In [1]:
import torch
import torch_geometric as pyg
import torch.nn.functional as F
import torch_scatter as tc
from functools import partial
from tqdm import tqdm
import time
from torch_geometric.loader import DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [15]:
global_params = {
        "seed": 41,
        "epochs": 1000,
        "batch_size": 128,
        "init_lr": 0.001,
        "lr_reduce_factor": 0.5,
        "lr_schedule_patience": 10,
        "min_lr": 1e-5,
        "weight_decay": 0.0,
        "print_epoch_interval": 5,
        "max_time": 12
    }
params_gcn = {
    "L": 4,
    "hidden_dim": 146,
    "out_dim": 10,
    "residual": True,
    "readout": "mean",
    "in_feat_dropout": 0.0,
    "dropout": 0.0,
    "batch_norm": True,
    "self_loop": True
}

In [16]:
cifar10_dataset = pyg.datasets.GNNBenchmarkDataset(root='data/', name='CIFAR10')
train_dataset = pyg.datasets.GNNBenchmarkDataset(root='data/', name='CIFAR10', split='train')
test_dataset = pyg.datasets.GNNBenchmarkDataset(root='data/', name='CIFAR10', split='test')
val_dataset = pyg.datasets.GNNBenchmarkDataset(root='data/', name='CIFAR10', split='val')

In [17]:
train_dataloader = DataLoader(train_dataset, batch_size=global_params["batch_size"], shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=global_params["batch_size"], shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=global_params["batch_size"], shuffle=False)


In [24]:
import torch
import torch_geometric as tg
import torch.nn.functional as F
import torch_scatter as tc
import torch_scatter
from enum import Enum
from functools import partial

class AggregationType(Enum):
    ADD = 'add'
    MEAN = 'mean'
    MAX = 'max'
    MIN = 'min'
    MUL = 'mul'


class messagePassing(torch.nn.Module):
    def __init__(self, agg: AggregationType = AggregationType.ADD, **kwargs):
        super(messagePassing, self).__init__()
        self.agg = partial(torch_scatter.scatter, reduce = agg)
    
    def forward(self, x, edge_index, edge_attr = None, self_loop = False):
        # first gather the values from the source nodes
        temp = torch.index_select(x, 0, edge_index[0])
        return self.agg(temp, edge_index[1], dim = 0, out = x) if self_loop else self.agg(temp, edge_index[1], dim = 0) 

class GCNconv(torch.nn.Module):
    def __init__(self, in_channels, out_channels, add_self_loops = True,residual = True):
        super(GCNconv, self).__init__()
        self.lin = torch.nn.Linear(in_channels, out_channels)
        self.mp = messagePassing("add")
        self.self_loops = add_self_loops
        
    def forward(self, x, edge_index):
        degrees = torch.sqrt(tg.utils.degree(edge_index[0]) +1 )
        x = x / degrees.view(-1, 1)
        x = self.mp(x, edge_index, self_loop = self.self_loops)
        x = x / degrees.view(-1, 1)
        x = self.lin(x)
        return x
    
