In [1]:
from fg_funcs import save_model
from full_vae import train_base_model, train_conditional_vae, train_conditional_subspace_vae, train_discover_vae
import pandas as pd
import numpy as np
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load the curated dataset
curated_dataset = pd.read_pickle('data/chembl_35_fg_scaf_curated.pkl')

# Convert the fingerprint to numpy arrays
curated_dataset['fingerprint_array'] = curated_dataset['fingerprint_array'].apply(lambda x: x if isinstance(x, np.ndarray) else np.zeros((2048,), dtype=int))

In [3]:
MODELS = ['DISCoVeR']  # List of models to train
MODEL_OUTPUT = 'models/small_models'  # Directory to save trained models
dataset = curated_dataset  # Use the curated dataset for training

# Train all VAE models on curated dataset
latent_dim = 16  # Example latent dimension
encoder_hidden_dims = [1024, 512, 256, 128]
decoder_hidden_dims = [128, 256, 1024]  
decoder_z_hidden_dims = [128, 256, 1024]  # Decoder layers for DISCoVeR
latent_dim_z = 16 # for CSVAE and DISCoVeR
latent_dim_w = 16 # for CSVAE and DISCoVeR
encoder_hidden_dims_z = [1024, 512, 256, 128] # for CSVAE and DISCoVeR
encoder_hidden_dims_w = [1024, 512, 256, 128] # for CSVAE and DISCoVeR
adversarial_hidden_dims = [64] 
batch_size = 64
learning_rate = 1e-3
max_epochs = 5


for MODEL in MODELS:

    torch.manual_seed(42)

    if MODEL is None:
        raise ValueError("MODEL must be defined before training.")
    elif MODEL == 'Base':

        print("Training BaseVAE model...")

        input_dim = len(dataset['fingerprint_array'].iloc[0])
        fg_dim = len(dataset['fg_array'].iloc[0])

        vae_trainer = train_base_model(
            dataset=dataset,
            input_dim=input_dim,
            latent_dim=latent_dim,
            fg_dim=fg_dim,  # BaseVAE does not use fg_array for logging purposes only
            encoder_hidden_dims=encoder_hidden_dims,
            decoder_hidden_dims=decoder_hidden_dims,
            batch_size=batch_size,
            learning_rate=learning_rate,
            max_epochs=max_epochs
        )

        save_model(vae_trainer, MODEL_OUTPUT, MODEL)

    elif MODEL == 'CVAE':

        print("Training CVAE model...")

        fingerprint_dim = len(dataset['fingerprint_array'].iloc[0])
        fg_dim = len(dataset['fg_array'].iloc[0])

        vae_trainer = train_conditional_vae(
            dataset=dataset,
            fingerprint_dim=fingerprint_dim,
            fg_dim=fg_dim,
            latent_dim=latent_dim,
            encoder_hidden_dims=encoder_hidden_dims,
            decoder_hidden_dims=decoder_hidden_dims,
            batch_size=batch_size,
            learning_rate=learning_rate,
            max_epochs=max_epochs
        )

        save_model(vae_trainer, MODEL_OUTPUT, MODEL)

    elif MODEL == 'CSVAE':
            
            print("Training CSVAE model based on the NeurIPS 2018 paper...")

            fingerprint_dim = len(dataset['fingerprint_array'].iloc[0])
            fg_dim = len(dataset['fg_array'].iloc[0])

            vae_trainer = train_conditional_subspace_vae(
                dataset=dataset,
                fingerprint_dim=fingerprint_dim,
                fg_dim=fg_dim,
                latent_dim_z=latent_dim_z,
                latent_dim_w=latent_dim_w,
                encoder_hidden_dims_z=encoder_hidden_dims_z,
                encoder_hidden_dims_w=encoder_hidden_dims_w,
                decoder_hidden_dims=decoder_hidden_dims,
                adversarial_hidden_dims=adversarial_hidden_dims,
                batch_size=batch_size,
                learning_rate=learning_rate,
                max_epochs=max_epochs
            )
            save_model(vae_trainer, MODEL_OUTPUT, MODEL)

    elif MODEL == 'DISCoVeR':
            
            print("Training DISCoVeR VAE model...")

            fingerprint_dim = len(dataset['fingerprint_array'].iloc[0])
            fg_dim = len(dataset['fg_array'].iloc[0])

            vae_trainer = train_discover_vae(
                dataset=dataset,
                fingerprint_dim=fingerprint_dim,
                fg_dim=fg_dim,
                latent_dim_z=latent_dim_z,
                latent_dim_w=latent_dim_w,
                encoder_hidden_dims_z=encoder_hidden_dims_z,
                encoder_hidden_dims_w=encoder_hidden_dims_w,
                decoder_hidden_dims=decoder_hidden_dims,
                decoder_z_hidden_dims=decoder_z_hidden_dims,
                adversarial_hidden_dims=adversarial_hidden_dims,
                batch_size=batch_size,
                learning_rate=learning_rate,
                max_epochs=max_epochs
            )
            save_model(vae_trainer, MODEL_OUTPUT, MODEL)

    # Free memory after each model
    del vae_trainer

Training DISCoVeR VAE model...


ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
[34m[1mwandb[0m: Currently logged in as: [33mdamianrelkins[0m ([33mdamianrelkins-university-college-london-ucl-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin



  | Name                     | Type        | Params | Mode 
-----------------------------------------------------------------
0 | model                    | DiscoverVAE | 10.5 M | train
1 | reconstruction_loss_func | BCELoss     | 0      | train
2 | adversarial_loss_func    | BCELoss     | 0      | train
-----------------------------------------------------------------
10.5 M    Trainable params
0         Non-trainable params
10.5 M    Total params
42.092    Total estimated model params size (MB)
69        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/damianelkins/miniconda3/envs/rdkit-thesis/lib/python3.13/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


                                                                           

/Users/damianelkins/miniconda3/envs/rdkit-thesis/lib/python3.13/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Epoch 4: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 1250/1250 [01:09<00:00, 18.07it/s, v_num=3jub, val_total_loss=0.506]

`Trainer.fit` stopped: `max_epochs=5` reached.


Epoch 4: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 1250/1250 [01:09<00:00, 18.07it/s, v_num=3jub, val_total_loss=0.506]


/Users/damianelkins/miniconda3/envs/rdkit-thesis/lib/python3.13/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 63/63 [00:01<00:00, 61.16it/s]
â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
       Test metric             DataLoader 0
â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
      test_adv_disc         0.5209407806396484
       test_kld_w         2.0425319235073403e-05
       test_kld_z           0.0058710603043437
 