In [1]:
import pandas as pd
import numpy as np
from fg_funcs import (
    train_conditional_vae, 
    train_conditional_subspace_vae, 
    train_discover_vae, 
    train_base_model,
    save_model,
    FingerprintDataset,
    BaseVAETrainer,
    ConditionalVAETrainer,
    ConditionalSubspaceVAETrainer,
    DiscoverVAETrainer
    )
import torch
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from pathlib import Path

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
TRAIN_SMALL = []
TRAIN_FULL = []

In [3]:
if len(TRAIN_SMALL) != 0:
    # Load the curated dataset
    curated_dataset = pd.read_pickle('data/chembl_35_fg_scaf_curated.pkl')

    # Convert the fingerprint to numpy arrays
    curated_dataset['fingerprint_array'] = curated_dataset['fingerprint_array'].apply(lambda x: x if isinstance(x, np.ndarray) else np.zeros((2048,), dtype=int))
    curated_dataset['fg_array'] = curated_dataset['fg_array'].apply(lambda x: x if isinstance(x, np.ndarray) else np.zeros((100,), dtype=int))

    MODEL_OUTPUT = 'models/small_models'  # Directory to save trained models
    dataset = curated_dataset  # Use the curated dataset for training

latent_dim = 16  
encoder_hidden_dims = [1024, 512, 256, 128]
decoder_hidden_dims = [128, 256, 1024]  
decoder_z_hidden_dims = [128, 256, 1024]  # Decoder layers for DISCoVeR
latent_dim_z = 16 # for CSVAE and DISCoVeR
latent_dim_w = 16 # for CSVAE and DISCoVeR
encoder_hidden_dims_z = [1024, 512, 256, 128] # for CSVAE and DISCoVeR
encoder_hidden_dims_w = [1024, 512, 256, 128] # for CSVAE and DISCoVeR
adversarial_hidden_dims = [64] 
batch_size = 64
learning_rate = 1e-3
max_epochs = 5


for MODEL in TRAIN_SMALL:

    torch.manual_seed(42)

    if MODEL is None:
        raise ValueError("MODEL must be defined before training.")
    elif MODEL == 'Base':

        print("Training BaseVAE model...")

        input_dim = len(dataset['fingerprint_array'].iloc[0])
        fg_dim = len(dataset['fg_array'].iloc[0])

        vae_trainer = train_base_model(
            dataset=dataset,
            input_dim=input_dim,
            latent_dim=latent_dim,
            fg_dim=fg_dim,  # BaseVAE does not use fg_array for logging purposes only
            encoder_hidden_dims=encoder_hidden_dims,
            decoder_hidden_dims=decoder_hidden_dims,
            batch_size=batch_size,
            learning_rate=learning_rate,
            max_epochs=max_epochs,
            sparse=False
        )

        save_model(vae_trainer, MODEL_OUTPUT, MODEL)

    elif MODEL == 'CVAE':

        print("Training CVAE model...")

        fingerprint_dim = len(dataset['fingerprint_array'].iloc[0])
        fg_dim = len(dataset['fg_array'].iloc[0])

        vae_trainer = train_conditional_vae(
            dataset=dataset,
            fingerprint_dim=fingerprint_dim,
            fg_dim=fg_dim,
            latent_dim=latent_dim,
            encoder_hidden_dims=encoder_hidden_dims,
            decoder_hidden_dims=decoder_hidden_dims,
            batch_size=batch_size,
            learning_rate=learning_rate,
            max_epochs=max_epochs,
            sparse=False
        )

        save_model(vae_trainer, MODEL_OUTPUT, MODEL)

    elif MODEL == 'CSVAE':
            
            print("Training CSVAE model based on the NeurIPS 2018 paper...")

            fingerprint_dim = len(dataset['fingerprint_array'].iloc[0])
            fg_dim = len(dataset['fg_array'].iloc[0])

            vae_trainer = train_conditional_subspace_vae(
                dataset=dataset,
                fingerprint_dim=fingerprint_dim,
                fg_dim=fg_dim,
                latent_dim_z=latent_dim_z,
                latent_dim_w=latent_dim_w,
                encoder_hidden_dims_z=encoder_hidden_dims_z,
                encoder_hidden_dims_w=encoder_hidden_dims_w,
                decoder_hidden_dims=decoder_hidden_dims,
                adversarial_hidden_dims=adversarial_hidden_dims,
                batch_size=batch_size,
                learning_rate=learning_rate,
                max_epochs=max_epochs,
                sparse=False
            )
            save_model(vae_trainer, MODEL_OUTPUT, MODEL)

    elif MODEL == 'DISCoVeR':
            
            print("Training DISCoVeR VAE model...")

            fingerprint_dim = len(dataset['fingerprint_array'].iloc[0])
            fg_dim = len(dataset['fg_array'].iloc[0])

            vae_trainer = train_discover_vae(
                dataset=dataset,
                fingerprint_dim=fingerprint_dim,
                fg_dim=fg_dim,
                latent_dim_z=latent_dim_z,
                latent_dim_w=latent_dim_w,
                encoder_hidden_dims_z=encoder_hidden_dims_z,
                encoder_hidden_dims_w=encoder_hidden_dims_w,
                decoder_hidden_dims=decoder_hidden_dims,
                decoder_z_hidden_dims=decoder_z_hidden_dims,
                adversarial_hidden_dims=adversarial_hidden_dims,
                batch_size=batch_size,
                learning_rate=learning_rate,
                max_epochs=max_epochs,
                sparse=False
            )
            save_model(vae_trainer, MODEL_OUTPUT, MODEL)

    # Free memory after each model
    del vae_trainer


In [4]:
# Create list of models and locations
small_models = [
    ("Base", "fg_vae_4/tl8m62ls/checkpoints/epoch=4-step=6250.ckpt", BaseVAETrainer),
    ("CVAE", "fg_cvae_4/tl8m62ls/checkpoints/epoch=4-step=6250.ckpt", ConditionalVAETrainer),
    ("CSVAE", "fg_csvae_4/tl8m62ls/checkpoints/epoch=4-step=12500.ckpt", ConditionalSubspaceVAETrainer),
    ("DISCoVeR", "fg_discover_vae_4/tl8m62ls/checkpoints/epoch=4-step=12500.ckpt", DiscoverVAETrainer)
]

large_models = [
    ("Base", "fg_vae_50/ix9u9770/checkpoints/epoch=4-step=142690.ckpt", BaseVAETrainer),
    ("CVAE", "fg_cvae_50/ix9u9770/checkpoints/epoch=4-step=142690.ckpt", ConditionalVAETrainer),
    ("CSVAE", "fg_csvae_50/ix9u9770/checkpoints/epoch=4-step=285380.ckpt", ConditionalSubspaceVAETrainer),
    ("DISCoVeR", "fg_discover_vae_50/e90lks0j/checkpoints/epoch=4-step=285380.ckpt", DiscoverVAETrainer)
]

In [11]:
def extract_and_save_latents(
    model_path,
    dataloader,
    model_type,
    model_class,   # LightningModule class
    device="cpu",
    map_location=None
):
    """
    Extract latent variables from a trained LightningModule (.ckpt) and save to CSV.

    Supports BaseVAE, ConditionalVAE, ConditionalSubspaceVAE, DiscoverVAE.

    Args:
        model_path (str or Path): Path to the saved .ckpt model.
        dataloader (DataLoader): DataLoader for dataset.
        model_type (str): One of ['Base', 'CVAE', 'CSVAE', 'DISCoVeR'].
        model_class (LightningModule): Class to load from checkpoint.
        device (str): 'cpu', 'cuda', or 'mps'.
        map_location: optional, for torch.load

    Returns:
        pd.DataFrame: Latents DataFrame including y if present.
    """
    model_path = Path(model_path)
    print(f"Loading {model_type} model from {model_path}...")

    # Load LightningModule from checkpoint and access the underlying model
    model = model_class.load_from_checkpoint(str(model_path), map_location=map_location).model
    model = model.to(device)
    model.eval()

    z_list, w_list, y_list = [], [], []

    with torch.no_grad():
        for batch in dataloader:
            if isinstance(batch, (list, tuple)):
                x = batch[0].to(device)
                y = batch[1].to(device) if len(batch) > 1 and torch.is_tensor(batch[1]) else None
            else:
                x, y = batch.to(device), None

            w = None

            if model_type == "Base":
                mu_z, logvar_z = model.encode(x)
                z = model.reparameterize(mu_z, logvar_z)
            elif model_type == "CVAE":
                mu_z, logvar_z = model.encode(x, y)
                z = model.reparameterize(mu_z, logvar_z)
            elif model_type in ["CSVAE", "DISCoVeR"]:
                mu_z, logvar_z, mu_w, logvar_w = model.encode(x, y)
                z = model.reparameterize(mu_z, logvar_z)
                w = model.reparameterize(mu_w, logvar_w)
            else:
                raise ValueError(f"Unsupported model type: {model_type}")

            z_list.append(z.cpu())
            if w is not None:
                w_list.append(w.cpu())
            if y is not None:
                y_list.append(y.cpu())

    # Build DataFrame
    data = {}
    z_tensor = torch.cat(z_list)
    for i in range(z_tensor.shape[1]):
        data[f"z{i}"] = z_tensor[:, i].numpy()

    if w_list:
        w_tensor = torch.cat(w_list)
        for i in range(w_tensor.shape[1]):
            data[f"w{i}"] = w_tensor[:, i].numpy()

    if y_list:
        y_tensor = torch.cat(y_list)
        for i in range(y_tensor.shape[1]):
            data[f"y{i}"] = y_tensor[:, i].numpy()

    df = pd.DataFrame(data)

    # Save CSV
    output_csv = f"latents/latents_{model_type}_{len(dataloader.dataset)}.csv"
    Path(output_csv).parent.mkdir(parents=True, exist_ok=True)
    df.to_csv(output_csv, index=False)
    print(f"Saved latents (including y) to {output_csv}")

    return df


