In [None]:
import dgl
import dgl.function as fn

import torch
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (20, 10)

In [None]:
import pickle

# Load the training and validation molecules.
train_path = "data/zinc_250k_train.pkl"
valid_path = "data/zinc_250k_valid.pkl"

with open(train_path, "rb") as f:
    train_graphs = pickle.load(f)

with open(valid_path, "rb") as f:
    valid_graphs = pickle.load(f)

In [None]:
from utils import to_dgl_graph

# Turn the molecules into DGLGraph, so as to train a model with the Deep Graph Library (DGL).
train_dataset = []
valid_dataset = []

for graph in train_graphs:
    train_dataset.append(to_dgl_graph(graph))

for graph in valid_graphs:
    valid_dataset.append(to_dgl_graph(graph))

nb_node_types = 14
nb_edge_types = 3

In [None]:
from torch.utils.data import DataLoader

# Create mini-batches
def collate(samples):
    return dgl.batch(samples)

train_loader = DataLoader(train_dataset, batch_size=32, collate_fn=collate, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=32, collate_fn=collate)

In [None]:
import model
import importlib
importlib.reload(model)

from model import VGAE

# Create the model and the optimizer
n_epochs = 1
lr = 0.01

model = VGAE(nb_node_types, nb_edge_types)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

print(model)

In [None]:
import utils
import importlib
importlib.reload(utils)

from tqdm import tqdm
from utils import recon_loss, kl_loss

def eval(dataloader, model):
    eval_recon_loss = 0
    eval_kl_loss = 0
    model.eval()

    with torch.no_grad():
        for bg in dataloader:
            n_features = bg.ndata["feats"]
            e_types = bg.edata["type"]
            pred_graphs, mu, logstd = model(bg, n_features, e_types)

            eval_recon_loss += recon_loss(pred_graphs, dgl.unbatch(bg)).item()
            eval_kl_loss += kl_loss(mu, logstd).item()

    model.train()
    return eval_recon_loss / len(dataloader), eval_kl_loss / len(dataloader)

# Train the model 
recon_loss_history = []
kl_loss_history = []

for epoch in range(n_epochs):
    train_loss = 0
    model.train()

    for bg in tqdm(train_loader):
        n_features = bg.ndata["feats"]
        e_types = bg.edata["type"]
        pred_graphs, mu, logstd = model(bg, n_features, e_types)

        train_recon_loss = recon_loss(pred_graphs, dgl.unbatch(bg))
        train_kl_loss = kl_loss(mu, logstd)
        loss = train_recon_loss + train_kl_loss

        recon_loss_history.append(train_recon_loss.item())
        kl_loss_history.append(train_kl_loss.item())
        train_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    train_loss /= len(train_loader)
    valid_recon_loss, valid_kl_loss = eval(valid_loader, model, device)
        
    print("Epoch {}, train_loss: {:,.3f}, valid_recon_loss: {:,.3f}, valid_kl_loss: {:,.3f}".format(epoch, train_loss, valid_recon_loss, valid_kl_loss))
    torch.save(model.state_dict(), "trained_model_epoch_{}.pkl".format(epoch))

In [None]:
# Plot training reconstruction loss
plt.ylabel('Reconstruction loss')
plt.xlabel('Number of processed batches')
plt.plot(recon_loss_history)

In [None]:
# Plot training KL loss
plt.ylabel('KL loss')
plt.xlabel('Number of processed batches')
plt.plot(kl_loss_history)

In [None]:
# Save model to disk
torch.save(model.state_dict(), "trained_model.pkl")