In [1]:
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid, Flickr
from torch_geometric.loader import GraphSAINTRandomWalkSampler, GraphSAINTNodeSampler

from transformers import GPT2Model

# load the modeules I wrote
from graph_gpt_classification import Graph_GPT_Classification



In [2]:
# dataset = Planetoid(root = '/tmp/Cora', name = 'Cora')
dataset = Flickr(root = './tmp/Flickr')

# check data
dataset[0]

Data(x=[89250, 500], edge_index=[2, 899756], y=[89250], train_mask=[89250], val_mask=[89250], test_mask=[89250])

In [3]:
loader = GraphSAINTNodeSampler(
    dataset[0],
    batch_size = 1280, 
    num_steps = 100)

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
# initilize model
gpt_model = GPT2Model.from_pretrained('distilgpt2')
graph_gpt_model = Graph_GPT_Classification(gpt_model,
                                           dataset.num_node_features, 
                                           128,
                                           dataset.num_classes)

graph_gpt_model.to(device)

Graph_GPT_Classification(
  (g_conv_1): GCNConv(500, 768)
  (transformer_layers): ModuleList(
    (0-5): 6 x GPT2Block(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2Attention(
        (c_attn): Conv1D()
        (c_proj): Conv1D()
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D()
        (c_proj): Conv1D()
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (g_conv_2): GCNConv(768, 7)
)

In [6]:
loader = GraphSAINTNodeSampler(
    dataset[0],
    batch_size = 128, 
    num_steps = 1000)

In [7]:
optimizer = torch.optim.Adam(graph_gpt_model.parameters(), lr=0.001, weight_decay=5e-4)

for epoch in range(5):
    step = 1
    for subgraph in loader:
        step += 1
        subgraph = subgraph.to(device)
        optimizer.zero_grad()
        out = graph_gpt_model(subgraph)
        loss = F.nll_loss(out[subgraph.train_mask], subgraph.y[subgraph.train_mask])

        if (step + 1) % 10 == 0:
            # print(f'step: {step+1}, training loss: {loss.item()}')
            t_loss = F.nll_loss(out[subgraph.test_mask], subgraph.y[subgraph.test_mask])
            print(f'epoch: {epoch + 1}, step: {step+1}, training loss: {loss.item()}, testing loss: {t_loss.item()}')

        if (step + 1) % 10 == 0:
            pred = graph_gpt_model(subgraph).argmax(dim=1)
            correct = (pred[subgraph.test_mask] == subgraph.y[subgraph.test_mask]).sum()
            acc = int(correct) / int(subgraph.test_mask.sum())
            print(f'Accuracy: {acc:.4f}')
            
        loss.backward()
        optimizer.step()

epoch: 1, step: 10, training loss: 16.852779388427734, testing loss: 15.763978958129883
Accuracy: 0.3429
epoch: 1, step: 20, training loss: 12.42660903930664, testing loss: 8.590229988098145
Accuracy: 0.2308
epoch: 1, step: 30, training loss: 9.657149314880371, testing loss: 7.8424391746521
Accuracy: 0.3913
epoch: 1, step: 40, training loss: 8.627907752990723, testing loss: 7.771993637084961
Accuracy: 0.3333
epoch: 1, step: 50, training loss: 8.068082809448242, testing loss: 5.506824016571045
Accuracy: 0.3793
epoch: 1, step: 60, training loss: 7.083344459533691, testing loss: 10.865238189697266
Accuracy: 0.2609
epoch: 1, step: 70, training loss: 5.7629218101501465, testing loss: 7.5988969802856445
Accuracy: 0.2069
epoch: 1, step: 80, training loss: 6.7849016189575195, testing loss: 5.192543983459473
Accuracy: 0.2821
epoch: 1, step: 90, training loss: 5.071207523345947, testing loss: 5.8787689208984375
Accuracy: 0.3125
epoch: 1, step: 100, training loss: 4.722263813018799, testing loss: