In [1]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import scipy.sparse as sp

import dgl
import numpy as np
import torch
import torch.nn.functional as F
from dgl.dataloading import GraphDataLoader
from sklearn.model_selection import train_test_split
from time import time
from tqdm import tqdm

from functions_HGPLS import train, test

In [2]:
dataset = dgl.data.FakeNewsDataset('gossipcop', 'bert')

In [3]:
dataset = dgl.data.GINDataset("PROTEINS", self_loop=False)

In [3]:
new_dataset = list()
for i in range(len(dataset)):
    g, l = dataset[i]
    g.ndata["feature"] = torch.ones_like(dataset.feature[g.ndata["_ID"]])
    # g.ndata["feature"] = g.ndata["attr"]
    g = dgl.add_self_loop(g)
    g = dgl.add_reverse_edges(g)
    new_dataset.append((g, l))

In [4]:
train_dataset, test_dataset = train_test_split(new_dataset, test_size=0.25, random_state=42)

train_dataloader = GraphDataLoader(train_dataset, batch_size=16, drop_last=False)
test_dataloader = GraphDataLoader(test_dataset, batch_size=16, drop_last=False)

In [5]:
from dgl.nn.pytorch import GATConv

class GAT(nn.Module):
    def __init__(self, in_feats, h_feats, n_classes):
        super(GAT, self).__init__()
        self.layer1 = GATConv(in_feats, h_feats, num_heads=4)
        self.layer2 = GATConv(4*h_feats, h_feats, num_heads=4)
        self.layer3 = GATConv(4*h_feats, h_feats, num_heads=6)
        self.fc = nn.Linear(h_feats, n_classes)
        self.elu = nn.ELU()

    def forward(self, g, in_feat):
        x1 = self.layer1(g, in_feat)
        x1 = self.elu(x1)
        x1 = x1.view(in_feat.shape[0], -1)
        x2 = self.layer2(g, x1)
        x2 = self.elu(x2)
        x2 = x2.view(in_feat.shape[0], -1)
        x3 = self.layer3(g, x2)
        x3 = torch.mean(x3, dim=1)
        with g.local_scope():
            g.ndata['h'] = x3
            x4 = dgl.readout_nodes(g, 'h')
        return F.log_softmax(self.fc(x4), dim=-1)

In [6]:
# Load model architecture
device = 'cpu' if torch.cuda.is_available() else 'cpu'
model = GAT(in_feats=768, n_classes=2, h_feats=128).to(device)

In [7]:
# Define optimizer and loss
optimizer = torch.optim.Adam(
        model.parameters(), lr=0.001, weight_decay=0.001
    )
loss = torch.nn.CrossEntropyLoss()

In [8]:
# Train model and keep the best validation loss model
bad_cound = 0
best_val_acc = 0
best_epoch = 0
epochs = 10
patience = 10
print_every = 1
train_times = []
for e in range(epochs):
    s_time = time()
    train_loss, train_acc = train(model, optimizer, loss, train_dataloader, device)
    train_times.append(time() - s_time)
    val_acc, val_loss = test(model, loss, test_dataloader, device)
    if best_val_acc < val_acc:
        best_val_loss = val_loss
        bad_cound = 0
        best_epoch = e + 1
        torch.save(model.state_dict(), "../models/GATModel_prot.pt")
    else:
        bad_cound += 1
    if bad_cound >= patience:
        break

    if (e + 1) % print_every == 0:
        log_format = (
            "Epoch {}: train_loss={:.4f}, train_acc={:.4f}, val_acc={:.4f}, vall_loss={:.4f}"
        )
        print(log_format.format(e + 1, train_loss, train_acc, val_acc, val_loss))
print(
    "Best Epoch {}, final test acc {:.4f}".format(
        best_epoch, best_val_loss
    )
)

100%|██████████| 257/257 [00:37<00:00,  6.86it/s]
100%|██████████| 86/86 [00:04<00:00, 18.91it/s]


Epoch 1: train_loss=8.8670, train_acc=0.4980, val_acc=0.5329, vall_loss=0.9505


100%|██████████| 257/257 [00:29<00:00,  8.61it/s]
100%|██████████| 86/86 [00:04<00:00, 20.19it/s]


Epoch 2: train_loss=1.2711, train_acc=0.5554, val_acc=0.6164, vall_loss=0.7304


100%|██████████| 257/257 [00:35<00:00,  7.19it/s]
100%|██████████| 86/86 [00:05<00:00, 17.20it/s]


Epoch 3: train_loss=0.7757, train_acc=0.6362, val_acc=0.7592, vall_loss=0.6067


100%|██████████| 257/257 [00:35<00:00,  7.30it/s]
100%|██████████| 86/86 [00:04<00:00, 17.44it/s]


Epoch 4: train_loss=0.6818, train_acc=0.6798, val_acc=0.7313, vall_loss=0.6357


100%|██████████| 257/257 [00:34<00:00,  7.35it/s]
100%|██████████| 86/86 [00:04<00:00, 17.50it/s]


Epoch 5: train_loss=0.6715, train_acc=0.6811, val_acc=0.7313, vall_loss=0.6663


100%|██████████| 257/257 [00:34<00:00,  7.36it/s]
100%|██████████| 86/86 [00:04<00:00, 17.60it/s]


Epoch 6: train_loss=0.6387, train_acc=0.6986, val_acc=0.7635, vall_loss=0.6213


100%|██████████| 257/257 [00:34<00:00,  7.37it/s]
100%|██████████| 86/86 [00:04<00:00, 17.59it/s]


Epoch 7: train_loss=0.6281, train_acc=0.7106, val_acc=0.7694, vall_loss=0.5894


100%|██████████| 257/257 [00:35<00:00,  7.33it/s]
100%|██████████| 86/86 [00:04<00:00, 17.50it/s]


Epoch 8: train_loss=0.6263, train_acc=0.7013, val_acc=0.7701, vall_loss=0.5886


100%|██████████| 257/257 [00:35<00:00,  7.15it/s]
100%|██████████| 86/86 [00:04<00:00, 17.56it/s]


Epoch 9: train_loss=0.6172, train_acc=0.7077, val_acc=0.7687, vall_loss=0.5945


100%|██████████| 257/257 [00:34<00:00,  7.37it/s]
100%|██████████| 86/86 [00:04<00:00, 17.53it/s]

Epoch 10: train_loss=0.6539, train_acc=0.6901, val_acc=0.7723, vall_loss=0.5861
Best Epoch 10, final test acc 0.5861