class MLPReadout(torch.nn.Module):

    def __init__(self, input_dim, output_dim, L=2): #L=nb_hidden_layers
        super().__init__()
        list_FC_layers = [ torch.nn.Linear( input_dim//2**l , input_dim//2**(l+1) , bias=True ) for l in range(L) ]
        list_FC_layers.append(torch.nn.Linear( input_dim//2**L , output_dim , bias=True ))
        self.FC_layers = torch.nn.ModuleList(list_FC_layers)
        self.L = L
        
    def forward(self, x):
        y = x
        for l in range(self.L):
            y = self.FC_layers[l](y)
            y = F.relu(y)
        y = self.FC_layers[self.L](y)
        return y
class GCN(torch.nn.Module):
    def __init__(self, in_channels, out_channels, hidden_channels, num_layers, add_self_loops = True, reduce = "mean", n_iter = None, embedding = True, residual = True):
        super(GCN, self).__init__()
        if embedding:
            self.embedding_layer = torch.nn.Embedding(in_channels, hidden_channels)
        else:
            self.embedding_layer = torch.nn.Linear(in_channels, hidden_channels)

        self.convs = torch.nn.ModuleList()
        for i in range(num_layers-1):
            self.convs.append(GCNconv(hidden_channels, hidden_channels, add_self_loops))
        self.convs.append(GCNconv(hidden_channels, hidden_channels, add_self_loops))
        self.residual = residual

        # for now implementing sum reduction
        match reduce:
            case "sum":
                self.reduce = partial(tc.scatter_add, dim = 0)
            case "mean":
                self.reduce = partial(tc.scatter_mean, dim = 0)
            case "max":
                self.reduce = partial(tc.scatter_max, dim = 0)
            case "set2set":
                self.reduce = set2setReadout(hidden_channels, n_iter)
            case _:
                raise ValueError("Invalid value for reduce")
        self.mlp = MLPReadout(hidden_channels, out_channels) if reduce != "set2set" else MLPReadout(hidden_channels*2, out_channels)

    def forward(self, x, edge_index, ptr=None):
        x = self.embedding_layer(x)
        for conv in self.convs:
            x = F.relu(conv(x, edge_index)) + x if self.residual else F.relu(conv(x, edge_index))
        x = self.reduce(x, ptr)
        x = self.mlp(x)
        return x
    
class GATconv(torch.nn.Module):
    def __init__(self, in_channels, out_channels, heads = 1, add_self_loops = True):
        super(GATconv, self).__init__()
        assert out_channels % heads == 0
        self.lin = torch.nn.Linear(in_channels, out_channels, bias=False)
        self.att_src = torch.nn.Parameter(torch.empty(1, heads, out_channels//heads))
        self.att_dst = torch.nn.Parameter(torch.empty(1, heads, out_channels//heads))
        self.heads = heads
        self.self_loops = add_self_loops
        self.reset_parameters()
        
    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.lin.weight)
        torch.nn.init.xavier_uniform_(self.att_src)
        torch.nn.init.xavier_uniform_(self.att_dst)

        
    def forward(self, x, edge_index):
        if self.self_loops:
            edge_index, _ = tg.utils.add_self_loops(edge_index, num_nodes=x.size(0))
        x = self.lin(x).view(x.size(0), self.heads, -1)
        e_ij = self.prepareEdgeWeights(x, edge_index)
        y = self.propagate(x, edge_index, e_ij)
        return y.view(x.size(0), -1)
        
    def prepareEdgeWeights(self,x ,edge_index):
        alpha_dst = (x * self.att_dst).sum(dim=-1)
        alpha_src = (x * self.att_src).sum(dim=-1)
        alpha = torch.index_select(alpha_src, 0, edge_index[0]) + torch.index_select(alpha_dst, 0, edge_index[1])
        alpha = F.leaky_relu(alpha, 0.2)
        e_ij = tg.utils.softmax(alpha, edge_index[1])
        return e_ij
    def propagate(self, x, edge_index, e_ij):
        temp = torch.index_select(x, 0, edge_index[0])
        temp = temp * e_ij.unsqueeze(-1)
        return tc.scatter_add(temp, edge_index[1], dim=0)
    
class GAT(torch.nn.Module):
    def __init__(self, in_channels, out_channels, hidden_channels, num_layers, heads, add_self_loops = True, reduce = "mean", n_iter = None, embedding = True):
        super(GAT, self).__init__()
        if embedding:
            self.embedding_layer = torch.nn.Embedding(in_channels, hidden_channels)
        else:
            self.embedding_layer = torch.nn.linear(in_channels, hidden_channels)
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers-1):
            self.convs.append(GATconv(hidden_channels, hidden_channels, heads, add_self_loops))
        self.convs.append(GATconv(hidden_channels, out_channels, heads, add_self_loops))
        # for now implementing sum reduction
        match reduce:
            case "sum":
                self.reduce = partial(tc.scatter_add, dim = 0)
            case "mean":
                self.reduce = partial(tc.scatter_mean, dim = 0)
            case "max":
                self.reduce = partial(tc.scatter_max, dim = 0)
            case "set2set":
                self.reduce = set2setReadout(out_channels, n_iter)
            case _:
                raise ValueError("Invalid value for reduce")
        self.mlp = MLPReadout(out_channels, 1) if reduce != "set2set" else MLPReadout(out_channels*2, 1)

    def forward(self, x, edge_index, ptr=None):
        x = self.embedding_layer(x.view(-1))
        a = 0
        for conv in self.convs:
            x = F.relu(conv(x, edge_index))
        x = self.reduce(x, ptr)
        x = self.mlp(x)
        return x
    
class set2setReadout(torch.nn.Module):
    def __init__(self, input_dim, n_iters, n_layers=2):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = input_dim * 2
        self.n_iters = n_iters
        self.n_layers = n_layers
        self.LSTM_layers = torch.nn.LSTM(self.output_dim, input_dim, num_layers=n_layers)
        self.reset_parameters()
    
    def reset_parameters(self):
        self.LSTM_layers.reset_parameters()

        
    def forward(self, x, ptr):
        batch_size = ptr.max().item() + 1
        q = x.new_zeros((batch_size, self.output_dim)).to(x.device)
        h = (x.new_zeros((self.n_layers, batch_size, self.input_dim)).to(x.device), x.new_zeros((self.n_layers, batch_size, self.input_dim)).to(x.device))
        for _ in range(self.n_iters):
            q, h = self.LSTM_layers(q.unsqueeze(0), h)
            q = q.view(batch_size, self.input_dim)
            e = (x * q[ptr]).sum(dim=-1, keepdim=True)
            alpha = tg.utils.softmax(e, ptr, dim=-2)
            r = tc.scatter_add(alpha * x, ptr, dim=-2)
            q = torch.cat([q, r], dim=-1)
        return q


In [25]:
import torch.nn.functional as F
import torch
import time
from tqdm import tqdm
from torchmetrics.classification import  MulticlassAccuracy
def train_epoch(model, loss_fn, optimizer,scheduler, device, data_loader, epoch, metric):
    model.train()
    epoch_loss = 0
    epoch_train_metric = 0
    nb_data = 0
    for iter, batch in enumerate(data_loader):
        x = batch.x.to(device)
        edge_index = batch.edge_index.to(device)
        # edge_attr = batch.edge_attr.to(device)
        targets = batch.y.to(device)
        ptr = batch.batch.to(device)
        optimizer.zero_grad()
        batch_scores = model.forward(x, edge_index, ptr)
        loss = loss_fn(batch_scores, targets)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach().item()
        epoch_train_metric += metric(F.sigmoid(batch_scores), targets).detach().item()
        nb_data += targets.size(0)
    epoch_loss /= (iter + 1)
    epoch_train_metric /= (iter + 1)
    
    return epoch_loss, epoch_train_metric, optimizer
def evaluate_epoch(model, loss_fn, device, data_loader, metric):
    model.eval()
    epoch_loss = 0
    epoch_metric = 0
    nb_data = 0
    with torch.no_grad():
        for iter, batch in enumerate(data_loader):
            x = batch.x.to(device)
            edge_index = batch.edge_index.to(device)
            # edge_attr = batch.edge_attr.to(device)
            targets = batch.y.to(device)
            ptr = batch.batch.to(device)
            batch_scores = model.forward(x, edge_index, ptr)
            loss = loss_fn(batch_scores, targets)
            epoch_loss += loss.detach().item()
            epoch_metric += metric(F.sigmoid(batch_scores), targets).item()
            nb_data += targets.size(0)
        epoch_loss /= (iter + 1)
        epoch_metric /= (iter + 1)
    return epoch_loss, epoch_metric
def evaluate_model(model, loss_fn, device, data_loader):
    epoch_loss, epoch_metric = evaluate_epoch(model, loss_fn, device, data_loader)
    print(f'Loss: {epoch_loss}, metric: {epoch_metric}')
    return epoch_loss, epoch_metric

from collections import OrderedDict
def train(model, loss_fn, optimizer, scheduler, device, train_loader, val_loader, test_loader, epochs, min_lr):
    start = time.time()
    logs = { 'train_loss': [], 'val_loss': [], 'test_loss': [], 'train_metric': [], 'val_metric': [], 'test_metric': [], "time_per_epoch": []}
    metric =  MulticlassAccuracy(10).to(device)
    with tqdm(range(epochs)) as pbar:
        for epoch in pbar:
            pbar.set_description(f'Epoch {epoch}')
            epoch_loss, epoch_train_metric, optimizer = train_epoch(model, loss_fn, optimizer, scheduler, device, train_loader, epoch, metric = metric)
            val_loss, val_metric = evaluate_epoch(model, loss_fn, device, val_loader, metric = metric)
            test_loss, test_metric = evaluate_epoch(model, loss_fn, device, test_loader, metric = metric)
            postfix_dict = OrderedDict(
                time = time.time() - start,
                lr=optimizer.param_groups[0]['lr'],
                train_loss=epoch_loss,
                val_loss=val_loss,
                test_loss=test_loss,
                train_metric=epoch_train_metric,
                val_metric=val_metric,
                test_metric=test_metric
            )
            pbar.set_postfix(postfix_dict)
            scheduler.step(val_loss)
            logs['train_loss'].append(epoch_loss)
            logs['val_loss'].append(val_loss)
            logs['test_loss'].append(test_loss)
            logs['train_metric'].append(epoch_train_metric)
            logs['val_metric'].append(val_metric)
            logs['test_metric'].append(test_metric)
            logs['time_per_epoch'].append(time.time() - start)
            if optimizer.param_groups[0]['lr'] <= min_lr:
                print('Early stopping')
                break
    return logs

In [26]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [27]:
model = GCN(3, params_gcn["out_dim"], params_gcn["hidden_dim"],params_gcn["L"], embedding=False, residual=params_gcn["residual"]).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=global_params["init_lr"], weight_decay=global_params["weight_decay"])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=global_params["lr_reduce_factor"],
                                                       patience=global_params["lr_schedule_patience"], min_lr=global_params["min_lr"], verbose = True)
loss = torch.nn.CrossEntropyLoss().to(device)

In [28]:
model

GCN(
  (embedding_layer): Linear(in_features=3, out_features=146, bias=True)
  (convs): ModuleList(
    (0-3): 4 x GCNconv(
      (lin): Linear(in_features=146, out_features=146, bias=True)
      (mp): messagePassing()
    )
  )
  (mlp): MLPReadout(
    (FC_layers): ModuleList(
      (0): Linear(in_features=146, out_features=73, bias=True)
      (1): Linear(in_features=73, out_features=36, bias=True)
      (2): Linear(in_features=36, out_features=10, bias=True)
    )
  )
)

In [29]:
logs = train(model,loss,  optimizer, scheduler, device, train_dataloader, val_dataloader, test_dataloader, global_params["epochs"], global_params["min_lr"])

Epoch 0:   0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 202:  20%|██        | 202/1000 [36:30<2:20:06, 10.53s/it, time=2.19e+3, lr=0.001, train_loss=1.62, val_loss=1.67, test_loss=1.67, train_metric=0.421, val_metric=0.4, test_metric=0.395]  

Epoch 00202: reducing learning rate of group 0 to 5.0000e-04.


Epoch 224:  22%|██▏       | 224/1000 [40:32<2:26:48, 11.35s/it, time=2.43e+3, lr=0.0005, train_loss=1.59, val_loss=1.64, test_loss=1.66, train_metric=0.432, val_metric=0.412, test_metric=0.409]

Epoch 00224: reducing learning rate of group 0 to 2.5000e-04.


Epoch 269:  27%|██▋       | 269/1000 [49:25<2:58:26, 14.65s/it, time=2.97e+3, lr=0.00025, train_loss=1.57, val_loss=1.62, test_loss=1.63, train_metric=0.44, val_metric=0.424, test_metric=0.421] 

Epoch 00269: reducing learning rate of group 0 to 1.2500e-04.


Epoch 293:  29%|██▉       | 293/1000 [54:00<2:05:36, 10.66s/it, time=3.24e+3, lr=0.000125, train_loss=1.55, val_loss=1.61, test_loss=1.62, train_metric=0.445, val_metric=0.42, test_metric=0.423] 

In [13]:
model

GCN(
  (embedding_layer): Embedding(1, 145)
  (convs): ModuleList(
    (0-3): 4 x GCNconv(
      (lin): Linear(in_features=145, out_features=145, bias=True)
      (mp): messagePassing()
    )
  )
  (mlp): MLPReadout(
    (FC_layers): ModuleList(
      (0): Linear(in_features=145, out_features=72, bias=True)
      (1): Linear(in_features=72, out_features=36, bias=True)
      (2): Linear(in_features=36, out_features=1, bias=True)
    )
  )
)

In [14]:
gcn_logs = logs

In [10]:
gat_params =  {
        "L": 4,
        "hidden_dim": 144,
        "out_dim": 144,
        "readout": "mean",
        "n_heads": 8,
        "in_feat_dropout": 0.0,
        "dropout": 0.0,
    }

In [16]:
gat_model = GAT(gat_params["hidden_dim"], gat_params["hidden_dim"], gat_params["out_dim"], gat_params["L"], gat_params["n_heads"]).to(device)
gat_optimizer = torch.optim.Adam(gat_model.parameters(), lr=global_params["init_lr"], weight_decay=global_params["weight_decay"])
gat_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(gat_optimizer, mode='min', factor=global_params["lr_reduce_factor"],
                                                       patience=global_params["lr_schedule_patience"], min_lr=global_params["min_lr"], verbose = True)
gat_loss = torch.nn.BCEWithLogitsLoss().to(device)


In [17]:
gat_model

GAT(
  (embedding_layer): Embedding(1, 144)
  (convs): ModuleList(
    (0-3): 4 x GATconv(
      (lin): Linear(in_features=144, out_features=144, bias=False)
    )
  )
  (mlp): MLPReadout(
    (FC_layers): ModuleList(
      (0): Linear(in_features=144, out_features=72, bias=True)
      (1): Linear(in_features=72, out_features=36, bias=True)
      (2): Linear(in_features=36, out_features=1, bias=True)
    )
  )
)

In [18]:
gat_logs = train(gat_model, gat_loss, gat_optimizer, gat_scheduler, device, train_dataloader, val_dataloader, test_dataloader, global_params["epochs"], global_params["min_lr"])

Epoch 24:   2%|▏         | 24/1000 [00:25<17:51,  1.10s/it, time=25.8, lr=0.001, train_loss=0.693, val_loss=0.692, test_loss=0.692, train_metric=0.502, val_metric=0.609, test_metric=0.609]

Epoch 00024: reducing learning rate of group 0 to 5.0000e-04.


Epoch 45:   4%|▍         | 45/1000 [00:47<16:52,  1.06s/it, time=47.9, lr=0.0005, train_loss=0.693, val_loss=0.693, test_loss=0.693, train_metric=0.502, val_metric=0.609, test_metric=0.609]

Epoch 00045: reducing learning rate of group 0 to 2.5000e-04.


Epoch 66:   7%|▋         | 66/1000 [01:09<16:24,  1.05s/it, time=69.9, lr=0.00025, train_loss=0.693, val_loss=0.693, test_loss=0.693, train_metric=0.498, val_metric=0.609, test_metric=0.609]

Epoch 00066: reducing learning rate of group 0 to 1.2500e-04.


Epoch 87:   9%|▊         | 87/1000 [01:31<15:57,  1.05s/it, time=91.6, lr=0.000125, train_loss=0.693, val_loss=0.693, test_loss=0.693, train_metric=0.499, val_metric=0.609, test_metric=0.609]

Epoch 00087: reducing learning rate of group 0 to 6.2500e-05.


Epoch 108:  11%|█         | 108/1000 [01:54<15:44,  1.06s/it, time=114, lr=6.25e-5, train_loss=0.693, val_loss=0.693, test_loss=0.693, train_metric=0.501, val_metric=0.609, test_metric=0.609]

Epoch 00108: reducing learning rate of group 0 to 3.1250e-05.


Epoch 129:  13%|█▎        | 129/1000 [02:16<15:44,  1.08s/it, time=137, lr=3.13e-5, train_loss=0.693, val_loss=0.693, test_loss=0.693, train_metric=0.504, val_metric=0.609, test_metric=0.609]

Epoch 00129: reducing learning rate of group 0 to 1.5625e-05.


Epoch 150:  15%|█▌        | 150/1000 [02:38<14:42,  1.04s/it, time=159, lr=1.56e-5, train_loss=0.693, val_loss=0.693, test_loss=0.693, train_metric=0.502, val_metric=0.609, test_metric=0.609]

Epoch 00150: reducing learning rate of group 0 to 7.8125e-06.


Epoch 171:  17%|█▋        | 171/1000 [03:02<16:09,  1.17s/it, time=182, lr=7.81e-6, train_loss=0.693, val_loss=0.693, test_loss=0.693, train_metric=0.498, val_metric=0.609, test_metric=0.609]

Epoch 00171: reducing learning rate of group 0 to 3.9063e-06.


Epoch 192:  19%|█▉        | 192/1000 [03:24<14:13,  1.06s/it, time=205, lr=3.91e-6, train_loss=0.693, val_loss=0.693, test_loss=0.693, train_metric=0.501, val_metric=0.609, test_metric=0.609]

Epoch 00192: reducing learning rate of group 0 to 1.9531e-06.


Epoch 212:  21%|██        | 212/1000 [03:47<14:05,  1.07s/it, time=227, lr=1.95e-6, train_loss=0.693, val_loss=0.693, test_loss=0.693, train_metric=0.502, val_metric=0.609, test_metric=0.609]

Epoch 00213: reducing learning rate of group 0 to 1.0000e-06.
Early stopping





In [11]:
gat_model_with_set2set = GAT(gat_params["hidden_dim"], gat_params["hidden_dim"], gat_params["out_dim"], gat_params["L"], gat_params["n_heads"], reduce="set2set", n_iter= 10).to(device)
gat_optimizer_with_set2set = torch.optim.Adam(gat_model_with_set2set.parameters(), lr=global_params["init_lr"], weight_decay=global_params["weight_decay"])
gat_scheduler_with_set2set = torch.optim.lr_scheduler.ReduceLROnPlateau(gat_optimizer_with_set2set, mode='min', factor=global_params["lr_reduce_factor"],
                                                         patience=global_params["lr_schedule_patience"], min_lr=global_params["min_lr"], verbose = True)
gat_loss_with_set2set = torch.nn.BCEWithLogitsLoss().to(device)


In [12]:
set2set_logs = train(gat_model_with_set2set, gat_loss_with_set2set, gat_optimizer_with_set2set, gat_scheduler_with_set2set, device, train_dataloader, val_dataloader, test_dataloader, global_params["epochs"], global_params["min_lr"])

Epoch 29:   3%|▎         | 29/1000 [00:59<33:43,  2.08s/it, time=59.4, lr=0.001, train_loss=0.693, val_loss=0.693, test_loss=0.693, train_metric=0.471, val_metric=0.609, test_metric=0.609]