### Multi-Conditional Graph Variational Autoencoder

In [1]:
import torch
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

In [2]:
from data.dataset import QM9GraphDataset
from utils.datautils import create_data_loaders
from mgcvae import MGCVAE
from trainer import MGCVAETrainer
from utils.trainutils import load_from_checkpoint

In [3]:
dataset = QM9GraphDataset(csv_path='./data/qm9_bbbp.csv')

Processing...


Processing 2142 molecules...
Successfully processed 2142 molecules, failed: 0


Done!


In [4]:
train_loader, val_loader, test_loader = create_data_loaders(dataset, batch_size=4)

Dataset splits — Train: 1713, Val: 214, Test: 215


In [5]:
# Load the best model from training
model, checkpoint = load_from_checkpoint('checkpoints/mgcvae/best_model.pth', device=device)
loaded = True

# Model is ready for inference!
# You can also access training history:
print(f"\nTraining stopped at epoch: {checkpoint['epoch']}")
print(f"Final validation loss: {checkpoint['best_val_loss']:.4f}")

Model loaded from checkpoints/mgcvae/best_model.pth
Epoch: 5, Best Val Loss: 0.8495

Training stopped at epoch: 5
Final validation loss: 0.8495


### If we haven't Trained yet

In [6]:
if not loaded:
    model_config = {
        'node_dim': 29,
        'edge_dim': 6,
        'latent_dim': 32,
        'hidden_dim': 64,
        'num_properties': 1,
        'num_layers': 3,
        'heads': 4,
        'max_nodes': 20,
        'beta': 0.01,      # Start with low KL weight
        'gamma': 1.0,      # Property prediction weight
        'dropout': 0.1
    }
    model = MGCVAE(**model_config).to(device)

In [7]:
trainer = MGCVAETrainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    lr=1e-3,
    device=device,
    save_dir='checkpoints/mgcvae'
)

In [None]:
train_metrics, val_metrics = trainer.train(num_epochs=5, start_epoch=checkpoint['epoch']+1)

Starting MGCVAE training from epoch 6 to 10
Model parameters: 105,216
Device: mps

Epoch 6/10


Training:   0%|          | 0/429 [00:00<?, ?it/s]

### Metrics

In [None]:
from utils.metrics import (
    evaluate_property_prediction,
    evaluate_reconstruction_and_kl,
    evaluate_novelty_diversity,
    evaluate_conditioning_latent
)

In [None]:
from utils.inference import (
    batch_logits_to_molecules,
    evaluate_generation_quality
)

In [None]:
_ = evaluate_property_prediction(model, val_loader, device)
_ = evaluate_reconstruction_and_kl(model, val_loader, device)

  mse = F.mse_loss(preds, targets, reduction='sum').item()
  mse = F.mse_loss(preds, targets, reduction='sum').item()


Property Prediction MSE: 0.3804
Avg Reconstruction Loss: 0.7515
Avg KL Divergence: 2.7008
