### Multi-Conditional Graph Variational Autoencoder

In [1]:
import torch
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

In [2]:
from data.dataset import QM9GraphDataset
from utils.datautils import create_data_loaders
from mgcvae import MGCVAE
from trainer import MGCVAETrainer
from utils.trainutils import load_from_checkpoint

In [3]:
dataset = QM9GraphDataset(csv_path='./data/qm9_bbbp2.csv')

Processing...


Processing 4898 molecules...
Successfully processed 4898 molecules, failed: 0


Done!


In [4]:
train_loader, val_loader, test_loader = create_data_loaders(dataset, batch_size=4)

Dataset splits — Train: 3918, Val: 490, Test: 490


In [5]:
loaded = True

In [6]:
# Load the best model from training
if loaded:
    model, optimizer_state, scheduler_state, checkpoint = load_from_checkpoint(
        'checkpoints/mgcvae/best_model.pth',
        device=device
    )
    # Model is ready for inference!
    # You can also access training history:
    print(f"\nTraining stopped at epoch: {checkpoint['epoch']}")
    print(f"Final validation loss: {checkpoint['best_val_loss']:.4f}")

Model loaded from checkpoints/mgcvae/best_model.pth
Epoch: 2, Best Val Loss: 1.0025

Training stopped at epoch: 2
Final validation loss: 1.0025


In [7]:
if not loaded:
    model_config = {
        'node_dim': 29,
        'edge_dim': 6,
        'latent_dim': 32,
        'hidden_dim': 64,
        'num_properties': 1,
        'num_layers': 2,
        'heads': 4,
        'max_nodes': 20,
        'beta': 0.01,      # Start with low KL weight
        'gamma': 1.0,      # Property prediction weight
        'dropout': 0.1
    }
    model = MGCVAE(**model_config).to(device)

In [8]:
trainer = MGCVAETrainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    lr=1e-3,
    device=device,
    save_dir='checkpoints/mgcvae'
)

In [9]:
if loaded: trainer.load_optimizer_scheduler(optimizer_state, scheduler_state)

Optimizer and scheduler states loaded


In [10]:
train_metrics, val_metrics = trainer.train(num_epochs=4, start_epoch=0 if not loaded else checkpoint['epoch']+1)

Starting MGCVAE training from epoch 3 to 6
Model parameters: 96,640
Device: mps

Epoch 3/6


Training: 100%|██████████| 980/980 [18:16<00:00,  1.12s/it, Loss=0.9784, Recon=0.9550, KL=1.6734, Prop=0.0066]
Validation: 100%|██████████| 123/123 [02:05<00:00,  1.02s/it]


New best model saved! Val loss: 0.9889
Train Loss: 1.0601 | Val Loss: 0.9889
Recon: 0.9279 | KL: 2.1369 | Prop: 0.1108
LR: 1.00e-03 | Patience: 0/30

Epoch 4/6


Training:  19%|█▉        | 189/980 [03:26<14:24,  1.09s/it, Loss=1.0741, Recon=0.9661, KL=1.7038, Prop=0.0909]


KeyboardInterrupt: 

### Metrics

In [11]:
from utils.metrics import (
    evaluate_property_prediction,
    evaluate_reconstruction_and_kl,
    evaluate_novelty_diversity,
    evaluate_conditioning_latent
)

In [12]:
from utils.inference import (
    batch_logits_to_molecules,
    evaluate_generation_quality
)

In [13]:
import warnings
warnings.filterwarnings('ignore', category=UserWarning)

In [14]:
_ = evaluate_property_prediction(model, val_loader, device)
_ = evaluate_reconstruction_and_kl(model, val_loader, device)

Property Prediction MSE: 0.4312
Avg Reconstruction Loss: 0.8863
Avg KL Divergence: 1.3024


In [15]:
evaluate_conditioning_latent(model, target=[0.9], num_samples=20, tolerance=0.15, device=device)

Conditioning Evaluation (latent):
Target:            [0.9]
Success rate:      0.0% within ±0.15
Mean absolute err: 0.2898
Predicted mean:    [0.6102389693260193]
  Predicted std:     [0.02855761907994747]


{'success_rate': 0.0,
 'mae': 0.289760982990265,
 'mean_pred': [0.6102389693260193],
 'std_pred': [0.02855761907994747]}

### Inference
Once the metrics are within reasonable thresholds, we can conditionally generate new molecules