In [11]:
import pandas as pd
import numpy as np
from VAE import (
    train_conditional_vae, 
    train_conditional_subspace_vae, 
    train_discover_vae, 
    train_base_model,
    FingerprintDataset,
    BaseVAETrainer,
    ConditionalVAETrainer,
    ConditionalSubspaceVAETrainer,
    DiscoverVAETrainer
    )
from fg_funcs import (
    extract_and_save_latents,
    extract_prefixed_arrays,
    get_nearest_neighbors,
    metric
)

import torch
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

In [12]:
TRAIN = []
GET_LATENTS = False
VIS_CUR = False
VIS_FULL = False

In [13]:
latent_dims = [4, 8, 16]  # You can change this to test different latent dimensions
encoder_hidden_dims = [1024, 512, 256, 128]
decoder_hidden_dims = [128, 512]  
decoder_z_hidden_dims = [128, 512]  # Decoder layers for DISCoVeR
latent_dims_z = [4, 8, 16] # for CSVAE and DISCoVeR
latent_dims_w = [2, 4, 8] # for CSVAE and DISCoVeR
encoder_hidden_dims_z = [1024, 512, 256, 128] # for CSVAE and DISCoVeR
encoder_hidden_dims_w = [1024, 512, 256, 128] # for CSVAE and DISCoVeR
adversarial_hidden_dims = [8] 
batch_size = 64
learning_rate = 1e-3
max_epochs = 5


In [14]:
for latent_dim, latent_dim_z, latent_dim_w in zip(latent_dims, latent_dims_z, latent_dims_w):
    for MODEL in TRAIN:
        model_type, dataset_type = MODEL.split('_')

        if dataset_type == 'CUR':
            # Load the curated dataset
            curated_dataset = pd.read_pickle('data/chembl_35_fg_scaf_curated.pkl')

            # Convert the fingerprint to numpy arrays
            curated_dataset['fingerprint_array'] = curated_dataset['fingerprint_array'].apply(lambda x: x if isinstance(x, np.ndarray) else np.zeros((2048,), dtype=int))
            curated_dataset['fg_array'] = curated_dataset['fg_array'].apply(lambda x: x if isinstance(x, np.ndarray) else np.zeros((100,), dtype=int))

            MODEL_OUTPUT = 'models/small_models'  # Directory to save trained models
            dataset = curated_dataset  # Use the curated dataset for training
            sparse = False

            input_dim = 2048
            fg_dim = 4

        elif dataset_type == 'FULL':
            # Load the full dataset
            full_dataset = pd.read_csv("data/chembl_35_fg_full.csv")

            MODEL_OUTPUT = 'models/large_models'  # Directory to save trained models
            dataset = full_dataset  # Use the full dataset for training
            sparse = True

            input_dim = 2048
            fg_dim = 50

        torch.manual_seed(42)

        if model_type is None:
            raise ValueError("MODEL must be defined before training.")
        elif model_type == 'Base':

            print("Training BaseVAE model...")
            

            vae_trainer = train_base_model(
                dataset=dataset,
                input_dim=input_dim,
                latent_dim=latent_dim,
                fg_dim=fg_dim,  # BaseVAE does not use fg_array for logging purposes only
                encoder_hidden_dims=encoder_hidden_dims,
                decoder_hidden_dims=decoder_hidden_dims,
                batch_size=batch_size,
                learning_rate=learning_rate,
                max_epochs=max_epochs,
                sparse=sparse
            )

        elif model_type == 'CVAE':

            print("Training CVAE model...")

            vae_trainer = train_conditional_vae(
                dataset=dataset,
                fingerprint_dim=input_dim,
                fg_dim=fg_dim,
                latent_dim=latent_dim,
                encoder_hidden_dims=encoder_hidden_dims,
                decoder_hidden_dims=decoder_hidden_dims,
                batch_size=batch_size,
                learning_rate=learning_rate,
                max_epochs=max_epochs,
                sparse=sparse
            )

        elif model_type == 'CSVAE':

            print("Training CSVAE model based on the NeurIPS 2018 paper...")


            vae_trainer = train_conditional_subspace_vae(
                dataset=dataset,
                fingerprint_dim=input_dim,
                fg_dim=fg_dim,
                latent_dim_z=latent_dim_z,
                latent_dim_w=latent_dim_w,
                encoder_hidden_dims_z=encoder_hidden_dims_z,
                encoder_hidden_dims_w=encoder_hidden_dims_w,
                decoder_hidden_dims=decoder_hidden_dims,
                adversarial_hidden_dims=adversarial_hidden_dims,
                batch_size=batch_size,
                learning_rate=learning_rate,
                max_epochs=max_epochs,
                sparse=sparse
            )

        elif model_type == 'DISCoVeR':

            print("Training DISCoVeR VAE model...")

            vae_trainer = train_discover_vae(
                dataset=dataset,
                fingerprint_dim=input_dim,
                fg_dim=fg_dim,
                latent_dim_z=latent_dim_z,
                latent_dim_w=latent_dim_w,
                encoder_hidden_dims_z=encoder_hidden_dims_z,
                encoder_hidden_dims_w=encoder_hidden_dims_w,
                decoder_hidden_dims=decoder_hidden_dims,
                decoder_z_hidden_dims=decoder_z_hidden_dims,
                adversarial_hidden_dims=adversarial_hidden_dims,
                batch_size=batch_size,
                learning_rate=learning_rate,
                max_epochs=max_epochs,
                sparse=sparse
            )

        # Free memory after each model
        del vae_trainer


