In [None]:
import os
from typing import List

import numpy as np
import pandas as pd
import torch
from rdkit.Chem.rdmolfiles import MolFromSmarts
from tokenizers import Tokenizer

from src.data.components.utils import (
    smiles2vector_fg,
    smiles2vector_mfg,
    standardize_smiles,
)
from src.models.fgr_module import FGRPretrainLitModule

In [None]:
model = FGRPretrainLitModule.load_from_checkpoint("./epoch_000_val_0.8505.ckpt")
model.eval();

In [None]:
fgroups = pd.read_parquet("fg.parquet")["SMARTS"].tolist()  # Get functional groups
fgroups_list = [MolFromSmarts(x) for x in fgroups]  # Convert to RDKit Mol
tokenizer = Tokenizer.from_file(
    os.path.join(
        "tokenizers",
        f"BPE_pubchem_{500}.json",
    )
)  # Load tokenizer

In [None]:
def get_representation(
    smiles: List[str],
    method: str,
    fgroups_list: List[MolFromSmarts],
    tokenizer: Tokenizer,
) -> np.ndarray:
    smiles = [standardize_smiles(smi) for smi in smiles]  # Standardize smiles
    if method == "FG":
        x = np.stack([smiles2vector_fg(x, fgroups_list) for x in smiles])
    elif method == "MFG":
        x = np.stack([smiles2vector_mfg(x, tokenizer) for x in smiles])
    elif method == "FGR":
        f_g = np.stack([smiles2vector_fg(x, fgroups_list) for x in smiles])
        mfg = np.stack([smiles2vector_mfg(x, tokenizer) for x in smiles])
        x = np.concatenate((f_g, mfg), axis=1)  # Concatenate both vectors
    else:
        raise ValueError("Method not supported")  # Raise error if method not supported
    return x

In [None]:
x = get_representation(
    ["CC(C)(C)NCC(O)c1cc(Cl)c(N)c(c1)C(F)(F)F", "CCN", "CCF"], "FGR", fgroups_list, tokenizer
)
x = torch.tensor(x, dtype=torch.float32, device=model.device)
z_d = model(x)

In [None]:
z_d[0].shape

In [None]:
df = pd.read_csv("./PRISM_19Q4_chemicals.csv")
df

In [None]:
all_smiles = df[df.name != "tyloxapol"]["smiles_cleaned"].to_list()

In [None]:
rep = get_representation(all_smiles, "FGR", fgroups_list, tokenizer)
rep_x = torch.tensor(rep, dtype=torch.float32, device=model.device)
z_d = model(rep_x)[0]

In [None]:
# rep = []
# correct_smiles = []
# error = []
# for i, smile in enumerate(all_smiles):
#     try:
#         rep.append(get_representation([smile], "FGR", fgroups_list, tokenizer))
#         correct_smiles.append(smile)
#     except:
#         error.append(smile)

In [None]:
len(rep)

In [None]:
# error_df = df[~df['smiles_cleaned'].isin(correct_smiles)]
# error_df

In [None]:
len(all_smiles)

In [None]:
latent_df = pd.DataFrame(
    data=z_d.detach().numpy(),
    columns=[f"latent_{i}" for i in range(256)],
)
# latent_df = pd.concat((df[df.name!="tyloxapol"][['name', 'smiles_cleaned']].dropna(axis=0), latent_df),
#   ignore_index=True, axis =1)

In [None]:
latent_df

In [None]:
import seaborn as sns
from sklearn.preprocessing import MinMaxScaler

In [None]:
sns.heatmap(MinMaxScaler().fit_transform(latent_df))

In [None]:
final_df = df[df.name != "tyloxapol"].reset_index().drop("index", axis=1)

In [None]:
final_df

In [None]:
final_df = final_df.join(latent_df).drop("smiles", axis=1)
final_df

In [None]:
final_df.to_csv("PRSIM_19Q4_latent.csv", index=False)

In [None]:
rep.shape

In [None]:
rep[4].sum()

In [None]:
from umap import UMAP

embedder = UMAP()
embedding = embedder.fit_transform(latent_df)

In [None]:
sns.scatterplot(x=embedding[:, 0], y=embedding[:, 1], hue=final_df.index)