In [None]:
from baselines.global_utils import (
    get_all_model_funcs,
    smiles_from_file, 
    BASELINE_DIR
)
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit import DataStructs
from rdkit.Chem import AllChem
from rdkit.Chem import Draw
import numpy as np
import random


In [None]:
dataset = "zinc"
model_name = "JTVAE"
model_id = "3bsp47ta"

In [None]:
val_smiles = smiles_from_file(BASELINE_DIR / "smiles_files" / "zinc" / "val.txt")

In [None]:
num_plots = 3

random.shuffle(val_smiles)
smiles_pairs = []
for _ in range(num_plots):
    smiles_pairs.append([val_smiles.pop(), val_smiles.pop()])

In [None]:
# prepare model and input smiles
all_funcs = get_all_model_funcs(model_name)
inference_server = all_funcs["load"](dataset=dataset, model_id=model_id, seed=0)

num_interpol = 5
interpolation_outputs = inference_server.interpolate_between_molecules(smiles_pairs, num_interpol)
for smiles_list in interpolation_outputs:
    # Convert SMILES to RDKit molecules
    molecules = [Chem.MolFromSmiles(smiles) for smiles in smiles_list]

    # Calculate Tanimoto similarity
    similarity_matrix = np.zeros((len(molecules), len(molecules)))
    for i in range(len(molecules)):
        for j in range(len(molecules)):
            if i != j:
                sim = DataStructs.TanimotoSimilarity(AllChem.GetMorganFingerprint(molecules[i], 2), AllChem.GetMorganFingerprint(molecules[j], 2))
                similarity_matrix[i][j] = sim
    np.fill_diagonal(similarity_matrix, 1)

    # Add Tanimoto similarity labels to the molecules
    labels = []
    for i in range(len(smiles_list)):
        left_sim = similarity_matrix[0][i]
        right_sim = similarity_matrix[-1][i]
        label = f"Left: {left_sim:.2f}\nRight: {right_sim:.2f}"
        labels.append(label)
    print(labels)

    # Create an image grid with labels
    img = Draw.MolsToGridImage(molecules, molsPerRow=len(molecules), subImgSize=(300, 300), legends=labels, useSVG=True)
    display(img)
    # img_name = "interpolation/" + model_name + "/" + str(k)
    # with open(img_name + '.svg', 'w') as f:
    #     f.write(img.data)