In [15]:
# Create list of models and locations
latent_dim = latent_dims[1]
small_models = [
    ("Base", f"checkpoints/fg_bvae_4_{latent_dim}/best-checkpoint.ckpt", BaseVAETrainer),
    ("CVAE", f"checkpoints/fg_cvae_4_{latent_dim}/best-checkpoint.ckpt", ConditionalVAETrainer),
    ("CSVAE", f"checkpoints/fg_csvae_4_{latent_dim}/best-checkpoint.ckpt", ConditionalSubspaceVAETrainer),
    ("DISCoVeR", f"checkpoints/fg_dvae_4_{latent_dim}/best-checkpoint.ckpt", DiscoverVAETrainer)
]

large_models = [
    ("Base", f"checkpoints/fg_bvae_50_{latent_dim}/best-checkpoint.ckpt", BaseVAETrainer),
    ("CVAE", f"checkpoints/fg_cvae_50_{latent_dim}/best-checkpoint.ckpt", ConditionalVAETrainer),
    ("CSVAE", f"checkpoints/fg_csvae_50_{latent_dim}/best-checkpoint.ckpt", ConditionalSubspaceVAETrainer),
    ("DISCoVeR", f"checkpoints/fg_dvae_50_{latent_dim}/best-checkpoint.ckpt", DiscoverVAETrainer)
]

In [16]:
if GET_LATENTS:
    # Create a test DataLoader
    curated_dataset = pd.read_pickle('data/chembl_35_fg_scaf_curated.pkl')

    # Convert the fingerprint to numpy arrays
    curated_dataset['fingerprint_array'] = curated_dataset['fingerprint_array'].apply(lambda x: x if isinstance(x, np.ndarray) else np.zeros((2048,), dtype=int))
    curated_dataset['fg_array'] = curated_dataset['fg_array'].apply(lambda x: x if isinstance(x, np.ndarray) else np.zeros((100,), dtype=int))

    # Split the dataset into train and test sets
    train_data, test_data = train_test_split(curated_dataset, test_size=0.2, random_state=42)
    val_data, test_data = train_test_split(test_data, test_size=0.2, random_state=42)

    test_dataset = FingerprintDataset(test_data, sparse=False)
    test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False)
    device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"

    # Extract and save latents for small models
    for model_name, model_path, model_class in small_models:
        extract_and_save_latents(
            model_path=model_path,
            dataloader=test_dataloader,
            model_type=model_name,
            model_class=model_class,
            device=device
        )

