In [11]:
import torch
import torch_geometric as pyg
import torch.nn.functional as F
import torch_scatter as tc
from functools import partial
from tqdm import tqdm
import time
from torch_geometric.loader import DataLoader

In [12]:
global_params = {
    "seed": 41,
    "epochs": 1000,
    "batch_size": 128,
    "init_lr": 1e-3,
    "lr_reduce_factor": 0.5,
    "lr_schedule_patience": 20,
    "min_lr": 1e-6,
    "weight_decay": 0.0,
    "print_epoch_interval": 5,
    "max_time": 12}
params_gcn = {
    "L": 4,
    "hidden_dim": 145,
    "out_dim": 145,
    "residual": True,
    "readout": "mean",
    "in_feat_dropout": 0.0,
    "dropout": 0.0,
    "batch_norm": True,
    "self_loop": True
}

In [13]:
def add_x(data):
    data.x = torch.zeros(data.num_nodes, 1, dtype = torch.int64)
    return data

In [14]:
reddit_binary_dataset = pyg.datasets.TUDataset(root='./data', name='REDDIT-BINARY', pre_transform=add_x)

In [15]:
permutation1 = torch.randperm(len(reddit_binary_dataset)//2) - 500
permutation2 = torch.randperm(len(reddit_binary_dataset)//2) + 500
train_fraction = 0.8
val_fraction = 0.1
test_fraction = 0.1
train_premutation = torch.cat((permutation1[:int(train_fraction*len(permutation1))], permutation2[:int(train_fraction*len(permutation2))]))
val_premutation = torch.cat((permutation1[int(train_fraction*len(permutation1)):int((train_fraction+val_fraction)*len(permutation1))], permutation2[int(train_fraction*len(permutation2)):int((train_fraction+val_fraction)*len(permutation2))]))
test_premutation = torch.cat((permutation1[int((train_fraction+val_fraction)*len(permutation1)):], permutation2[int((train_fraction+val_fraction)*len(permutation2)):]))
train_dataset = reddit_binary_dataset[train_premutation]
val_dataset = reddit_binary_dataset[val_premutation]
test_dataset = reddit_binary_dataset[test_premutation]



In [16]:
train_dataloader = DataLoader(train_dataset, batch_size=global_params["batch_size"], shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=global_params["batch_size"], shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=global_params["batch_size"], shuffle=False)


In [71]:
from models import GCN, GAT
from utils import train

In [18]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [19]:
model = GCN(params_gcn["hidden_dim"], params_gcn["hidden_dim"], params_gcn["out_dim"], params_gcn["L"]).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=global_params["init_lr"], weight_decay=global_params["weight_decay"])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=global_params["lr_reduce_factor"],
                                                       patience=global_params["lr_schedule_patience"], min_lr=global_params["min_lr"], verbose = True)
loss = torch.nn.BCEWithLogitsLoss().to(device)

In [20]:
test = next(iter(train_dataloader))

In [21]:

test.y

tensor([1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1,
        0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1,
        1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0,
        1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1,
        1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1,
        0, 1, 0, 1, 0, 0, 0, 0])

In [22]:
number_of_tried = 5

In [23]:
gcn_logs = []
for _ in range(number_of_tried):
    model = GCN(params_gcn["hidden_dim"], params_gcn["hidden_dim"], params_gcn["out_dim"], params_gcn["L"]).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=global_params["init_lr"], weight_decay=global_params["weight_decay"])
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=global_params["lr_reduce_factor"],
                                                        patience=global_params["lr_schedule_patience"], min_lr=global_params["min_lr"], verbose = True)
    loss = torch.nn.BCEWithLogitsLoss().to(device)
    
    gcn_logs.append(train(model,loss,  optimizer, scheduler, device, train_dataloader, val_dataloader, test_dataloader, global_params["epochs"], global_params["min_lr"]))


Epoch 288:  29%|██▉       | 288/1000 [01:14<03:05,  3.84it/s, time=74.9, lr=1.95e-6, train_loss=0.466, val_loss=0.548, test_loss=0.467, train_metric=0.744, val_metric=0.659, test_metric=0.707] 


