In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.distributions import Normal, kl_divergence

device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

In [2]:
from data.dataset import QM9GraphDataset, create_data_loaders
from mgcvae import MGCVAE
from trainer import MGCVAETrainer

In [3]:
dataset = QM9GraphDataset(csv_path='./data/qm9_bbbp.csv')

Processing...


Processing 2142 molecules...
Successfully processed 2142 molecules, failed: 0


Done!


In [4]:
train_loader, val_loader, test_loader = create_data_loaders(dataset, batch_size=4)

Dataset splits — Train: 1713, Val: 214, Test: 215


In [5]:
model_config = {
    'node_dim': 29,
    'edge_dim': 6,
    'latent_dim': 32,
    'hidden_dim': 64,
    'num_properties': 2,
    'num_layers': 3,
    'heads': 4,
    'max_nodes': 20,
    'beta': 0.01,      # Start with low KL weight
    'gamma': 1.0,      # Property prediction weight
    'dropout': 0.1
}

In [6]:
model = MGCVAE(**model_config).to(device)

In [7]:
trainer = MGCVAETrainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    lr=1e-3,
    device=device,
    save_dir='checkpoints/mgcvae'
)