### Multi-Conditional Graph Variational Autoencoder

In [1]:
import torch
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

In [2]:
from data.dataset import QM9GraphDataset
from utils.datautils import create_data_loaders
from mgcvae import MGCVAE
from trainer import MGCVAETrainer
from utils.trainutils import load_from_checkpoint

In [3]:
dataset = QM9GraphDataset(csv_path='./data/qm9_bbbp2.csv')

Processing...


Processing 4898 molecules...
Successfully processed 4898 molecules, failed: 0


Done!


In [4]:
train_loader, val_loader, test_loader = create_data_loaders(dataset, batch_size=4)

Dataset splits — Train: 3918, Val: 490, Test: 490


In [5]:
# Load the best model from training
model, optimizer_state, scheduler_state, checkpoint = load_from_checkpoint(
    'checkpoints/mgcvae/best_model.pth',
    device=device
)
loaded = True

# Model is ready for inference!
# You can also access training history:
print(f"\nTraining stopped at epoch: {checkpoint['epoch']}")
print(f"Final validation loss: {checkpoint['best_val_loss']:.4f}")

Model loaded from checkpoints/mgcvae/best_model.pth
Epoch: 18, Best Val Loss: 0.8214

Training stopped at epoch: 18
Final validation loss: 0.8214


### If we haven't Trained yet

In [6]:
if not loaded:
    model_config = {
        'node_dim': 29,
        'edge_dim': 6,
        'latent_dim': 32,
        'hidden_dim': 64,
        'num_properties': 1,
        'num_layers': 3,
        'heads': 4,
        'max_nodes': 20,
        'beta': 0.01,      # Start with low KL weight
        'gamma': 1.0,      # Property prediction weight
        'dropout': 0.1
    }
    model = MGCVAE(**model_config).to(device)

In [7]:
trainer = MGCVAETrainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    lr=1e-3,
    device=device,
    save_dir='checkpoints/mgcvae'
)

In [8]:
trainer.load_optimizer_scheduler(optimizer_state, scheduler_state)

Optimizer and scheduler states loaded


In [9]:
train_metrics, val_metrics = trainer.train(num_epochs=3, start_epoch=checkpoint['epoch']+1)

Starting MGCVAE training from epoch 19 to 21
Model parameters: 105,216
Device: mps

Epoch 19/21


Training: 100%|██████████| 980/980 [12:53<00:00,  1.27it/s, Loss=0.9873, Recon=0.9723, KL=1.3295, Prop=0.0017]
Validation: 100%|██████████| 123/123 [00:57<00:00,  2.15it/s]


New best model saved! Val loss: 0.9521
Train Loss: 1.0201 | Val Loss: 0.9521
Recon: 0.8823 | KL: 2.8503 | Prop: 0.1093
LR: 1.00e-03 | Patience: 0/30

Epoch 20/21


Training:  28%|██▊       | 279/980 [02:35<06:30,  1.79it/s, Loss=1.0501, Recon=0.9246, KL=2.4851, Prop=0.1006]


KeyboardInterrupt: 

### Metrics

In [10]:
from utils.metrics import (
    evaluate_property_prediction,
    evaluate_reconstruction_and_kl,
    evaluate_novelty_diversity,
    evaluate_conditioning_latent
)

In [11]:
from utils.inference import (
    batch_logits_to_molecules,
    evaluate_generation_quality
)

In [12]:
import warnings
warnings.filterwarnings('ignore', category=UserWarning)

In [13]:
_ = evaluate_property_prediction(model, val_loader, device)
_ = evaluate_reconstruction_and_kl(model, val_loader, device)

Property Prediction MSE: 0.3753
Avg Reconstruction Loss: 0.7220
Avg KL Divergence: 2.5864


In [14]:
evaluate_conditioning_latent(model, target=[0.9], num_samples=20, tolerance=0.15, device=device)

Conditioning Evaluation (latent):
Target:            [0.9]
Success rate:      20.0% within ±0.15
Mean absolute err: 0.2218
Predicted mean:    [0.6782122850418091]
  Predicted std:     [0.10340068489313126]


{'success_rate': 20.0,
 'mae': 0.22178768515586858,
 'mean_pred': [0.6782122850418091],
 'std_pred': [0.10340068489313126]}

### Inference
Once the metrics are within reasonable thresholds, we can conditionally generate new molecules