Early stopping


Epoch 308:  31%|███       | 308/1000 [01:28<03:17,  3.50it/s, time=88, lr=1.95e-6, train_loss=0.43, val_loss=0.536, test_loss=0.42, train_metric=0.785, val_metric=0.705, test_metric=0.788]     


Early stopping


Epoch 420:  42%|████▏     | 420/1000 [02:02<02:48,  3.44it/s, time=122, lr=1.95e-6, train_loss=0.253, val_loss=0.447, test_loss=0.279, train_metric=0.898, val_metric=0.822, test_metric=0.855]  


Early stopping


Epoch 338:  34%|███▍      | 338/1000 [01:39<03:14,  3.40it/s, time=99.4, lr=1.95e-6, train_loss=0.467, val_loss=0.551, test_loss=0.464, train_metric=0.736, val_metric=0.651, test_metric=0.708] 


Early stopping


Epoch 505:  50%|█████     | 505/1000 [02:28<02:25,  3.40it/s, time=149, lr=1.95e-6, train_loss=0.274, val_loss=0.426, test_loss=0.292, train_metric=0.891, val_metric=0.836, test_metric=0.854] 

Early stopping





In [24]:
model

GCN(
  (embedding_layer): Embedding(1, 145)
  (convs): ModuleList(
    (0-3): 4 x GCNconv(
      (lin): Linear(in_features=145, out_features=145, bias=True)
      (mp): messagePassing()
    )
  )
  (mlp): MLPReadout(
    (FC_layers): ModuleList(
      (0): Linear(in_features=145, out_features=72, bias=True)
      (1): Linear(in_features=72, out_features=36, bias=True)
      (2): Linear(in_features=36, out_features=1, bias=True)
    )
  )
)

In [25]:
# https://arxiv.org/pdf/1907.02204v4.pdf

In [26]:
gat_params =  {
        "L": 4,
        "hidden_dim": 144,
        "out_dim": 144,
        "readout": "mean",
        "n_heads": 2,
        "in_feat_dropout": 0.0,
        "dropout": 0.0,
    }

In [27]:
gat_logs = []
for _ in range(5):
    gat_model = GAT(gat_params["hidden_dim"], gat_params["hidden_dim"], gat_params["out_dim"], gat_params["L"], gat_params["n_heads"]).to(device)
    gat_optimizer = torch.optim.Adam(gat_model.parameters(), lr=global_params["init_lr"]/10, weight_decay=global_params["weight_decay"])
    gat_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(gat_optimizer, mode='min', factor=global_params["lr_reduce_factor"],
                                                        patience=global_params["lr_schedule_patience"], min_lr=global_params["min_lr"], verbose = True)
    gat_loss = torch.nn.BCEWithLogitsLoss().to(device)
    gat_logs.append(train(gat_model, gat_loss, gat_optimizer, gat_scheduler, device, train_dataloader, val_dataloader, test_dataloader, global_params["epochs"], global_params["min_lr"]))

Epoch 269:  27%|██▋       | 269/1000 [01:59<05:25,  2.25it/s, time=120, lr=1.56e-6, train_loss=0.73, val_loss=0.745, test_loss=0.836, train_metric=0.74, val_metric=0.696, test_metric=0.727]   


Early stopping


Epoch 330:  33%|███▎      | 330/1000 [02:26<04:58,  2.25it/s, time=147, lr=1.56e-6, train_loss=0.784, val_loss=0.684, test_loss=0.709, train_metric=0.776, val_metric=0.743, test_metric=0.797] 


Early stopping


Epoch 265:  26%|██▋       | 265/1000 [01:57<05:26,  2.25it/s, time=118, lr=1.56e-6, train_loss=0.554, val_loss=0.598, test_loss=0.477, train_metric=0.714, val_metric=0.668, test_metric=0.719] 


Early stopping


Epoch 226:  23%|██▎       | 226/1000 [01:40<05:44,  2.25it/s, time=101, lr=1.56e-6, train_loss=0.579, val_loss=0.633, test_loss=0.69, train_metric=0.749, val_metric=0.71, test_metric=0.727]   