In [17]:
if GET_LATENTS:    
    try:
        # Re-run the extraction and saving process for the full dataset
        full_dataset = pd.read_csv("data/chembl_35_fg_full.csv")

        train_full, test_full = train_test_split(full_dataset, test_size=0.2, random_state=42)
        val_full, test_full = train_test_split(test_full, test_size=0.2, random_state=42)

        test_full_dataset = FingerprintDataset(test_full, sparse=True)
        full_dataloader = DataLoader(test_full_dataset, batch_size=64, shuffle=False)

        for model_name, model_path, model_class in large_models:
            extract_and_save_latents(
                model_path=model_path,
                dataloader=full_dataloader,
                model_type=model_name,
                model_class=model_class,
                device="mps"
            )
    except Exception as e:
        print(f"Error occurred while extracting latents: {e}")

In [18]:
# Load and plot base latents UMAP
latent_dim = latent_dim
cur_size = 80000
full_size = 1826415

if VIS_CUR:
    try:
        base_latents_cur = extract_prefixed_arrays(f'latents/latents_Base_{cur_size}_{latent_dim}.csv', ['z', 'y'])
        cvae_latents_cur = extract_prefixed_arrays(f'latents/latents_CVAE_{cur_size}_{latent_dim}.csv', ['z', 'y'])
        csvae_latents_cur = extract_prefixed_arrays(f'latents/latents_CSVAE_{cur_size}_{latent_dim}.csv', ['z', 'w', 'y'])
        discover_latents_cur = extract_prefixed_arrays(f'latents/latents_DISCoVeR_{cur_size}_{latent_dim}.csv', ['z', 'w', 'y'])

    except Exception as e:
        print(f"Error occurred while loading current latents: {e}")

if VIS_FULL:
    try:
        base_latents_full = extract_prefixed_arrays(f'latents/latents_Base_{full_size}_{latent_dim}.csv', ['z', 'y'])
        cvae_latents_full = extract_prefixed_arrays(f'latents/latents_CVAE_{full_size}_{latent_dim}.csv', ['z', 'y'])
        csvae_latents_full = extract_prefixed_arrays(f'latents/latents_CSVAE_{full_size}_{latent_dim}.csv', ['z', 'w', 'y'])
        discover_latents_full = extract_prefixed_arrays(f'latents/latents_DISCoVeR_{full_size}_{latent_dim}.csv', ['z', 'w', 'y'])
    except Exception as e:
        print(f"Error occurred while loading full latents: {e}")

In [19]:
from fg_funcs import visualize_latent_space_per_fg

method = 'umap'

