In [65]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import csv
from pathlib import Path
from collections import Counter
import os
import random
from rdkit import Chem
from rdkit.Chem import Descriptors, AllChem, DataStructs, QED
from efgs import get_dec_fgs
import io
from PIL import Image, ImageDraw, ImageFont
import re

from VAE import (
    train_conditional_vae, 
    train_conditional_subspace_vae, 
    train_discover_vae, 
    train_base_model,
    FingerprintDataset,
    BaseVAETrainer,
    ConditionalVAETrainer,
    ConditionalSubspaceVAETrainer,
    DiscoverVAETrainer,
    )
from fg_funcs import (
    extract_and_save_latents,
    extract_prefixed_arrays,
    evaluate_reconstructions,
    visualize_latent_space_per_fg,
    metric,
    get_nearest_neighbors,
    visualize_latent_space_tanimoto,
    fingerprint_to_bv,
    mol_to_fingerprint
)

import torch
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

from rdkit.DataStructs.cDataStructs import TanimotoSimilarity

In [2]:
TRAIN = []
GET_LATENTS = True
GET_EVAL = True
VIS_CUR = True
VIS_FULL = False

In [3]:
latent_dims = [4, 8, 16]  # You can change this to test different latent dimensions
encoder_hidden_dims = [1024, 512, 256, 128]
decoder_hidden_dims = [128, 512]  
decoder_z_hidden_dims = [128, 512]  # Decoder layers for DISCoVeR
latent_dims_z = [4, 8, 16] # for CSVAE and DISCoVeR
latent_dims_w = [2, 4, 8] # for CSVAE and DISCoVeR
encoder_hidden_dims_z = [1024, 512, 256, 128] # for CSVAE and DISCoVeR
encoder_hidden_dims_w = [1024, 512, 256, 128] # for CSVAE and DISCoVeR
adversarial_hidden_dims = [8] 
batch_size = 64
learning_rate = 1e-3
max_epochs = 5
betas = (0.5, 1, 1.5, 2, 3, 5)


In [4]:
for latent_dim, latent_dim_z, latent_dim_w in zip(latent_dims, latent_dims_z, latent_dims_w):
    for MODEL in TRAIN:
        model_type, dataset_type = MODEL.split('_')

        if dataset_type == 'CUR':
            # Load the curated dataset
            curated_dataset = pd.read_pickle('data/chembl_35_fg_scaf_curated.pkl')

            # Convert the fingerprint to numpy arrays
            curated_dataset['fingerprint_array'] = curated_dataset['fingerprint_array'].apply(lambda x: x if isinstance(x, np.ndarray) else np.zeros((2048,), dtype=int))
            curated_dataset['fg_array'] = curated_dataset['fg_array'].apply(lambda x: x if isinstance(x, np.ndarray) else np.zeros((100,), dtype=int))

            MODEL_OUTPUT = 'models/small_models'  # Directory to save trained models
            dataset = curated_dataset  # Use the curated dataset for training
            sparse = False

            input_dim = 2048
            fg_dim = 4

        elif dataset_type == 'FULL':
            # Load the full dataset
            full_dataset = pd.read_csv("data/chembl_35_fg_full.csv")

            MODEL_OUTPUT = 'models/large_models'  # Directory to save trained models
            dataset = full_dataset  # Use the full dataset for training
            sparse = True

            input_dim = 2048
            fg_dim = 50

        torch.manual_seed(42)

        if model_type is None:
            raise ValueError("MODEL must be defined before training.")
        elif model_type == 'Base':

            print("Training BaseVAE model...")
            
            print("Training with beta=1 ...")
            vae_trainer = train_base_model(
                dataset=dataset,
                input_dim=input_dim,
                latent_dim=latent_dim,
                fg_dim=fg_dim,  # BaseVAE does not use fg_array for logging purposes only
                encoder_hidden_dims=encoder_hidden_dims,
                decoder_hidden_dims=decoder_hidden_dims,
                beta=1,
                batch_size=batch_size,
                learning_rate=learning_rate,
                max_epochs=max_epochs,
                sparse=sparse
            )

        elif model_type == 'CVAE':

            print("Training CVAE model...")

            vae_trainer = train_conditional_vae(
                dataset=dataset,
                fingerprint_dim=input_dim,
                fg_dim=fg_dim,
                latent_dim=latent_dim,
                encoder_hidden_dims=encoder_hidden_dims,
                decoder_hidden_dims=decoder_hidden_dims,
                batch_size=batch_size,
                learning_rate=learning_rate,
                max_epochs=max_epochs,
                sparse=sparse
            )

        elif model_type == 'CSVAE':

            print("Training CSVAE model based on the NeurIPS 2018 paper...")


            vae_trainer = train_conditional_subspace_vae(
                dataset=dataset,
                fingerprint_dim=input_dim,
                fg_dim=fg_dim,
                latent_dim_z=latent_dim_z,
                latent_dim_w=latent_dim_w,
                encoder_hidden_dims_z=encoder_hidden_dims_z,
                encoder_hidden_dims_w=encoder_hidden_dims_w,
                decoder_hidden_dims=decoder_hidden_dims,
                adversarial_hidden_dims=adversarial_hidden_dims,
                batch_size=batch_size,
                learning_rate=learning_rate,
                max_epochs=max_epochs,
                sparse=sparse
            )

        elif model_type == 'DISCoVeR':

            print("Training DISCoVeR VAE model...")

            vae_trainer = train_discover_vae(
                dataset=dataset,
                fingerprint_dim=input_dim,
                fg_dim=fg_dim,
                latent_dim_z=latent_dim_z,
                latent_dim_w=latent_dim_w,
                encoder_hidden_dims_z=encoder_hidden_dims_z,
                encoder_hidden_dims_w=encoder_hidden_dims_w,
                decoder_hidden_dims=decoder_hidden_dims,
                decoder_z_hidden_dims=decoder_z_hidden_dims,
                adversarial_hidden_dims=adversarial_hidden_dims,
                batch_size=batch_size,
                learning_rate=learning_rate,
                max_epochs=max_epochs,
                sparse=sparse
            )

        # Free memory after each model
        del vae_trainer