In [12]:
# Create a test DataLoader
curated_dataset = pd.read_pickle('data/chembl_35_fg_scaf_curated.pkl')

# Convert the fingerprint to numpy arrays
curated_dataset['fingerprint_array'] = curated_dataset['fingerprint_array'].apply(lambda x: x if isinstance(x, np.ndarray) else np.zeros((2048,), dtype=int))
curated_dataset['fg_array'] = curated_dataset['fg_array'].apply(lambda x: x if isinstance(x, np.ndarray) else np.zeros((100,), dtype=int))

# Split the dataset into train and test sets
train_data, test_data = train_test_split(curated_dataset, test_size=0.2, random_state=42)
val_data, test_data = train_test_split(test_data, test_size=0.2, random_state=42)

test_dataset = FingerprintDataset(test_data, sparse=False)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False)
device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"

# Extract and save latents for small models
for model_name, model_path, model_class in small_models:
    extract_and_save_latents(
        model_path=model_path,
        dataloader=test_dataloader,
        model_type=model_name,
        model_class=model_class,
        device=device
    )

Loading Base model from fg_vae_4/tl8m62ls/checkpoints/epoch=4-step=6250.ckpt...
Saved latents (including y) to latents/latents_Base_4000.csv
Loading CVAE model from fg_cvae_4/tl8m62ls/checkpoints/epoch=4-step=6250.ckpt...
Saved latents (including y) to latents/latents_CVAE_4000.csv
Loading CSVAE model from fg_csvae_4/tl8m62ls/checkpoints/epoch=4-step=12500.ckpt...
Saved latents (including y) to latents/latents_CSVAE_4000.csv
Loading DISCoVeR model from fg_discover_vae_4/tl8m62ls/checkpoints/epoch=4-step=12500.ckpt...
Saved latents (including y) to latents/latents_DISCoVeR_4000.csv


In [14]:
# Re-run the extraction and saving process for the full dataset
full_dataset = pd.read_csv("data/chembl_35_fg_full.csv")

train_full, test_full = train_test_split(full_dataset, test_size=0.2, random_state=42)
val_full, test_full = train_test_split(test_full, test_size=0.2, random_state=42)

test_full_dataset = FingerprintDataset(test_full, sparse=True)
full_dataloader = DataLoader(test_full_dataset, batch_size=64, shuffle=False)

for model_name, model_path, model_class in large_models:
    extract_and_save_latents(
        model_path=model_path,
        dataloader=full_dataloader,
        model_type=model_name,
        model_class=model_class,
        device="mps"
    )

Loading Base model from fg_vae_50/ix9u9770/checkpoints/epoch=4-step=142690.ckpt...
Saved latents (including y) to latents/latents_Base_91321.csv
Loading CVAE model from fg_cvae_50/ix9u9770/checkpoints/epoch=4-step=142690.ckpt...
Saved latents (including y) to latents/latents_CVAE_91321.csv
Loading CSVAE model from fg_csvae_50/ix9u9770/checkpoints/epoch=4-step=285380.ckpt...
Saved latents (including y) to latents/latents_CSVAE_91321.csv
Loading DISCoVeR model from fg_discover_vae_50/e90lks0j/checkpoints/epoch=4-step=285380.ckpt...
Saved latents (including y) to latents/latents_DISCoVeR_91321.csv
