In [1]:
import torch
from torch import optim
import matplotlib.pyplot as plt

from GAT_model import GAT
from trainer import Trainer
from datasets_list import datasets_list

First, detect the device: CPU or GPU

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

Using device: cpu


Then, define one of the datasets (`Cora`, `Pubmed` or `Citeseer`), load all constants for it and the data

In [3]:
dataset = 'cora'

data_loader = datasets_list[dataset]['load_function']
_, _, _, NUM_INPUT_FEATURES, NUM_CLASSES = datasets_list[dataset]['constants']

In [4]:
data = data_loader(f'datasets/{dataset}/', device=device)

  0%|          | 0/2708 [00:00<?, ?it/s]

  self._set_intXint(row, col, x.flat[0])


Deifne GAT model, optimizer, trainer and train network (with or without logs)

In [5]:
model = GAT(
    num_of_layers=2,
    num_heads_per_layer=[1, 1],
    num_features_per_layer=[NUM_INPUT_FEATURES, 50, NUM_CLASSES],
    add_skip_connection=False,
    bias=True,
    dropout=0.7,
)

optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-3)
trainer = Trainer(model=model, optimizer=optimizer, data=data, return_logs=True)

logs_na = trainer.train(50)

  0%|          | 0/50 [00:00<?, ?it/s]

----------------------
train 0.16428571428571428
val 0.238
----------------------
train 0.20714285714285716
val 0.282
----------------------
train 0.14285714285714285
val 0.33
----------------------
train 0.19285714285714287
val 0.384
----------------------
train 0.22142857142857142
val 0.408
----------------------
train 0.3142857142857143
val 0.442
----------------------
train 0.30714285714285716
val 0.446
----------------------
train 0.37142857142857144
val 0.462
----------------------
train 0.34285714285714286
val 0.478
----------------------
train 0.32857142857142857
val 0.49
----------------------
train 0.37857142857142856
val 0.512
----------------------
train 0.40714285714285714
val 0.522
----------------------
train 0.37857142857142856
val 0.53
----------------------
train 0.45714285714285713
val 0.538
----------------------
train 0.45714285714285713
val 0.536
----------------------
train 0.4
val 0.538
----------------------
train 0.45
val 0.538
----------------------
train 0.4