In [1]:
import torch
from torch_geometric.data import DataLoader
import import_ipynb
from GVAE_Dataset import MoleculeDataset
from tqdm import tqdm
import numpy as np
import mlflow.pytorch
from utils import count_parameters, gvae_loss, reconstruction_accuracy
from gvae import GVAE
from config import DEVICE as device

mlflow.set_tracking_uri("http://localhost:5000")


train_dataset = MoleculeDataset(root="data/", filename="HIV_train_oversampled.csv")[:10000]
test_dataset = MoleculeDataset(root="data/", filename="HIV_test.csv", test=True)[:4000]
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)


model = GVAE(feature_size=train_dataset[0].x.shape[1])
model = model.to(device)
print("Model parameters: ", count_parameters(model))


loss_fn = gvae_loss
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
kl_beta = 0.001


def run_one_epoch(data_loader, type, epoch, kl_beta):
     
    all_losses = []
    all_accs=[]
    all_kldivs = []
    
    total_mols=0
    reconstructed_mols=0

    
    for _, batch in enumerate(tqdm(data_loader)):
         
        try:
            
            batch.to(device)  
            
            optimizer.zero_grad() 
            
            triu_logits, mu, logvar = model(batch.x.float(), 
                                                        batch.edge_attr.float(),
                                                        batch.edge_index, 
                                                        batch.batch) 
            
            #edge_targets = slice_edge_type_from_edge_feats(batch.edge_attr.float())
            #node_targets = slice_atom_type_from_node_feats(batch.x.float(), as_index=True)
            loss, kl_div = loss_fn(triu_logits, batch.edge_index, mu, logvar, batch.batch, kl_beta)
            if type == "Train":
                loss.backward()  
                optimizer.step() 
            
            acc,num_recon = reconstruction_accuracy(triu_logits, batch.edge_index,batch.batch,batch.x.float())
            total_mols= total_mols + len(batch.smiles)
            reconstructed_mols=reconstructed_mols+num_recon
            
            all_losses.append(loss.detach().cpu().numpy())
            all_accs.append(acc)
            all_kldivs.append(kl_div.detach().cpu().numpy())
        except IndexError as error:
            
            print("Error: ", error)
    
   
    #if type == "Test":
    #    generated_mols = model.sample_mols(num=10000)
    #    print(f"Generated {generated_mols} molecules.")
    #    mlflow.log_metric(key=f"Sampled molecules", value=float(generated_mols), step=epoch)

    print(f"{type} epoch {epoch} loss: ", np.array(all_losses).mean())
    print(f"{type} epoch {epoch} accuracy: ", np.array(all_accs).mean())
    print(f"Reconstructed {reconstructed_mols} out of {total_mols} molecules ")
    mlflow.log_metric(key=f"{type} Epoch Loss", value=float(np.array(all_losses).mean()), step=epoch)
    mlflow.log_metric(key=f"{type} Epoch Accuracy", value=float(np.array(all_accs).mean()), step=epoch)
    mlflow.log_metric(key=f"{type} Num Reconstructed", value=float(reconstructed_mols), step=epoch)
    mlflow.log_metric(key=f"{type} KL Divergence", value=float(np.array(all_kldivs).mean()), step=epoch)
    mlflow.pytorch.log_model(model, "model")


with mlflow.start_run() as run:
    for epoch in range(100): 
        model.train()
        run_one_epoch(train_loader, type="Train", epoch=epoch, kl_beta=kl_beta)
        if epoch % 5 == 0:
            print("Start test epoch...")
            model.eval()
            run_one_epoch(test_loader, type="Test", epoch=epoch, kl_beta=kl_beta)

importing Jupyter notebook from GVAE_Dataset.ipynb


No normalization for SPS. Feature removed!
No normalization for AvgIpc. Feature removed!


Instructions for updating:
experimental_relax_shapes is deprecated, use reduce_retracing instead


Skipped loading modules with pytorch-geometric dependency, missing a dependency. No module named 'dgl'
  torch.utils._pytree._register_pytree_node(
Skipped loading modules with pytorch-lightning dependency, missing a dependency. No module named 'lightning'
Skipped loading some Jax models, missing a dependency. No module named 'haiku'


Torch version: 2.2.1+cpu
Cuda available: False
Torch geometric version: 2.5.1
importing Jupyter notebook from gvae.ipynb


Processing...
100%|██████████| 71634/71634 [17:37<00:00, 67.74it/s] 
Done!
Processing...
100%|██████████| 3999/3999 [00:36<00:00, 109.70it/s]
Done!


Model parameters:  185473


  0%|          | 0/313 [00:00<?, ?it/s]


NameError: name 'reconstruction_accuracy' is not defined