if VIS_CUR:
    try:
        # Visualize the latent space
        visualize_latent_space_per_fg(base_latents_cur['z'], base_latents_cur['y'], method, sample_size=5000, combined_title=f'Base VAE {method.upper()} {latent_dim} {'Test Set' if cur_size < 50000 else 'Train Set'}', per_fg_title=f'Base VAE by FG {method.upper()} {latent_dim} {'Test Set' if cur_size < 50000 else 'Train Set'}', save_path=f'images/base_vae_cur_{latent_dim}_{cur_size}/')
        visualize_latent_space_per_fg(cvae_latents_cur['z'], cvae_latents_cur['y'], method, sample_size=5000, combined_title=f'CVAE {method.upper()} {latent_dim} {'Test Set' if cur_size < 50000 else 'Train Set'}', per_fg_title=f'CVAE by FG {method.upper()} {latent_dim} {'Test Set' if cur_size < 50000 else 'Train Set'}', save_path=f'images/cvae_cur_{latent_dim}_{cur_size}/')
        visualize_latent_space_per_fg(csvae_latents_cur['z'], csvae_latents_cur['y'], method, sample_size=5000, combined_title=f'CSVAE {method.upper()} {latent_dim} {'Test Set' if cur_size < 50000 else 'Train Set'}', per_fg_title=f'CSVAE by FG {method.upper()} {latent_dim} {'Test Set' if cur_size < 50000 else 'Train Set'}', save_path=f'images/csvae_cur_{latent_dim}_{cur_size}/')
        visualize_latent_space_per_fg(discover_latents_cur['z'], discover_latents_cur['y'], method, sample_size=5000, combined_title=f'Discover {method.upper()} {latent_dim} {'Test Set' if cur_size < 50000 else 'Train Set'}', per_fg_title=f'Discover by FG {method.upper()} {latent_dim} {'Test Set' if cur_size < 50000 else 'Train Set'}', save_path=f'images/discover_cur_{latent_dim}_{cur_size}/')
        visualize_latent_space_per_fg(csvae_latents_cur['w'], csvae_latents_cur['y'], method, sample_size=5000, combined_title=f'CSVAE {method.upper()} {latent_dim/2} {'Test Set' if cur_size < 50000 else 'Train Set'}', per_fg_title=f'CSVAE by FG {method.upper()} {latent_dim/2} {'Test Set' if cur_size < 50000 else 'Train Set'}', save_path=f'images/csvae_cur_w_{latent_dim}_{cur_size}/')
        visualize_latent_space_per_fg(discover_latents_cur['w'], discover_latents_cur['y'], method, sample_size=5000, combined_title=f'Discover {method.upper()} {latent_dim/2} {'Test Set' if cur_size < 50000 else 'Train Set'}', per_fg_title=f'Discover by FG {method.upper()} {latent_dim/2} {'Test Set' if cur_size < 50000 else 'Train Set'}', save_path=f'images/discover_cur_w_{latent_dim}_{cur_size}/')
    except Exception as e:
        print(f"Error occurred while visualizing current latents: {e}")

if VIS_FULL:
    try:
        visualize_latent_space_per_fg(base_latents_full['z'], base_latents_full['y'], method, sample_size=5000, combined_title=f'Base VAE {method.upper()} {latent_dim}', per_fg_title=f'Base VAE by FG {method.upper()} {latent_dim}', save_path=f'images/base_vae_full_{latent_dim}_{full_size}/')
        visualize_latent_space_per_fg(cvae_latents_full['z'], cvae_latents_full['y'], method, sample_size=5000, combined_title=f'CVAE {method.upper()} {latent_dim}', per_fg_title=f'CVAE by FG {method.upper()} {latent_dim}', save_path=f'images/cvae_full_{latent_dim}_{full_size}/')
        visualize_latent_space_per_fg(csvae_latents_full['z'], csvae_latents_full['y'], method, sample_size=5000, combined_title=f'CSVAE {method.upper()} {latent_dim}', per_fg_title=f'CSVAE by FG {method.upper()} {latent_dim}', save_path=f'images/csvae_full_{latent_dim}_{full_size}/')
        visualize_latent_space_per_fg(discover_latents_full['z'], discover_latents_full['y'], method, sample_size=5000, combined_title=f'Discover {method.upper()} {latent_dim}', per_fg_title=f'Discover by FG {method.upper()} {latent_dim}', save_path=f'images/discover_full_{latent_dim}_{full_size}/')
        visualize_latent_space_per_fg(csvae_latents_full['w'], csvae_latents_full['y'], method, sample_size=5000, combined_title=f'CSVAE {method.upper()} {latent_dim}', per_fg_title=f'CSVAE by FG {method.upper()} {latent_dim}', save_path=f'images/csvae_full_w_{latent_dim}_{full_size}/')
        visualize_latent_space_per_fg(discover_latents_full['w'], discover_latents_full['y'], method, sample_size=5000, combined_title=f'Discover {method.upper()} {latent_dim}', per_fg_title=f'Discover by FG {method.upper()} {latent_dim}', save_path=f'images/discover_full_w_{latent_dim}_{full_size}/')
    except Exception as e:
        print(f"Error occurred while visualizing full latents: {e}")

