In [22]:
import torch
import torch_geometric as tg
import torch.nn.functional as F
import torch_scatter as tc
from functools import partial
from tqdm import tqdm
import time

In [23]:
dataset = tg.datasets.ZINC('./data/ZINC', subset=True)

In [24]:
A = next(iter(dataset))

In [25]:
x = A.x
edge_index = A.edge_index
edge_attr = A.edge_attr


In [26]:
from mp import messagePassing 

In [27]:
layer = messagePassing("add")

In [28]:
edge_index.shape

torch.Size([2, 64])

In [29]:
x.shape

torch.Size([29, 1])

In [30]:
class GCNconv(torch.nn.Module):
    def __init__(self, in_channels, out_channels, add_self_loops = True,):
        super(GCNconv, self).__init__()
        self.lin = torch.nn.Linear(in_channels, out_channels)
        self.mp = messagePassing("add")
        self.self_loops = add_self_loops
        
    def forward(self, x, edge_index):
        degrees = torch.sqrt(tg.utils.degree(edge_index[0]) +1 )
        x = x / degrees.view(-1, 1)
        x = self.mp(x, edge_index, self_loop = self.self_loops)
        x = x / degrees.view(-1, 1)
        x = self.lin(x)
        return x


In [31]:
class MLPReadout(torch.nn.Module):

    def __init__(self, input_dim, output_dim, L=2): #L=nb_hidden_layers
        super().__init__()
        list_FC_layers = [ torch.nn.Linear( input_dim//2**l , input_dim//2**(l+1) , bias=True ) for l in range(L) ]
        list_FC_layers.append(torch.nn.Linear( input_dim//2**L , output_dim , bias=True ))
        self.FC_layers = torch.nn.ModuleList(list_FC_layers)
        self.L = L
        
    def forward(self, x):
        y = x
        for l in range(self.L):
            y = self.FC_layers[l](y)
            y = F.relu(y)
        y = self.FC_layers[self.L](y)
        return y

In [32]:
class GCN(torch.nn.Module):
    def __init__(self, in_channels, out_channels, hidden_channels, num_layers, add_self_loops=True, reduce="mean", residual=True):
        super(GCN, self).__init__()
        self.embedding_layer = torch.nn.Embedding(28, in_channels)
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers - 1):
            in_channels = in_channels if i == 0 else hidden_channels
            self.convs.append(GCNconv(in_channels, hidden_channels, add_self_loops))
        self.convs.append(GCNconv(hidden_channels, out_channels, add_self_loops))

        # for now implementing sum reduction
        if reduce == "sum":
            self.reduce = partial(tc.scatter_add, dim=0)
        elif reduce == "mean":
            self.reduce = partial(tc.scatter_mean, dim=0)
        elif reduce == "max":
            self.reduce = partial(tc.scatter_max, dim=0)
        else:
            raise ValueError("Invalid value for reduce")

        self.mlp = MLPReadout(out_channels, 1)
        self.residual = residual

    def forward(self, x, edge_index, ptr=None):
        x = self.embedding_layer(x.view(-1))
        for conv in self.convs:
            x = F.relu(conv(x, edge_index)) + x if self.residual else F.relu(conv(x, edge_index))
        x = self.reduce(x, ptr)
        x = self.mlp(x)
        return x
    


In [33]:
seed=41; epochs=1000; batch_size=128; init_lr=1e-3; lr_reduce_factor=0.5; lr_schedule_patience=10; min_lr = 1e-5; weight_decay=0
L=4; hidden_dim=145; out_dim=hidden_dim; dropout=0.0; readout='mean'

In [34]:
from torch_geometric.data import DataLoader

# Define the batch size
batch_size = 100

# Create loaders for train, validation, and test sets
train_loader = DataLoader(dataset[:8000], batch_size=batch_size, shuffle=True)
val_loader = DataLoader(dataset[8000:9000], batch_size=batch_size, shuffle=False)
test_loader = DataLoader(dataset[9000:], batch_size=batch_size, shuffle=False)


In [35]:
mae = lambda x,y: F.l1_loss(x,y).detach().item()

In [36]:

def train_epoch(model, loss_fn, optimizer,scheduler, device, data_loader, epoch):
    model.train()
    epoch_loss = 0
    epoch_train_mae = 0
    nb_data = 0
    for iter, batch in enumerate(data_loader):
        x = batch.x.to(device)
        edge_index = batch.edge_index.to(device)
        # edge_attr = batch.edge_attr.to(device)
        targets = batch.y.to(device)
        ptr = batch.batch.to(device)
        optimizer.zero_grad()
        batch_scores = model.forward(x, edge_index, ptr)
        loss = loss_fn(batch_scores, targets.view(-1, 1))
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach().item()
        epoch_train_mae += mae(batch_scores, targets.view(-1, 1))
        nb_data += targets.size(0)
    epoch_loss /= (iter + 1)
    epoch_train_mae /= (iter + 1)
    
    return epoch_loss, epoch_train_mae, optimizer

In [37]:
def evaluate_epoch(model, loss_fn, device, data_loader):
    model.eval()
    epoch_loss = 0
    epoch_mae = 0
    nb_data = 0
    with torch.no_grad():
        for iter, batch in enumerate(data_loader):
            x = batch.x.to(device)
            edge_index = batch.edge_index.to(device)
            # edge_attr = batch.edge_attr.to(device)
            targets = batch.y.to(device)
            ptr = batch.batch.to(device)
            batch_scores = model.forward(x, edge_index, ptr)
            loss = loss_fn(batch_scores, targets.view(-1, 1))
            epoch_loss += loss.detach().item()
            epoch_mae += mae(batch_scores, targets.view(-1, 1))
            nb_data += targets.size(0)
        epoch_loss /= (iter + 1)
        epoch_mae /= (iter + 1)
    return epoch_loss, epoch_mae

In [38]:
def evaluate_model(model, loss_fn, device, data_loader):
    epoch_loss, epoch_mae = evaluate_epoch(model, loss_fn, device, data_loader)
    print(f'Loss: {epoch_loss}, MAE: {epoch_mae}')
    return epoch_loss, epoch_mae

In [39]:
import time
from collections import OrderedDict
def train(model, loss_fn, optimizer, scheduler, device, train_loader, val_loader, test_loader, epochs):
    start = time.time()
    logs = { 'train_loss': [], 'val_loss': [], 'test_loss': [], 'train_mae': [], 'val_mae': [], 'test_mae': [], "time_per_epoch": []}
    with tqdm(range(epochs)) as pbar:
        for epoch in pbar:
            pbar.set_description(f'Epoch {epoch}')
            epoch_loss, epoch_train_mae, optimizer = train_epoch(model, loss_fn, optimizer, scheduler, device, train_loader, epoch)
            val_loss, val_mae = evaluate_epoch(model, loss_fn, device, val_loader)
            test_loss, test_mae = evaluate_epoch(model, loss_fn, device, test_loader)
            postfix_dict = OrderedDict(
                time = time.time() - start,
                lr=optimizer.param_groups[0]['lr'],
                train_loss=epoch_loss,
                val_loss=val_loss,
                test_loss=test_loss,
                train_mae=epoch_train_mae,
                val_mae=val_mae,
                test_mae=test_mae
            )
            pbar.set_postfix(postfix_dict)
            scheduler.step(val_loss)
            logs['train_loss'].append(epoch_loss)
            logs['val_loss'].append(val_loss)
            logs['test_loss'].append(test_loss)
            logs['train_mae'].append(epoch_train_mae)
            logs['val_mae'].append(val_mae)
            logs['test_mae'].append(test_mae)
            logs['time_per_epoch'].append(time.time() - start)
            if optimizer.param_groups[0]['lr'] <= min_lr:
                print('Early stopping')
                break
    return logs

In [40]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN(in_channels=hidden_dim, hidden_channels=hidden_dim, out_channels=out_dim, num_layers=L, reduce=readout).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=init_lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=lr_reduce_factor, patience=lr_schedule_patience, min_lr=min_lr,verbose=True)
loss = torch.nn.L1Loss().to(device)

In [41]:
GCN_logs = []
for _ in range(5):
    model = GCN(in_channels=hidden_dim, hidden_channels=hidden_dim, out_channels=out_dim, num_layers=L, reduce=readout).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=init_lr, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=lr_reduce_factor, patience=lr_schedule_patience, min_lr=min_lr,verbose=True)
    loss = torch.nn.L1Loss().to(device)
    GCN_logs.append(train(model, loss, optimizer, scheduler, device, train_loader, val_loader, test_loader, epochs  ))

Epoch 236:  24%|██▎       | 236/1000 [04:28<14:27,  1.14s/it, time=268, lr=1.56e-5, train_loss=0.411, val_loss=0.47, test_loss=0.434, train_mae=0.411, val_mae=0.47, test_mae=0.434]   


Early stopping


Epoch 225:  22%|██▎       | 225/1000 [04:26<15:19,  1.19s/it, time=267, lr=1.56e-5, train_loss=0.405, val_loss=0.472, test_loss=0.425, train_mae=0.405, val_mae=0.472, test_mae=0.425] 


Early stopping


Epoch 242:  24%|██▍       | 242/1000 [05:07<16:02,  1.27s/it, time=307, lr=1.56e-5, train_loss=0.386, val_loss=0.468, test_loss=0.431, train_mae=0.386, val_mae=0.468, test_mae=0.431] 


Early stopping


Epoch 241:  24%|██▍       | 241/1000 [05:30<17:20,  1.37s/it, time=330, lr=1.56e-5, train_loss=0.366, val_loss=0.46, test_loss=0.417, train_mae=0.366, val_mae=0.46, test_mae=0.417]   


Early stopping


Epoch 219:  22%|██▏       | 219/1000 [05:22<19:08,  1.47s/it, time=322, lr=1.56e-5, train_loss=0.411, val_loss=0.476, test_loss=0.43, train_mae=0.411, val_mae=0.476, test_mae=0.43]   

Early stopping





In [42]:
class GATconv(torch.nn.Module):
    def __init__(self, in_channels, out_channels, heads = 1, add_self_loops = True, f_additive = True):
        super(GATconv, self).__init__()
        # print( out_channels, heads, out_channels//heads)
        assert out_channels % heads == 0
        
        self.lin = torch.nn.Linear(in_channels, out_channels, bias=False)
        self.att_src = torch.nn.Parameter(torch.empty(1, heads, out_channels//heads))
        self.att_dst = torch.nn.Parameter(torch.empty(1, heads, out_channels//heads))
        self.heads = heads
        self.self_loops = add_self_loops
        self.reset_parameters()
        self.f_additive = f_additive
        
    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.lin.weight)
        torch.nn.init.xavier_uniform_(self.att_src)
        torch.nn.init.xavier_uniform_(self.att_dst)

        
    def forward(self, x, edge_index):
        if self.self_loops:
            edge_index, _ = tg.utils.add_self_loops(edge_index, num_nodes=x.size(0))
        x = self.lin(x).view(x.size(0), self.heads, -1)
        e_ij = self.prepareEdgeWeights(x, edge_index)
        y = self.propagate(x, edge_index, e_ij)
        return y.view(x.size(0), -1)
        
    def prepareEdgeWeights(self,x ,edge_index):
        alpha_dst = (x * self.att_dst).sum(dim=-1)
        alpha_src = (x * self.att_src).sum(dim=-1)
        alpha = torch.index_select(alpha_src, 0, edge_index[0]) + torch.index_select(alpha_dst, 0, edge_index[1])
        alpha = F.leaky_relu(alpha, 0.2)
        e_ij = tg.utils.softmax(alpha, edge_index[1])
        return e_ij
    def propagate(self, x, edge_index, e_ij):
        temp = torch.index_select(x, 0, edge_index[0])
        edge_weights = e_ij.unsqueeze(-1) + 1 if self.f_additive else e_ij.unsqueeze(-1)
        temp = temp * edge_weights
        return tc.scatter_add(temp, edge_index[1], dim=0)

In [43]:
class GAT(torch.nn.Module):
    def __init__(self, in_channels, out_channels, hidden_channels, num_layers, heads, add_self_loops = True, reduce = "mean", residual = True, f_additive = True):
        super(GAT, self).__init__()
        self.embedding_layer = torch.nn.Embedding(28, in_channels)
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers-1):
            in_channels = in_channels if i == 0 else hidden_channels
            self.convs.append(GATconv(in_channels, hidden_channels, heads, add_self_loops, f_additive))
        self.convs.append(GATconv(hidden_channels, out_channels, heads, add_self_loops, f_additive))
        # for now implementing sum reduction
        if reduce == "sum":
            self.reduce = partial(tc.scatter_add, dim=0)
        elif reduce == "mean":
            self.reduce = partial(tc.scatter_mean, dim=0)
        elif reduce == "max":
            self.reduce = partial(tc.scatter_max, dim=0)
        else:
            raise ValueError("Invalid value for reduce")
        self.mlp = MLPReadout(out_channels, 1)
        self.residual = residual

    def forward(self, x, edge_index, ptr=None):
        x = self.embedding_layer(x.view(-1))
        a = 0
        for conv in self.convs:
            x = F.elu(conv(x, edge_index))
        x = self.reduce(x, ptr)
        x = self.mlp(x)
        return x

In [44]:
net_params =  {
        "L": 4,
        "hidden_dim": 18,
        "out_dim": 144,
        "residual": True,
        "readout": "mean",
        "n_heads": 8,
        "in_feat_dropout": 0.0,
        "dropout": 0.0,
        "batch_norm": True,
        "self_loop": False
    }

In [45]:
hidden_dim = 144

In [46]:
out_dim = 144

In [47]:
GAT_model = GAT(in_channels=hidden_dim, hidden_channels=hidden_dim, out_channels=out_dim, num_layers=L, heads=2, reduce=readout).to(device)
optimizer = torch.optim.Adam(GAT_model.parameters(), lr=init_lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=lr_reduce_factor, patience=lr_schedule_patience, min_lr=min_lr,verbose=True)
loss = torch.nn.L1Loss().to(device)


In [48]:
gat_logs = []
for _ in range(5):
    GAT_model = GAT(in_channels=hidden_dim, hidden_channels=hidden_dim, out_channels=out_dim, num_layers=L, heads=8, reduce=readout).to(device)
    optimizer = torch.optim.Adam(GAT_model.parameters(), lr=init_lr, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=lr_reduce_factor, patience=lr_schedule_patience, min_lr=min_lr,verbose=True)
    loss = torch.nn.L1Loss().to(device)
    gat_logs.append(train(GAT_model, loss, optimizer, scheduler, device, train_loader, val_loader, test_loader, epochs))

Epoch 0:   0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 177:  18%|█▊        | 177/1000 [05:51<27:13,  1.99s/it, time=351, lr=1.56e-5, train_loss=0.239, val_loss=0.421, test_loss=0.41, train_mae=0.239, val_mae=0.421, test_mae=0.41]   


Early stopping


Epoch 153:  15%|█▌        | 153/1000 [05:06<28:17,  2.00s/it, time=307, lr=1.56e-5, train_loss=0.275, val_loss=0.432, test_loss=0.402, train_mae=0.275, val_mae=0.432, test_mae=0.402] 


Early stopping


Epoch 163:  16%|█▋        | 163/1000 [05:20<27:27,  1.97s/it, time=321, lr=1.56e-5, train_loss=0.268, val_loss=0.434, test_loss=0.402, train_mae=0.268, val_mae=0.434, test_mae=0.402] 


Early stopping


Epoch 165:  16%|█▋        | 165/1000 [02:10<11:02,  1.26it/s, time=131, lr=1.56e-5, train_loss=0.244, val_loss=0.426, test_loss=0.393, train_mae=0.244, val_mae=0.426, test_mae=0.393]  


Early stopping


Epoch 159:  16%|█▌        | 159/1000 [02:05<11:06,  1.26it/s, time=126, lr=1.56e-5, train_loss=0.26, val_loss=0.425, test_loss=0.402, train_mae=0.26, val_mae=0.425, test_mae=0.402]    

Early stopping





In [49]:
class set2setReadout(torch.nn.Module):
    def __init__(self, input_dim, n_iters, n_layers=1):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = input_dim * 2
        self.n_iters = n_iters
        self.n_layers = n_layers
        self.LSTM_layers = torch.nn.LSTM(self.output_dim, input_dim, num_layers=n_layers)
        self.reset_parameters()
    
    def reset_parameters(self):
        self.LSTM_layers.reset_parameters()

        
    def forward(self, x, ptr):
        batch_size = ptr.max().item() + 1
        q = x.new_zeros((batch_size, self.output_dim)).to(x.device)
        h = (x.new_zeros((self.n_layers, batch_size, self.input_dim)).to(x.device), x.new_zeros((self.n_layers, batch_size, self.input_dim)).to(x.device))
        for _ in range(self.n_iters):
            q, h = self.LSTM_layers(q.unsqueeze(0), h)
            q = q.view(batch_size, self.input_dim)
            e = (x * q[ptr]).sum(dim=-1, keepdim=True)
            alpha = tg.utils.softmax(e, ptr, dim=-2)
            r = tc.scatter_add(alpha * x, ptr, dim=-2)
            q = torch.cat([q, r], dim=-1)
        return q

In [50]:
class GAT_set2set(torch.nn.Module):
    def __init__(self, in_channels, out_channels, hidden_channels, num_layers, heads, add_self_loops = True, reduce = "mean", n_iters=5, f_additive = True):
        super(GAT_set2set, self).__init__()
        self.embedding_layer = torch.nn.Embedding(28, in_channels)
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers-1):
            in_channels = in_channels if i == 0 else hidden_channels
            self.convs.append(GATconv(in_channels, hidden_channels, heads, add_self_loops, f_additive))
        self.convs.append(GATconv(hidden_channels, out_channels//2, heads, add_self_loops, f_additive))
        self.set2set = set2setReadout(out_channels//2, n_iters)
        self.mlp = MLPReadout(out_channels , 1,1)

    def forward(self, x, edge_index, ptr=None):
        x = self.embedding_layer(x.view(-1))
        for conv in self.convs:
            x = F.elu(conv(x, edge_index))
        x = self.set2set(x, ptr)
        x = F.relu(x)
        x = self.mlp(x)
        return x

In [51]:
GAT_model_set2set = GAT_set2set(in_channels=hidden_dim, hidden_channels=hidden_dim//2, out_channels=out_dim, num_layers=L, heads=8, reduce=readout, n_iters= 10).to(device)
optimizer = torch.optim.Adam(GAT_model_set2set.parameters(), lr=init_lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=lr_reduce_factor, patience=lr_schedule_patience, min_lr=min_lr,verbose=True)
loss = torch.nn.L1Loss().to(device)


In [52]:
gat_logs_set2set =[]
for _ in range(5):
    GAT_model_set2set = GAT_set2set(in_channels=hidden_dim, hidden_channels=hidden_dim//2, out_channels=out_dim, num_layers=L, heads=8, reduce=readout, n_iters= 10).to(device)
    optimizer = torch.optim.Adam(GAT_model_set2set.parameters(), lr=init_lr, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=lr_reduce_factor, patience=lr_schedule_patience, min_lr=min_lr,verbose=True)
    loss = torch.nn.L1Loss().to(device)

    gat_logs_set2set.append(train(GAT_model_set2set, loss, optimizer, scheduler, device, train_loader, val_loader, test_loader, epochs))

Epoch 158:  16%|█▌        | 158/1000 [04:05<21:47,  1.55s/it, time=245, lr=1.56e-5, train_loss=0.952, val_loss=0.937, test_loss=0.957, train_mae=0.952, val_mae=0.937, test_mae=0.957]


Early stopping


Epoch 134:  13%|█▎        | 134/1000 [03:27<22:21,  1.55s/it, time=208, lr=1.56e-5, train_loss=0.962, val_loss=0.962, test_loss=0.963, train_mae=0.962, val_mae=0.962, test_mae=0.963]


Early stopping


Epoch 112:  11%|█         | 112/1000 [02:52<22:49,  1.54s/it, time=173, lr=1.56e-5, train_loss=1.1, val_loss=1.13, test_loss=1.05, train_mae=1.1, val_mae=1.13, test_mae=1.05]  


Early stopping


Epoch 165:  16%|█▋        | 165/1000 [04:14<21:25,  1.54s/it, time=254, lr=1.56e-5, train_loss=0.903, val_loss=0.902, test_loss=0.895, train_mae=0.903, val_mae=0.902, test_mae=0.895] 


Early stopping


Epoch 256:  26%|██▌       | 256/1000 [06:36<19:10,  1.55s/it, time=396, lr=1.56e-5, train_loss=0.597, val_loss=0.645, test_loss=0.598, train_mae=0.597, val_mae=0.645, test_mae=0.598] 

Early stopping





In [53]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [54]:
models = [model, GAT_model, GAT_model_set2set]
logs = [GCN_logs, gat_logs, gat_logs_set2set]
parameter_counts = [count_parameters(model) for model in models]

In [55]:
parameter_counts

[101917, 101233, 103825]

In [56]:
import pickle 
pickle.dump(logs, open("logs_zinc_additive.pkl", "wb"))