In [None]:
import dgl
import dgl.function as fn

import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
import pickle

# Load the training and validation molecules.
train_path = "data/zinc_250k_train.pkl"
valid_path = "data/zinc_250k_valid.pkl"

with open(train_path, "rb") as f:
    train_graphs = pickle.load(f)

with open(valid_path, "rb") as f:
    valid_graphs = pickle.load(f)

In [None]:
from utils import to_dgl_graph

# Turn the molecules into DGLGraph, so as to train a model with the Deep Graph Library (DGL).
train_dataset = []
valid_dataset = []

for graph in train_graphs:
    train_dataset.append(to_dgl_graph(graph))

for graph in valid_graphs:
    valid_dataset.append(to_dgl_graph(graph))

nb_node_features = 14

In [None]:
from torch.utils.data import DataLoader

# Create mini-batches
def collate(samples):
    return dgl.batch(samples)

train_loader = DataLoader(train_dataset, batch_size=32, collate_fn=collate, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=32, collate_fn=collate)

In [None]:
from model import VGAE

# Create the model and the optimizer
device = "cuda"
n_epochs = 5
lr = 0.01

model = VGAE(nb_node_features).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

print(model)

In [None]:
from tqdm import tqdm
from utils import recon_loss, kl_loss

def eval(dataloader, model, device):
    eval_loss = 0
    model.eval()

    with torch.no_grad():
        for bg in dataloader:
            bg = bg.to(device)

            n_features = bg.ndata["feats"]
            e_types = bg.edata["type"]

            z, mu, logstd = model.encode(bg, n_features, e_types)
            loss = recon_loss(z, bg) + kl_loss(mu, logstd)
            eval_loss += loss.item()

    model.train()
    return eval_loss / len(dataloader)

# Train the model 
for epoch in range(n_epochs):
    train_loss = 0
    model.train()

    for bg in tqdm(train_loader):
        bg = bg.to(device)
            
        n_features = bg.ndata["feats"]
        e_types = bg.edata["type"]

        z, mu, logstd = model.encode(bg, n_features, e_types)
        loss = recon_loss(z, bg) + kl_loss(mu, logstd)
        train_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    train_loss /= len(train_loader)
    valid_loss = eval(valid_loader, model, device)
        
    print("Epoch {}, train_loss: {:,.3f}, valid_loss: {:,.3f}".format(epoch, train_loss, valid_loss))

In [16]:
# Save model to disk
torch.save(model.state_dict(), "trained_model.pkl")