Early stopping


Epoch 301:  30%|███       | 301/1000 [02:14<05:11,  2.25it/s, time=134, lr=1.56e-6, train_loss=0.685, val_loss=0.634, test_loss=0.642, train_metric=0.712, val_metric=0.694, test_metric=0.714] 

Early stopping





In [42]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [67]:
gat_model_with_set2set = GAT(64, 64, 64, gat_params["L"], gat_params["n_heads"], reduce="set2set", n_iter= 10).to(device)
gat_optimizer_with_set2set = torch.optim.Adam(gat_model_with_set2set.parameters(), lr=global_params["init_lr"], weight_decay=global_params["weight_decay"])
gat_scheduler_with_set2set = torch.optim.lr_scheduler.ReduceLROnPlateau(gat_optimizer_with_set2set, mode='min', factor=global_params["lr_reduce_factor"],
                                                         patience=global_params["lr_schedule_patience"], min_lr=global_params["min_lr"], verbose = True)
gat_loss_with_set2set = torch.nn.BCEWithLogitsLoss().to(device)


In [73]:
set2set_logs =[]
for _ in range(5):
    gat_model_with_set2set = GAT(gat_params["hidden_dim"], gat_params["hidden_dim"]//2, gat_params["out_dim"], gat_params["L"], gat_params["n_heads"], reduce="set2set", n_iter= 30, residual=False).to(device)
    gat_optimizer_with_set2set = torch.optim.Adam(gat_model_with_set2set.parameters(), lr=global_params["init_lr"], weight_decay=global_params["weight_decay"])
    gat_scheduler_with_set2set = torch.optim.lr_scheduler.ReduceLROnPlateau(gat_optimizer_with_set2set, mode='min', factor=global_params["lr_reduce_factor"],
                                                            patience=global_params["lr_schedule_patience"], min_lr=global_params["min_lr"], verbose = True)
    gat_loss_with_set2set = torch.nn.BCEWithLogitsLoss().to(device)
    set2set_logs.append(train(gat_model_with_set2set, gat_loss_with_set2set, gat_optimizer_with_set2set, gat_scheduler_with_set2set, device, train_dataloader, val_dataloader, test_dataloader, global_params["epochs"], global_params["min_lr"]))

Epoch 286:  29%|██▊       | 286/1000 [07:35<18:56,  1.59s/it, time=455, lr=1.95e-6, train_loss=0.617, val_loss=0.619, test_loss=0.639, train_metric=0.633, val_metric=0.62, test_metric=0.574]  


Early stopping


Epoch 285:  28%|██▊       | 285/1000 [07:35<19:02,  1.60s/it, time=455, lr=1.95e-6, train_loss=0.632, val_loss=0.626, test_loss=0.642, train_metric=0.628, val_metric=0.623, test_metric=0.592] 


Early stopping


Epoch 220:  22%|██▏       | 220/1000 [05:52<20:48,  1.60s/it, time=352, lr=1.95e-6, train_loss=0.651, val_loss=0.641, test_loss=0.646, train_metric=0.63, val_metric=0.619, test_metric=0.582] 


Early stopping


Epoch 212:  21%|██        | 212/1000 [05:39<21:03,  1.60s/it, time=340, lr=1.95e-6, train_loss=0.79, val_loss=0.708, test_loss=0.735, train_metric=0.508, val_metric=0.481, test_metric=0.464]  


Early stopping


Epoch 258:  26%|██▌       | 258/1000 [06:51<19:43,  1.59s/it, time=411, lr=1.95e-6, train_loss=0.659, val_loss=0.642, test_loss=0.652, train_metric=0.633, val_metric=0.622, test_metric=0.579] 

Early stopping





In [39]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
models = [model, gat_model, gat_model_with_set2set]
logs = [gcn_logs, gat_logs, set2set_logs]
parameter_counts = [count_parameters(model) for model in models]
parameter_counts
import pickle 
pickle.dump(logs, open("logs_reddit_additive.pkl", "wb"))

In [40]:
parameter_counts

[98002, 97345, 191665]