In [20]:
import csv

# load all latents into latents list
latents = {}
models = ['Base', 'CVAE', 'CSVAE', 'DISCoVeR']
sizes = ['4000', '80000', '91321', '1826415']
dims = ['4','8', '16']

for model in models:
    if model in ['Base', 'CVAE']:
        terms = ['z', 'y']
    elif model in ['CSVAE', 'DISCoVeR']:
        terms = ['z', 'w', 'y']
    for size in sizes:
        for dim in dims:
            name = f"latents_{model}_{size}_{dim}"
            latent = f"latents/{name}.csv"
            latents[name] = extract_prefixed_arrays(latent, terms)

# Prepare to collect metrics
metrics_rows = []
header = ['model', 'size', 'dim', 'latent_type', 'score']

for name, latent in latents.items():
    # Parse model, size, dim from name
    parts = name.split('_')
    model = parts[1]
    size = parts[2]
    dim = parts[3]

    fg_counts = latent['y'].sum()

    # Take a sample of the latents
    if len(latent['y']) > 100000:
        latent_sample = latent.sample(n=100000)
    else:
        latent_sample = latent

    # z latent
    nbrs, latent_z, labels = get_nearest_neighbors(latents=latent_sample['z'].to_list(), fg_labels=latent_sample['y'].to_list(), n_neighbors=50)
    scores = []
    for i in range(len(latent_sample['y'])):
        metric_score = metric(nbrs, i, latent_z, fg_counts, labels, 50)
        scores.append(metric_score)
    mean_score = np.mean(scores)
    print(f"Metrics for latent {name}: {mean_score}")
    metrics_rows.append([model, size, dim, 'z', mean_score])

    # w latent if present
    if 'w' in latent_sample:
        nbrs, latent_w, labels = get_nearest_neighbors(latents=latent_sample['w'].to_list(), fg_labels=latent_sample['y'].to_list(), n_neighbors=50)
        scores = []
        for i in range(len(latent_sample['y'])):
            metric_score = metric(nbrs, i, latent_w, fg_counts, labels, 50)
            scores.append(metric_score)
        mean_score = np.mean(scores)
        print(f"Metrics for latent {name} (w): {mean_score}")
        metrics_rows.append([model, size, dim, 'w', mean_score])

# Save to CSV
with open('latent_metrics.csv', 'w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(header)
    writer.writerows(metrics_rows)

Metrics for latent latents_Base_4000_4: 0.738165
Metrics for latent latents_Base_4000_8: 0.636715
Metrics for latent latents_Base_4000_16: 0.5323399999999999
Metrics for latent latents_Base_80000_4: 0.8097952500000001
Metrics for latent latents_Base_80000_8: 0.7324422500000001
Metrics for latent latents_Base_80000_16: 0.6234390000000001


KeyboardInterrupt: 

In [None]:
# Use metric directly on fingerprints
fingerprint_metric_df = pd.read_pickle('data/chembl_35_fg_scaf_curated.pkl')

fingerprint_train, fingerprint_test = train_test_split(fingerprint_metric_df, test_size=0.2, random_state=42)
fingerprint_val, fingerprint_test = train_test_split(fingerprint_test, test_size=0.2, random_state=42)

nbrs, fingerprints, labels = get_nearest_neighbors(latents=fingerprint_train['fingerprint_array'].to_list(), fg_labels=fingerprint_train['fg_array'].to_list(), n_neighbors=50)

scores = []

for i in range(len(fingerprints)):
    score = metric(nbrs, i, fingerprints, fingerprint_train['fg_array'].sum(), labels, 50)
    scores.append(score)

print(scores.mean())