In [1]:
import pandas as pd
import numpy as np
from MCMC import run_mcmc_with_gibbs
import os
from scipy.sparse import csr_matrix
from pandas.api.types import CategoricalDtype

In [2]:
directory = '../out/'
dataframes = []  # List to hold dataframes
for filename in os.listdir(directory):
    if filename.endswith('.parquet'):  # Ensures only parquet files are read
        filepath = os.path.join(directory, filename)
        df = pd.read_parquet(filepath)  # Read parquet file to a pandas df
        dataframes.append(df)  # Append the df to the list
df = pd.concat(dataframes, ignore_index=True)

df.prob = df.prob.astype('float32')

In [3]:
pivot_df = df.pivot_table(index='species', columns='molecule', values='prob')
del df

In [4]:
df_agg_train = pd.read_csv("../data/lotus_agg_train.csv.gz", index_col=0)
df_agg_test = pd.read_csv("../data/lotus_agg_test.csv.gz", index_col=0)
df_agg = pd.concat([df_agg_train, df_agg_test])

In [5]:
mol_u = CategoricalDtype(sorted(df_agg.structure_smiles_2D.unique()), ordered=True)
species_u = CategoricalDtype(sorted(df_agg.organism_name.unique()), ordered=True)
row = df_agg.organism_name.astype(species_u).cat.codes
col = df_agg.structure_smiles_2D.astype(mol_u).cat.codes
sparse_matrix = csr_matrix((df_agg["reference_wikidata"], (row, col)),
                           shape=(species_u.categories.size, mol_u.categories.size),
                           dtype='uint16')
lotus_n_papers = pd.DataFrame.sparse.from_spmatrix(sparse_matrix, index=species_u.categories, columns=mol_u.categories)

In [6]:
#pivot_df = pivot_df[pivot_df.index.isin(lotus_n_papers.index)]
pivot_df.drop(index='Arnica amplexicaulis', inplace=True)

In [7]:
pivot_df = pivot_df[lotus_n_papers.columns]

In [8]:
lotus_n_papers = lotus_n_papers[lotus_n_papers.index.isin(pivot_df.index)]

In [9]:
pivot_df = pivot_df.iloc[:10, :]
lotus_n_papers = lotus_n_papers.iloc[:10, :]

In [10]:
n_iter = 100
x_init = np.zeros_like(lotus_n_papers)
gamma_init = 0.1
delta_init = 0.1

In [11]:
samples, x_samples, accept_gamma, accept_delta = run_mcmc_with_gibbs(
    lotus_n_papers.values, x_init, n_iter, gamma_init, delta_init, pivot_df.values)

KeyboardInterrupt: 

In [None]:
burn_in = int(0.5 * n_iter)  # Remove the first 50% of the samples
post_burn_in_samples = samples[burn_in:]

# Extract the posterior mean estimates for gamma and delta
gamma_posterior_mean = np.mean(post_burn_in_samples[:, 0])
delta_posterior_mean = np.mean(post_burn_in_samples[:, 1])

print("Estimated gamma: ", gamma_posterior_mean)
print("Estimated delta: ", delta_posterior_mean)
print("rate accept gamma : ", accept_gamma)
print("rate accept delta : ", accept_delta)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
sns.set(style="darkgrid")
fig, axs = plt.subplots(ncols=2)
sns.scatterplot(x=range(len(post_burn_in_samples)),
                y = post_burn_in_samples[:,0],
                ax=axs[0])
sns.lineplot(
    x=range(len(post_burn_in_samples)),
    y=[gamma_posterior_mean for i in range(len(post_burn_in_samples[:, 0]))],
    ax=axs[0],
    color="r",
)
sns.scatterplot(x=range(len(post_burn_in_samples)),
                y=post_burn_in_samples[:,1],
                ax=axs[1])
sns.lineplot(
    x=range(len(post_burn_in_samples)),
    y=[delta_posterior_mean for i in range(len(post_burn_in_samples[:, 1]))],
    ax=axs[1],
    color="r",
)

In [None]:
plt.hist(post_burn_in_samples[:, 0], bins=30)

In [None]:
plt.hist(post_burn_in_samples[:, 1], bins=30)

In [None]:
out = pd.DataFrame(np.mean(x_samples[-3000:],axis=0, dtype='float32'),
                   index=lotus_n_papers.index,
                   columns=lotus_n_papers.columns,
                   dtype='float32')

In [None]:
pivot_df

In [None]:
out

In [None]:
diff = out-pivot_df

In [None]:
diff

In [None]:
mask = diff > 0.8

In [None]:
diff.loc[mask.any(axis=1), mask.any(axis=0)]

In [None]:
out.loc[mask.any(axis=1), mask.any(axis=0)]

In [None]:
pivot_df.loc[mask.any(axis=1), mask.any(axis=0)]

In [None]:
lotus = pd.read_csv("../data/230106_frozen_metadata.csv.gz", low_memory=False)

In [None]:
lotus[(lotus.structure_smiles_2D=='CC1(C)CCCC2(C)C3C(=CCC12)COC3O')&(lotus.organism_name=='Dendrodoris carbunculosa')]