Training BaseVAE model...
Training with beta=1 ...


GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
[34m[1mwandb[0m: Currently logged in as: [33mdamianrelkins[0m ([33mdamianrelkins-university-college-london-ucl-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin



  | Name  | Type    | Params | Mode 
------------------------------------------
0 | model | BaseVAE | 3.9 M  | train
------------------------------------------
3.9 M     Trainable params
0         Non-trainable params
3.9 M     Total params
15.622    Total estimated model params size (MB)
18        Modules in train mode
0         Modules in eval mode


Epoch 4: 100%|██████████| 28538/28538 [05:04<00:00, 93.77it/s, v_num=hlqn, val_loss=147.0, val_bce=137.0, val_kld=10.20] 

`Trainer.fit` stopped: `max_epochs=5` reached.


Epoch 4: 100%|██████████| 28538/28538 [05:04<00:00, 93.77it/s, v_num=hlqn, val_loss=147.0, val_bce=137.0, val_kld=10.20]
Testing DataLoader 0: 100%|██████████| 3568/3568 [00:14<00:00, 242.44it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_bce            136.44338989257812
        test_kld            10.158361434936523
        test_loss           146.60186767578125
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▃▃▃▃▃▃▃▃▅▅▅▅▅▅▆▆▆▆▆▆▆▆█████
test_bce,▁
test_kld,▁
test_loss,▁
train_bce,█▅▄▃▄▃▄▂▃▃▁▁▂▂▂▂▃▁▂▃▂▁▁▂▂▂▂▂▃▂▁▂▁▃▁▂▂▁▁▁
train_kld,▁▂▄▅▅▅▆▆▆▇▇▇▇▇▇▇▇▇▇▇█▇▇▇██▇▇█▇██████████
train_loss,█▆▆▅▄▃▃▄▃▂▂▄▂▂▄▂▃▁▃▁▁▃▃▃▃▂▃▂▃▂▂▂▃▂▃▄▃▃▁▃
trainer/global_step,▁▁▁▁▁▂▂▂▂▂▂▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆▆▇▇▇▇▇█
val_bce,█▅▄▃▂▂▂▁▁▁
val_kld,▁▄▅▆▇▇█▇██

0,1
epoch,5
test_bce,136.44339
test_kld,10.15836
test_loss,146.60187
train_bce,138.96391
train_kld,10.32231
train_loss,149.28622
trainer/global_step,142690
val_bce,136.53821
val_kld,10.16389


Training BaseVAE model...
Training with beta=1 ...


GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs



  | Name  | Type    | Params | Mode 
------------------------------------------
0 | model | BaseVAE | 3.9 M  | train
------------------------------------------
3.9 M     Trainable params
0         Non-trainable params
3.9 M     Total params
15.628    Total estimated model params size (MB)
18        Modules in train mode
0         Modules in eval mode


Epoch 4: 100%|██████████| 28538/28538 [04:42<00:00, 100.96it/s, v_num=olts, val_loss=147.0, val_bce=136.0, val_kld=10.10]

`Trainer.fit` stopped: `max_epochs=5` reached.


Epoch 4: 100%|██████████| 28538/28538 [04:42<00:00, 100.95it/s, v_num=olts, val_loss=147.0, val_bce=136.0, val_kld=10.10]
Testing DataLoader 0: 100%|██████████| 3568/3568 [00:12<00:00, 285.29it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_bce            136.28671264648438
        test_kld            10.088326454162598
        test_loss           146.37486267089844
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


0,1
epoch,▁▁▁▁▁▁▁▃▃▃▃▃▃▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆▆▆▆▆▆███████
test_bce,▁
test_kld,▁
test_loss,▁
train_bce,█▅▂▄▅▄▃▃▃▃▂▃▃▃▄▂▄▃▃▃▃▂▃▃▃▃▃▃▂▃▃▃▃▁▃▃▂▂▂▄
train_kld,▁▃▄▅▅▅▆▆▆▇▇▇▆▆▇▇▇▇██▇▇▇██▇▇████▇███▇███▇
train_loss,▇▆▇█▇▇▆▄▁▄▇▃▄▁▆▄▄▄▄▅▃▄▂▅▆▅▄▃▂▃▅▄▆▄▂▁▄▃▅▆
trainer/global_step,▁▁▁▁▂▂▂▂▂▂▂▃▃▄▄▄▄▄▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇███
val_bce,█▅▄▃▃▂▂▂▁▁
val_kld,▁▃▅▆▆▆█▇▇█

0,1
epoch,5
test_bce,136.28671
test_kld,10.08833
test_loss,146.37486
train_bce,133.21512
train_kld,10.0992
train_loss,143.31432
trainer/global_step,142690
val_bce,136.4131
val_kld,10.09376


Training BaseVAE model...
Training with beta=1 ...


GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs



  | Name  | Type    | Params | Mode 
------------------------------------------
0 | model | BaseVAE | 3.9 M  | train
------------------------------------------
3.9 M     Trainable params
0         Non-trainable params
3.9 M     Total params
15.641    Total estimated model params size (MB)
18        Modules in train mode
0         Modules in eval mode


Epoch 4: 100%|██████████| 28538/28538 [04:47<00:00, 99.16it/s, v_num=9y5o, val_loss=143.0, val_bce=131.0, val_kld=11.90] 

`Trainer.fit` stopped: `max_epochs=5` reached.


Epoch 4: 100%|██████████| 28538/28538 [04:47<00:00, 99.16it/s, v_num=9y5o, val_loss=143.0, val_bce=131.0, val_kld=11.90]
Testing DataLoader 0: 100%|██████████| 3568/3568 [00:12<00:00, 284.84it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_bce            131.33856201171875
        test_kld            11.871740341186523
        test_loss           143.21054077148438
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


0,1
epoch,▁▁▁▁▁▃▃▃▃▃▃▃▃▃▃▃▃▃▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆▆▆█████
test_bce,▁
test_kld,▁
test_loss,▁
train_bce,█▇▇▆▅▆▄▅▆▄▅▄▅▄▄▄▃▃▅▃▃▄▂▃▁▄▂▃▂▄▃▁▂▄▂▂▁▂▂▃
train_kld,▁▃▄▅▆▆▆▇▆▇▇▇▇▇▇▇▇███████▇▇███▇██████████
train_loss,█▇▆▅▅▂▃▂▃▃▄▃▄▃▄▂▂▂▄▃▃▄▄▄▂▃▁▃▂▃▃▁▁▄▁▂▁▁▃▃
trainer/global_step,▁▁▁▂▂▂▂▂▂▂▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▆▆▇▇▇▇▇█
val_bce,█▅▄▃▂▂▂▁▁▁
val_kld,▁▄▄▆▆▆▆▇▇█

0,1
epoch,5
test_bce,131.33856
test_kld,11.87174
test_loss,143.21054
train_bce,131.18587
train_kld,11.60181
train_loss,142.78767
trainer/global_step,142690
val_bce,131.41809
val_kld,11.87794


In [None]:
if GET_LATENTS:
    # Create a test DataLoader
    curated_dataset = pd.read_pickle('data/chembl_35_fg_scaf_curated.pkl')

    # Convert the fingerprint to numpy arrays
    curated_dataset['fingerprint_array'] = curated_dataset['fingerprint_array'].apply(lambda x: x if isinstance(x, np.ndarray) else np.zeros((2048,), dtype=int))
    curated_dataset['fg_array'] = curated_dataset['fg_array'].apply(lambda x: x if isinstance(x, np.ndarray) else np.zeros((100,), dtype=int))

    # Split the dataset into train and test sets
    train_data, test_data = train_test_split(curated_dataset, test_size=0.2, random_state=42)
    val_data, test_data = train_test_split(test_data, test_size=0.2, random_state=42)

    test_dataset = FingerprintDataset(test_data, sparse=False)
    test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False)

    train_dataset = FingerprintDataset(train_data, sparse=False)
    train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=False)

    device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"


    dataset_partition = 'Test'

    for latent_dim in latent_dims:
        small_models = [
            ("Base", f"checkpoints/fg_bvae_4_{latent_dim}_1/best-checkpoint.ckpt", BaseVAETrainer),
            ("CVAE", f"checkpoints/fg_cvae_4_{latent_dim}/best-checkpoint.ckpt", ConditionalVAETrainer),
            ("CSVAE", f"checkpoints/fg_csvae_4_{latent_dim}/best-checkpoint.ckpt", ConditionalSubspaceVAETrainer),
            ("DISCoVeR", f"checkpoints/fg_dvae_4_{latent_dim}/best-checkpoint.ckpt", DiscoverVAETrainer),
        ]

        # Extract and save latents for small models
        for model_name, model_path, model_class in small_models:
            extract_and_save_latents(
                model_path=model_path,
                dataloader=test_dataloader if dataset_partition == 'Test' else train_dataloader,
                model_type=model_name,
                model_class=model_class,
                device=device,
                output_csv=f"latents/latents_{model_name}_{len(test_dataset) if dataset_partition == 'Test' else len(train_dataset)}_{latent_dim}_{beta}.csv"
            ), 

In [None]:
if GET_EVAL:
    # Create a test DataLoader
    curated_dataset = pd.read_pickle('data/chembl_35_fg_scaf_curated.pkl')

    # Convert the fingerprint to numpy arrays
    curated_dataset['fingerprint_array'] = curated_dataset['fingerprint_array'].apply(lambda x: x if isinstance(x, np.ndarray) else np.zeros((2048,), dtype=int))
    curated_dataset['fg_array'] = curated_dataset['fg_array'].apply(lambda x: x if isinstance(x, np.ndarray) else np.zeros((100,), dtype=int))

    # Split the dataset into train and test sets
    train_data, test_data = train_test_split(curated_dataset, test_size=0.2, random_state=42)
    val_data, test_data = train_test_split(test_data, test_size=0.2, random_state=42)

    test_dataset = FingerprintDataset(test_data, sparse=False)
    test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False)
    device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"

    # Create list of models and locations
    for beta in betas:
        for latent_dim in latent_dims:
            small_models = [
                ("Base", f"checkpoints/fg_bvae_4_{latent_dim}_{beta}/best-checkpoint.ckpt", BaseVAETrainer),
                # ("CVAE", f"checkpoints/fg_cvae_4_{latent_dim}/best-checkpoint.ckpt", ConditionalVAETrainer),
                # ("CSVAE", f"checkpoints/fg_csvae_4_{latent_dim}/best-checkpoint.ckpt", ConditionalSubspaceVAETrainer),
                # ("DISCoVeR", f"checkpoints/fg_dvae_4_{latent_dim}/best-checkpoint.ckpt", DiscoverVAETrainer)
            ]

            # Extract and save latents for small models
            for model_name, model_path, model_class in small_models:
                evaluate_reconstructions(
                    model_path=model_path,
                    dataloader=test_dataloader,
                    model_type=model_name,
                    model_class=model_class,
                    device=device,
                    thresholds=np.arange(0.4, 0.7, 0.02)
                )

In [11]:
latent_dims = [4, 8, 16]  # You can change this to test different latent dimensions
if GET_LATENTS:    
    try:
        # Set partition to either 'test' or 'train'
        partition = 'test'  # or 'train'

        full_dataset = pd.read_csv("data/chembl_35_fg_full.csv")
        train_full, test_full = train_test_split(full_dataset, test_size=0.2, random_state=42)
        val_full, test_full = train_test_split(test_full, test_size=0.2, random_state=42)

        if partition == 'test':
            data_full = test_full
        elif partition == 'train':
            data_full = train_full
        else:
            raise ValueError("partition must be 'test' or 'train'")

        smiles_list = data_full['smiles'].tolist()
        full_dataset_obj = FingerprintDataset(data_full, sparse=True)
        full_dataloader = DataLoader(full_dataset_obj, batch_size=64, shuffle=False)

        for latent_dim in latent_dims:
            large_models = [
                ("Base", f"checkpoints/fg_bvae_50_{latent_dim}_1/best-checkpoint.ckpt", BaseVAETrainer),
                # ("CVAE", f"checkpoints/fg_cvae_50_{latent_dim}/best-checkpoint.ckpt", ConditionalVAETrainer),
                # ("CSVAE", f"checkpoints/fg_csvae_50_{latent_dim}/best-checkpoint.ckpt", ConditionalSubspaceVAETrainer),
                # ("DISCoVeR", f"checkpoints/fg_dvae_50_{latent_dim}/best-checkpoint.ckpt", DiscoverVAETrainer)
            ]

            for model_name, model_path, model_class in large_models:
                extract_and_save_latents(
                    model_path=model_path,
                    dataloader=full_dataloader,
                    model_type=model_name,
                    model_class=model_class,
                    device="mps",
                    smiles_list=smiles_list,
                    output_csv=f"latents/latents_{model_name}_{len(data_full)}_{latent_dim}.csv"
                )
    except Exception as e:
        print(f"Error occurred while extracting latents: {e}")

Loading Base model from checkpoints/fg_bvae_50_4_1/best-checkpoint.ckpt...
Saved latents (including y and smiles) to latents/latents_Base_91321_4.csv
Loading Base model from checkpoints/fg_bvae_50_8_1/best-checkpoint.ckpt...
Saved latents (including y and smiles) to latents/latents_Base_91321_8.csv
Loading Base model from checkpoints/fg_bvae_50_16_1/best-checkpoint.ckpt...
Saved latents (including y and smiles) to latents/latents_Base_91321_16.csv


In [10]:
latent_dims = [8,16]  # You can change this to test different latent dimensions
if GET_EVAL:
    full_dataset = pd.read_csv("data/chembl_35_fg_full.csv")

    train_full, test_full = train_test_split(full_dataset, test_size=0.2, random_state=42)
    val_full, test_full = train_test_split(test_full, test_size=0.2, random_state=42)

    test_full_dataset = FingerprintDataset(test_full, sparse=True)
    full_dataloader = DataLoader(test_full_dataset, batch_size=64, shuffle=False)

    for latent_dim in latent_dims:
        large_models = [
            ("Base", f"checkpoints/fg_bvae_50_{latent_dim}_1/best-checkpoint.ckpt", BaseVAETrainer),
            # ("CVAE", f"checkpoints/fg_cvae_50_{latent_dim}/best-checkpoint.ckpt", ConditionalVAETrainer),
            # ("CSVAE", f"checkpoints/fg_csvae_50_{latent_dim}/best-checkpoint.ckpt", ConditionalSubspaceVAETrainer),
            # ("DISCoVeR", f"checkpoints/fg_dvae_50_{latent_dim}/best-checkpoint.ckpt", DiscoverVAETrainer)
        ]

        for model_name, model_path, model_class in large_models:
            evaluate_reconstructions(
                model_path=model_path,
                dataloader=full_dataloader,
                model_type=model_name,
                model_class=model_class,
                device="mps",
                thresholds=np.arange(0.5, 0.7, 0.02)
            )

Loading Base model from checkpoints/fg_bvae_50_8_1/best-checkpoint.ckpt...
Selected threshold based on max F1: 0.640 (F1=0.5103)
Saved reconstruction metrics to metrics/reconstruction_metrics_fg_bvae_50_8_1.csv
Loading Base model from checkpoints/fg_bvae_50_16_1/best-checkpoint.ckpt...
Selected threshold based on max F1: 0.640 (F1=0.5343)
Saved reconstruction metrics to metrics/reconstruction_metrics_fg_bvae_50_16_1.csv


In [57]:
latents = {}
models = ['Base', 'CSVAE', 'DISCoVeR']
sizes = ['1826415']
dims = ['4']
neighbors = 50

for model in models:
    if model in ['Base', 'CVAE']:
        terms = ['z', 'y']
    elif model in ['CSVAE', 'DISCoVeR']:
        terms = ['z', 'w', 'y']
    for size in sizes:
        for dim in dims:
            name = f"latents_{model}_{size}_{dim}"
            latent = f"latents/{name}.csv"
            latents[name] = extract_prefixed_arrays(latent, terms)
            if 'w' in terms:
                # Concatenate w and z
                latents[name]['wz'] = latents[name].apply(
                    lambda row: np.concatenate([row['w'], row['z']]), axis=1
                )

In [27]:
print(latents.keys())
print(latents['latents_Base_1826415_4'].head())

dict_keys(['latents_Base_1826415_4', 'latents_CSVAE_1826415_4', 'latents_DISCoVeR_1826415_4'])
                                              smiles  \
0  Cc1c[nH]c2ncnc(N3CC4CC3CN4/C(=N/C#N)Nc3cccc(Br...   
1                NC(C(=O)NO)C(=O)N1CCN(Cc2ccccc2)CC1   
2  Cc1c(-c2noc(C(=O)N3CCC(Cc4ccccc4)CC3)n2)oc2c(C...   
3  COc1cc(C(=O)O)ccc1NC(=O)[C@H]1[C@H](c2cccc(Cl)...   
4         Cn1c([C@H]2CCN(Cc3ccc(Cl)cc3)C2)nc2ccccc21   

                                                  z  \
0   [0.36363792, -1.1114069, 0.57658255, 1.0574797]   
1    [1.7424464, 0.95360804, 0.5800475, 0.35313958]   
2  [0.9032097, -0.015957396, 0.3586692, 0.18548699]   
3  [0.20626788, 0.1730261, -1.5167452, -0.12495865]   
4  [0.62835634, 1.3160388, 0.020977855, 0.31995744]   

                                                   y  
0  [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ...  
1  [0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, ...  
2  [1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...  
3  [0.0, 1.0, 1.0

In [None]:
method = 'umap'
cur_size = 80000
full_size = 1826415
latent_dim = 4
sample_size = 100000
beta = 1
fg_names = ['[R][NH][R]', 'O=[C](O)[R]', 'C=C', '[NH2][Car]']

for latent_dim in latent_dims:
    if VIS_CUR:
        try:
            # Visualize the latent space
            visualize_latent_space_per_fg(latents[f'latents_Base_{cur_size}_{latent_dim}_{beta}']['z'], latents[f'latents_Base_{cur_size}_{latent_dim}_{beta}']['y'], method, sample_size=sample_size, combined_title=f'Base VAE (Beta = {beta}) {method.upper()} {latent_dim} Dimensions {'Test Set' if cur_size < 50000 else 'Train Set'}', per_fg_title=f'Base VAE (Beta = {beta}) by FG {method.upper()} {latent_dim} {'Test Set' if cur_size < 50000 else 'Train Set'}', save_path=f'images/base_vae_cur_{latent_dim}_{cur_size}_{beta}/', fg_names=fg_names)
            visualize_latent_space_per_fg(latents[f'latents_CVAE_{cur_size}_{latent_dim}']['z'], latents[f'latents_CVAE_{cur_size}_{latent_dim}']['y'], method, sample_size=sample_size, combined_title=f'CVAE {method.upper()} {latent_dim} {'Test Set' if cur_size < 50000 else 'Train Set'}', per_fg_title=f'CVAE by FG {method.upper()} {latent_dim} {'Test Set' if cur_size < 50000 else 'Train Set'}', save_path=f'images/cvae_cur_{latent_dim}_{cur_size}/', fg_names=fg_names)
            visualize_latent_space_per_fg(latents[f'latents_CSVAE_{cur_size}_{latent_dim}']['z'], latents[f'latents_CSVAE_{cur_size}_{latent_dim}']['y'], method, sample_size=sample_size, combined_title=f'CSVAE {method.upper()} {latent_dim} {'Test Set' if cur_size < 50000 else 'Train Set'}', per_fg_title=f'CSVAE by FG {method.upper()} {latent_dim} {'Test Set' if cur_size < 50000 else 'Train Set'}', save_path=f'images/csvae_cur_{latent_dim}_{cur_size}/', fg_names=fg_names)
            visualize_latent_space_per_fg(latents[f'latents_DISCoVeR_{cur_size}_{latent_dim}']['z'], latents[f'latents_DISCoVeR_{cur_size}_{latent_dim}']['y'], method, sample_size=sample_size, combined_title=f'Discover {method.upper()} {latent_dim} {'Test Set' if cur_size < 50000 else 'Train Set'}', per_fg_title=f'Discover by FG {method.upper()} {latent_dim} {'Test Set' if cur_size < 50000 else 'Train Set'}', save_path=f'images/discover_cur_{latent_dim}_{cur_size}/', fg_names=fg_names)
            visualize_latent_space_per_fg(latents[f'latents_CSVAE_{cur_size}_{latent_dim}']['w'], latents[f'latents_CSVAE_{cur_size}_{latent_dim}']['y'], method, sample_size=sample_size, combined_title=f'CSVAE {method.upper()} {int(latent_dim/2)} {'Test Set' if cur_size < 50000 else 'Train Set'}', per_fg_title=f'CSVAE by FG {method.upper()} {int(latent_dim/2)} {'Test Set' if cur_size < 50000 else 'Train Set'}', save_path=f'images/csvae_cur_w_{latent_dim}_{cur_size}/', fg_names=fg_names)
            visualize_latent_space_per_fg(latents[f'latents_DISCoVeR_{cur_size}_{latent_dim}']['w'], latents[f'latents_DISCoVeR_{cur_size}_{latent_dim}']['y'], method, sample_size=sample_size, combined_title=f'Discover {method.upper()} {int(latent_dim/2)} {'Test Set' if cur_size < 50000 else 'Train Set'}', per_fg_title=f'Discover by FG {method.upper()} {int(latent_dim/2)} {'Test Set' if cur_size < 50000 else 'Train Set'}', save_path=f'images/discover_cur_w_{latent_dim}_{cur_size}/', fg_names=fg_names)
            visualize_latent_space_per_fg(latents[f'latents_CSVAE_{cur_size}_{latent_dim}']['wz'], latents[f'latents_CSVAE_{cur_size}_{latent_dim}']['y'], method, sample_size=sample_size, combined_title=f'CSVAE {method.upper()} {latent_dim+int(latent_dim/2)} {'Test Set' if cur_size < 50000 else 'Train Set'}', per_fg_title=f'CSVAE WZ by FG {method.upper()} {latent_dim+int(latent_dim/2)} {'Test Set' if cur_size < 50000 else 'Train Set'}', save_path=f'images/csvae_cur_wz_{latent_dim}_{cur_size}/', fg_names=fg_names)
            visualize_latent_space_per_fg(latents[f'latents_DISCoVeR_{cur_size}_{latent_dim}']['wz'], latents[f'latents_DISCoVeR_{cur_size}_{latent_dim}']['y'], method, sample_size=sample_size, combined_title=f'Discover {method.upper()} {latent_dim+int(latent_dim/2)} {'Test Set' if cur_size < 50000 else 'Train Set'}', per_fg_title=f'Discover WZ by FG {method.upper()} {latent_dim+int(latent_dim/2)} {'Test Set' if cur_size < 50000 else 'Train Set'}', save_path=f'images/discover_cur_wz_{latent_dim}_{cur_size}/', fg_names=fg_names)
        except Exception as e:
            print(f"Error occurred while visualizing current latents: {e}")

    if VIS_FULL:
        try:
            visualize_latent_space_per_fg(latents[f'latents_Base_{full_size}_50']['z'], latents[f'latents_Base_{full_size}_50']['y'], method, sample_size=5000, combined_title=f'Base VAE {method.upper()} {latent_dim}', per_fg_title=f'Base VAE by FG {method.upper()} {latent_dim}', save_path=f'images/base_vae_full_{latent_dim}_{full_size}/')
            visualize_latent_space_per_fg(latents[f'latents_CVAE_{full_size}_50']['z'], latents[f'latents_CVAE_{full_size}_50']['y'], method, sample_size=5000, combined_title=f'CVAE {method.upper()} {latent_dim}', per_fg_title=f'CVAE by FG {method.upper()} {latent_dim}', save_path=f'images/cvae_full_{latent_dim}_{full_size}/')
            visualize_latent_space_per_fg(latents[f'latents_CSVAE_{full_size}_50']['z'], latents[f'latents_CSVAE_{full_size}_50']['y'], method, sample_size=5000, combined_title=f'CSVAE {method.upper()} {latent_dim}', per_fg_title=f'CSVAE by FG {method.upper()} {latent_dim}', save_path=f'images/csvae_full_{latent_dim}_{full_size}/')
            visualize_latent_space_per_fg(latents[f'latents_DISCoVeR_{full_size}_50']['z'], latents[f'latents_DISCoVeR_{full_size}_50']['y'], method, sample_size=5000, combined_title=f'Discover {method.upper()} {latent_dim}', per_fg_title=f'Discover by FG {method.upper()} {latent_dim}', save_path=f'images/discover_full_{latent_dim}_{full_size}/')
            visualize_latent_space_per_fg(latents[f'latents_CSVAE_{full_size}_50']['w'], latents[f'latents_CSVAE_{full_size}_50']['y'], method, sample_size=5000, combined_title=f'CSVAE {method.upper()} {latent_dim}', per_fg_title=f'CSVAE by FG {method.upper()} {latent_dim}', save_path=f'images/csvae_full_w_{latent_dim}_{full_size}/')
            visualize_latent_space_per_fg(latents[f'latents_DISCoVeR_{full_size}_50']['w'], latents[f'latents_DISCoVeR_{full_size}_50']['y'], method, sample_size=5000, combined_title=f'Discover {method.upper()} {latent_dim}', per_fg_title=f'Discover by FG {method.upper()} {latent_dim}', save_path=f'images/discover_full_w_{latent_dim}_{full_size}/')
        except Exception as e:
            print(f"Error occurred while visualizing full latents: {e}")

In [50]:
# load all latents into latents list

# Prepare to collect metrics
metrics_rows = []
header = ['model', 'size', 'dim', 'beta', 'latent_type', 'score']

# Only outer progress bar
for name, latent in tqdm(latents.items(), desc="Processing latent sets", dynamic_ncols=True):
    parts = name.split('_')
    model, size, dim = parts[1], parts[2], parts[3]
    fg_counts = latent['y'].sum()

    latent_sample = latent.sample(n=50000) if len(latent['y']) > 50000 else latent

    def process_latent(latent_key):
        nbrs, latent_data, labels = get_nearest_neighbors(
            latents=latent_sample[latent_key].to_list(),
            fg_labels=np.array(latent_sample['y'].to_list()),
            n_neighbors=neighbors
        )
        scores = [metric(nbrs, i, latent_data, fg_counts, labels, neighbors)
                  for i in range(len(latent_sample['y']))]

        mean_score = np.mean(scores)
        return [model, size, dim, latent_key, mean_score]

    metrics_rows.append(process_latent('z'))

    if 'w' in latent_sample:
        metrics_rows.append(process_latent('w'))
        
    if 'wz' in latent_sample:
        metrics_rows.append(process_latent('wz'))

# Save to CSV 
metrics_path = Path(f'metrics/latent_metrics_{neighbors}.csv')
with open(metrics_path, 'w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(header)
    writer.writerows(metrics_rows)


Processing latent sets: 100%|██████████| 6/6 [02:22<00:00, 23.83s/it]


In [69]:
# Ensure top-level directories exist
Path('images').mkdir(exist_ok=True)
Path('interpolations').mkdir(exist_ok=True)

# Compute descriptors
def compute_descriptors(mol):
    return {
        'MolWt': Descriptors.MolWt(mol),
        'LogP': Descriptors.MolLogP(mol),
        'QED': QED.qed(mol),
    }

def save_molecule_image(mol, smi, output_dir, anchor_mol=None):
    # Generate molecule image
    # Generate molecule image
    img_text, _, _, _ = get_dec_fgs(mol)
    img = Image.open(io.BytesIO(img_text)).convert("RGB")  # ensure RGB

    # Compute descriptors
    desc = compute_descriptors(mol)
    desc_text = f"MolWt: {desc['MolWt']:.2f}, LogP: {desc['LogP']:.2f}, QED: {desc['QED']:.2f}"

    # Compute Tanimoto similarity if anchor provided
    if anchor_mol is not None:
        fp_mol = mol_to_fingerprint(mol)
        fp_anchor = mol_to_fingerprint(anchor_mol)
        tanimoto = DataStructs.TanimotoSimilarity(fp_mol, fp_anchor)
        desc_text += f", Tanimoto: {tanimoto:.2f}"
        print(f"Tanimoto similarity between {Chem.MolToSmiles(mol)} and anchor: {tanimoto:.2f}")

    # Draw text size and dynamically fit
    draw_temp = ImageDraw.Draw(img)
    max_width = img.width - 10
    font_size = 48
    while font_size > 8:
        try:
            font = ImageFont.truetype("/System/Library/Fonts/Supplemental/Arial.ttf", size=font_size)
        except OSError:
            font = ImageFont.load_default()
        bbox = draw_temp.textbbox((0,0), desc_text, font=font)
        text_width = bbox[2] - bbox[0]
        text_height = bbox[3] - bbox[1]
        if text_width <= max_width:
            break
        font_size -= 2

    # Create new image
    total_height = img.height + text_height + 10
    new_img = Image.new("RGB", (img.width, total_height), "white")
    new_img.paste(img, (0, 0))

    draw = ImageDraw.Draw(new_img)
    draw.text(((img.width - text_width)//2, img.height + 5), desc_text, fill="black", font=font)

    # Save image
    safe_smi = re.sub(r'[^\w\-]', '_', smi)
    file_path = os.path.join(output_dir, f"{safe_smi}.png")
    new_img.save(file_path)

# SLERP interpolation
def slerp(z1, z2, t):
    z1 = np.array(z1)
    z2 = np.array(z2)
    omega = np.arccos(np.clip(np.dot(z1, z2) / (np.linalg.norm(z1) * np.linalg.norm(z2)), -1, 1))
    so = np.sin(omega)
    if so == 0:
        return (1 - t) * z1 + t * z2
    return (np.sin((1 - t) * omega) / so) * z1 + (np.sin(t * omega) / so) * z2



In [71]:
def visualize_nearest_neighbors(latent_df, latent_key='z', anchor_idx=None, output_root="images/nearest_neighbors", n_neighbors=5):
    """
    Selects a molecule (by anchor_idx if provided, else random) and visualizes its n nearest neighbors in latent space.
    
    latent_df: pandas DataFrame with columns ['smiles', 'z']
    anchor_idx: index of anchor molecule (int), or None for random
    output_root: root directory to save images
    n_neighbors: number of neighbors to visualize
    """
    from pathlib import Path
    import numpy as np
    import random
    import re
    from rdkit import Chem

    Path(output_root).mkdir(parents=True, exist_ok=True)

    # Pick anchor molecule
    if anchor_idx is None:
        idx = random.randint(0, len(latent_df) - 1)
    else:
        idx = anchor_idx
    anchor_smi = latent_df.iloc[idx]['smiles']
    anchor_z = np.array(latent_df.iloc[idx][latent_key])

    # Compute distances
    latents_array = np.vstack(latent_df[latent_key].to_numpy())
    distances = np.linalg.norm(latents_array - anchor_z, axis=1)
    distances[idx] = np.inf  # exclude itself

    # Find n nearest neighbors
    neighbor_indices = np.argsort(distances)[:n_neighbors]
    neighbor_smiles = [latent_df.iloc[i]['smiles'] for i in neighbor_indices]

    # Create output folder for this anchor
    safe_anchor = re.sub(r'[^\w\-]', '_', anchor_smi)
    out_dir = Path(output_root) / safe_anchor
    out_dir.mkdir(parents=True, exist_ok=True)

    # Save anchor as "ibuprofen.png"
    anchor_mol = Chem.MolFromSmiles(anchor_smi)
    if anchor_mol:
        save_molecule_image(anchor_mol, "ibuprofen", output_dir=str(out_dir), anchor_mol=anchor_mol)

    # Save neighbors as neighbour1.png ... neighbourN.png
    for i, smi in enumerate(neighbor_smiles, 1):
        mol = Chem.MolFromSmiles(smi)
        if mol:
            save_molecule_image(mol, f"neighbour{i}", output_dir=str(out_dir), anchor_mol=anchor_mol)

    print(f"Anchor: {anchor_smi}")
    print(f"Nearest neighbors: {neighbor_smiles}")
    return idx, anchor_smi, neighbor_smiles


# Choose a fixed anchor index for all models
anchor_idx = random.randint(0, len(next(iter(latents.values()))) - 1)
anchor_idx = 1051396 # Ibuprofen
print(f"Using anchor index: {anchor_idx}")

# Visualize nearest neighbors for each model using the same anchor
for name, latent_df in latents.items():
    parts = name.split('_')
    model, size, dim = parts[1], parts[2], parts[3]
    print(f"Visualizing nearest neighbors for {model} with latent dim {dim} and dataset size {size}")
    if 'Base' in model:
        latent_key = 'z'
    elif 'CSVAE' in model or 'DISCoVeR' in model:
        latent_key = 'wz'
    visualize_nearest_neighbors(latent_df, latent_key=latent_key, anchor_idx=anchor_idx, output_root=f"images/nearest_neighbors/{model}_{size}_{dim}", n_neighbors=5)

Using anchor index: 1051396
Visualizing nearest neighbors for Base with latent dim 4 and dataset size 1826415
Tanimoto similarity between CC(C)Cc1ccc(C(C)C(=O)O)cc1 and anchor: 1.00
Tanimoto similarity between Oc1c(C(Nc2ccccc2)c2ccccc2)cc(Br)c2ccccc12 and anchor: 0.13
Tanimoto similarity between CC(C)NC(C)C(O)c1cccc(Cl)c1 and anchor: 0.17
Tanimoto similarity between CN(C)C(c1ccccc1)C(O)c1ccccc1 and anchor: 0.21
Tanimoto similarity between CC(C)CC(N)(Cc1ccccc1)C(=O)O and anchor: 0.32
Tanimoto similarity between CC(c1ccccc1)C(C(=O)O)C(Cc1ccccc1)C(=O)O and anchor: 0.45
Anchor: CC(C)Cc1ccc(C(C)C(=O)O)cc1
Nearest neighbors: ['Oc1c(C(Nc2ccccc2)c2ccccc2)cc(Br)c2ccccc12', 'CC(C)NC(C)C(O)c1cccc(Cl)c1', 'CN(C)C(c1ccccc1)C(O)c1ccccc1', 'CC(C)C[C@](N)(Cc1ccccc1)C(=O)O', 'C[C@@H](c1ccccc1)[C@H](C(=O)O)[C@H](Cc1ccccc1)C(=O)O']
Visualizing nearest neighbors for CSVAE with latent dim 4 and dataset size 1826415
Tanimoto similarity between CC(C)Cc1ccc(C(C)C(=O)O)cc1 and anchor: 1.00
Tanimoto similarity 

In [None]:
# Use metric directly on fingerprints
fingerprint_metric_df = pd.read_pickle('data/chembl_35_fg_scaf_curated.pkl')

fingerprint_train, fingerprint_test = train_test_split(fingerprint_metric_df, test_size=0.2, random_state=42)
fingerprint_val, fingerprint_test = train_test_split(fingerprint_test, test_size=0.2, random_state=42)

# select a subset of fingerprints for evaluation
fingerprint_train = fingerprint_train.sample(n=10000, random_state=42)

nbrs, fingerprints, labels = get_nearest_neighbors(latents=fingerprint_train['fingerprint_array'].to_list(), fg_labels=fingerprint_train['fg_array'].to_list(), n_neighbors=50)

scores = []

for i in range(len(fingerprints)):
    score = metric(nbrs, i, fingerprints, fingerprint_train['fg_array'].sum(), labels, 50)
    scores.append(score)

print(np.array(scores).mean())

In [None]:
# Use metric directly on fingerprints
fingerprint_metric_df = pd.read_csv('data/chembl_35_fg_full.csv')

fingerprint_train, fingerprint_test = train_test_split(fingerprint_metric_df, test_size=0.2, random_state=42)
fingerprint_val, fingerprint_test = train_test_split(fingerprint_test, test_size=0.2, random_state=42)

# select a subset of fingerprints for evaluation
fingerprint_train = fingerprint_train.sample(n=10000, random_state=42)

fingerprint_train_dataset = FingerprintDataset(fingerprint_train, sparse=True)

fingerprint_train_loader = DataLoader(fingerprint_train_dataset, batch_size=32, shuffle=True)

# Concatenate batches into a single tensor
fingerprints = torch.cat([batch[0] for batch in fingerprint_train_loader], dim=0)
fg_vectors = torch.cat([batch[1] for batch in fingerprint_train_loader], dim=0)

# Convert to numpy if needed
fingerprints = fingerprints.numpy()
fg_vectors = fg_vectors.numpy()

# Now call nearest neighbors
nbrs, fingerprints, labels = get_nearest_neighbors(
    latents=fingerprints, 
    fg_labels=fg_vectors, 
    n_neighbors=50
)


scores = []

for i in range(len(fingerprints)):
    score = metric(nbrs, i, fingerprints, fg_vectors.sum(), labels, 50)
    scores.append(score)

print(np.array(scores).mean())

# Get tanimoto similarity matrix
tanimoto_sim_matrix = np.zeros((len(fingerprints), len(fingerprints)))

for i in range(len(fingerprints)):
    for j in range(len(fingerprints)):
        tanimoto_sim_matrix[i, j] = TanimotoSimilarity(fingerprint_to_bv(fingerprints[i]), fingerprint_to_bv(fingerprints[j]))

# only keep upper triangle
tanimoto_sim_matrix = np.triu(tanimoto_sim_matrix, k=1)

# get descriptive statistics
tanimoto_sim_flat = tanimoto_sim_matrix.flatten()
tanimoto_sim_flat = tanimoto_sim_flat[tanimoto_sim_flat > 0]  # keep only positive similarities

print("Tanimoto Similarity - Descriptive Statistics:")
print(f"Mean: {tanimoto_sim_flat.mean()}")
print(f"Median: {np.median(tanimoto_sim_flat)}")
print(f"Std Dev: {tanimoto_sim_flat.std()}")
print(f"Max: {tanimoto_sim_flat.max()}")
print(f"Min: {tanimoto_sim_flat.min()}")


In [None]:
# Plot fingerprint brightness umap
fingerprint_metric_df = pd.read_pickle('data/chembl_35_fg_scaf_curated.pkl')

fingerprint_train, fingerprint_test = train_test_split(fingerprint_metric_df, test_size=0.2, random_state=42)
fingerprint_val, fingerprint_test = train_test_split(fingerprint_test, test_size=0.2, random_state=42)

fingerprint_train_dataset = FingerprintDataset(fingerprint_train, sparse=False)

model = BaseVAETrainer.load_from_checkpoint('checkpoints/fg_bvae_4_8/best-checkpoint.ckpt', map_location='cpu').model

visualize_latent_space_tanimoto(
    model=model,
    dataset=fingerprint_train_dataset,
    method='umap',
    sample_size=100000,
    title='Base VAE Latent Space with Tanimoto Brightness Scale (UMAP)',
    save_path='figures/fingerprint_latent_space_umap.png'
)