In [1]:
from tqdm import tqdm

import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch.nn import functional as F

from torch.nn import Module, Sequential, Linear, ReLU

import torch_geometric
from torch_geometric.nn import MessagePassing, global_mean_pool

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


# Data

In [79]:
from torch_geometric.datasets import ZINC
from torch_geometric.loader import DataLoader

root = "./data/zinc"

# mean and std of training set
MEAN = torch.Tensor([0.0153])
STD = torch.Tensor([2.0109])

def transform(data):
    data.x = F.one_hot(data.x, 28).squeeze(1).float()
    data.y = (data.y - MEAN) / STD
    data.edge_attr = F.one_hot(data.edge_attr - 1, 3).squeeze(1).float()
    return data

# train_dataset = ZINC(root, split="train", transform=transform)  # subset=False
# val_dataset = ZINC(root, split='val', transform=transform)
# test_dataset = ZINC(root, split="test", transform=transform)

train_dataset = ZINC(root, subset=True, split="train", transform=transform)
val_dataset = ZINC(root, subset=True, split='val', transform=transform)
test_dataset = ZINC(root, subset=True, split="test", transform=transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

print(f"train: {len(train_dataset)}")
print(f"val: {len(val_dataset)}")
print(f"test: {len(test_dataset)}")

train: 10000
val: 1000
test: 1000


# Train

In [14]:
from torch.optim import Adam
from torch.optim.lr_scheduler import ExponentialLR

def experiment(
        model,
        train_loader, val_loader,
        test_loader=None,
        num_epoch=100,
        ):

    MEAN = torch.Tensor([0.0153]).to(device)
    STD = torch.Tensor([2.0109]).to(device)
    def unwhiten(y):
        return y * STD + MEAN

    # regression loss
    loss_fn = nn.MSELoss()
    eval_fn = nn.L1Loss()

    def train():
        model.train()

        loss_sum = 0
        score_sum = 0
        for data in train_loader:
            data = data.to(device)

            optimizer.zero_grad()
            y_pred = model(data)
            loss = loss_fn(y_pred, data.y[:, None])  # mean loss
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
            optimizer.step()

            score = eval_fn(unwhiten(y_pred), unwhiten(data.y[:, None]))
            loss_sum += loss.item() * data.num_graphs
            score_sum += score.item() * data.num_graphs

        train_loss = loss_sum / len(train_loader.dataset)
        train_score = score_sum / len(train_loader.dataset)
        return train_loss, train_score
    
    def eval(loader):  # for validation / test
        model.eval()
        loss = 0
        score = 0
        for data in loader:
            data = data.to(device)
            with torch.no_grad():
                y_pred = model(data)
                loss += loss_fn(y_pred, data.y[:, None]).item() * data.num_graphs
                score += eval_fn(unwhiten(y_pred), unwhiten(data.y[:, None])).item() * data.num_graphs
        loss = loss / len(loader.dataset)
        score = score / len(loader.dataset)
        return loss, score

    optimizer = Adam(model.parameters(), lr=1e-3)
    scheduler = ExponentialLR(optimizer, 0.95)

    train_loss_hist = []
    train_score_hist = []
    val_loss_hist = []
    val_score_hist = []
    
    val_loss_best = np.inf
    epoch_best = None

    for epoch in range(1, num_epoch + 1):
        
        # train
        train_loss, train_score = train()
        train_loss_hist.append(train_loss)
        train_score_hist.append(train_score)

        # validation
        val_loss, val_score = eval(val_loader)
        val_loss_hist.append(val_loss)
        val_score_hist.append(val_score)

        if val_loss < val_loss_best:
            val_loss_best = val_loss
            val_score_best = val_score
            epoch_best = epoch

            if test_loader is not None:
                test_loss, test_score = eval(test_loader)

        if epoch == 1 or epoch % 1 == 0:
            output = f"Epoch {epoch}, train loss {train_loss:.4f}, val loss {val_loss:.4f}; Best Epoch {epoch_best}, val score = {val_score_best:.4f}"
            if test_loader is not None:
                output = output + f", test score = {test_score:.4f}"
            print(output)
        
        scheduler.step()

    ret = {
        "train_loss_hist": train_loss_hist,
        "train_score_hist": train_score_hist,
        "val_loss_hist": val_loss_hist,
        "val_score_hist": val_score_hist,
        "best_epoch": epoch_best,
        "val_loss": val_loss_best,
        "val_score": val_score_best,
    }
    if test_loader is not None:
        ret["test_score"] = test_score
    return ret

# Models

In [84]:
class Regression_Head(Module):
    def __init__(self, in_dim, out_dim, hid_dim):
        super().__init__()
        self.mlp = Sequential(
            Linear(in_dim, hid_dim),
            ReLU(),
            Linear(hid_dim, out_dim)
        )
    
    def forward(self, x):
        x = self.mlp(x)
        return x

## DE-MF

In [None]:
class DE_Layer(MessagePassing):
    def __init__(self, emb_dim, num_iter):
        super().__init__(aggr='add')
        self.num_iter = num_iter
        
        self.mlp_update = Sequential(
            Linear(2 * emb_dim, emb_dim),
            ReLU(),
            Linear(emb_dim, emb_dim),
        )
    
    def forward(self, h, edge_index, fixed_x=None):
        if fixed_x is not None:
            x = fixed_x
        else:
            x = h

        for i in range(self.num_iter):
            h = self.propagate(edge_index, h=h, x=x)

        return h
    
    def message(self, h_j):
        return h_j
    
    def update(self, aggr_out, x):
        return self.mlp_update(torch.cat([x, aggr_out], dim=-1))

In [None]:
class DE_Model(Module):
    def __init__(self, in_dim, out_dim, emb_dim=64, hid_dim=64, num_iter=4):
        super().__init__()

        self.lin_in = Linear(in_dim, emb_dim)
        self.emb = DE_Layer(emb_dim, num_iter)
        self.pool = global_mean_pool
        self.head = Regression_Head(emb_dim, out_dim, hid_dim)
    
    def forward(self, data):
        h = self.lin_in(data.x)
        h = self.emb(h, data.edge_index)
        h = self.pool(h, data.batch)
        h = self.head(h)
        return h

## DE-MF + edge feature

In [None]:
class DE_Layer(MessagePassing):
    def __init__(self, emb_dim, num_iter):
        super().__init__(aggr='add')
        self.num_iter = num_iter
        
        self.mlp_update = Sequential(
            Linear(2 * emb_dim, emb_dim),
            ReLU(),
            Linear(emb_dim, emb_dim),
        )

        self.mlp_message = Sequential(
            Linear(emb_dim + 3, emb_dim),
            ReLU(),
            Linear(emb_dim, emb_dim)
        )
    
    def forward(self, h, edge_index, edge_attr, fixed_x=None):
        if fixed_x is not None:
            x = fixed_x
        else:
            x = h

        for i in range(self.num_iter):
            h = self.propagate(edge_index, h=h, x=x, edge_attr=edge_attr)

        return h
    
    def message(self, h_j, edge_attr):
        return self.mlp_message(torch.cat([h_j, edge_attr], dim=-1))
    
    def update(self, aggr_out, x):
        return self.mlp_update(torch.cat([x, aggr_out], dim=-1))

## MPNN

In [None]:
class MP_Layer(MessagePassing):
    def __init__(self, emb_dim):
        super().__init__(aggr='add')
        
        self.mlp_update = Sequential(
            Linear(2 * emb_dim, emb_dim),
            ReLU(),
            Linear(emb_dim, emb_dim),
        )
    
    def forward(self, h, edge_index, fixed_x=None):
        if fixed_x is not None:
            x = fixed_x
        else:
            x = h
            
        h = self.propagate(edge_index, h=h, x=x)

        return h
    
    def message(self, h_j):
        return h_j
    
    def update(self, aggr_out, x):
        return self.mlp_update(torch.cat([x, aggr_out], dim=-1))

In [None]:
class MP_Model(Module):
    def __init__(self, in_dim, out_dim, num_layer=4, emb_dim=256, hid_dim=256):
        super().__init__()

        self.lin_in = Linear(in_dim, emb_dim)
        
        self.embs = nn.ModuleList()
        for layer in range(num_layer):
            self.embs.append(MP_Layer(emb_dim))

        self.pool = global_mean_pool
        self.head = Regression_Head(emb_dim, out_dim, hid_dim)
    
    def forward(self, data):
        h = self.lin_in(data.x)
        for emb in self.embs:
            h = emb(h, data.edge_index)
        h = self.pool(h, data.batch)
        h = self.head(h)
        return h

## Model 2

In [None]:
class MP_reuse(Module):
    def __init__(self, in_dim, out_dim, num_layer=4, emb_dim=256, hid_dim=256):
        super().__init__()
        self.num_layer = num_layer

        self.lin_in = Linear(in_dim, emb_dim)
        self.emb = MP_Layer(emb_dim)
        self.pool = global_mean_pool
        self.head = Regression_Head(emb_dim, out_dim, hid_dim)
    
    def forward(self, data):
        h = self.lin_in(data.x)
        for layer in range(self.num_layer):
            h = self.emb(h, data.edge_index)
        h = self.pool(h, data.batch)
        h = self.head(h)
        return h

## Model 3

In [None]:
class MP_FX(Module):
    def __init__(self, in_dim, out_dim, num_layer=4, emb_dim=256, hid_dim=256):
        super().__init__()

        self.lin_in = Linear(in_dim, emb_dim)
        
        self.embs = nn.ModuleList()
        for layer in range(num_layer):
            self.embs.append(MP_Layer(emb_dim))

        self.pool = global_mean_pool
        self.head = Regression_Head(emb_dim, out_dim, hid_dim)
    
    def forward(self, data):
        h = self.lin_in(data.x)

        fixed_x = h
        for emb in self.embs:
            h = emb(h, data.edge_index, fixed_x=fixed_x)

        h = self.pool(h, data.batch)
        h = self.head(h)
        return h

## Stacked DE-MF

In [None]:
class DE_Model(Module):
    def __init__(self, in_dim, out_dim, num_iter=4, emb_dim=256, hid_dim=256, num_layer=1):
        super().__init__()

        self.lin_in = Linear(in_dim, emb_dim)

        self.embs = nn.ModuleList()
        for layer in range(num_layer):
            self.embs.append(DE_Layer(emb_dim, num_iter))

        self.pool = global_mean_pool
        self.head = Regression_Head(emb_dim, out_dim, hid_dim)
    
    def forward(self, data):
        h = self.lin_in(data.x)
        for emb in self.embs:
            h = emb(h, data.edge_index)
        h = self.pool(h, data.batch)
        h = self.head(h)
        return h

## Iterated MPNN

In [None]:
class MP_Model(Module):
    def __init__(self, in_dim, out_dim, num_layer=4, emb_dim=256, hid_dim=256):
        super().__init__()

        self.lin_in = Linear(in_dim, emb_dim)
        
        self.embs = nn.ModuleList()
        for layer in range(num_layer):
            self.embs.append(MP_Layer(emb_dim))

        self.pool = global_mean_pool
        self.head = Regression_Head(emb_dim, out_dim, hid_dim)
    
    def forward(self, data):
        h = self.lin_in(data.x)
        for emb in self.embs:
            h = emb(h, data.edge_index)
        h = self.pool(h, data.batch)
        h = self.head(h)
        return h