In [1]:
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-{torch.__version__}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.4/9.4 MB[0m [31m59.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.5/4.5 MB[0m [31m27.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for torch-geometric (pyproject.toml) ... [?25l[?25hdone


In [3]:
from torch_geometric.datasets import Planetoid

# Import dataset from PyTorch Geometric
dataset = Planetoid(root=".", name="CiteSeer")
data = dataset[0]

# Print information about the dataset
print(f'Number of graphs: {len(dataset)}')
print(f'Number of nodes: {data.x.shape[0]}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.test.index
Processing...


Number of graphs: 1
Number of nodes: 3327
Number of features: 3703
Number of classes: 6
Has isolated nodes: True


Done!


In [12]:
import torch
import torch.nn as nn
from tqdm import tqdm

import torch.nn.functional as F
from torch.nn import Linear, Dropout
from torch_geometric.nn import GCNConv, GATv2Conv

import numpy as np
np.random.seed(0)

# Visualization
import networkx as nx
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
plt.rcParams['figure.dpi'] = 300
plt.rcParams.update({'font.size': 24})

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class GraphAttentionNetwork(nn.Module):
    def __init__(self, num_features : int, h_dim : int, num_classes : int, num_heads : int = 8):
        super().__init__()
        self.gat1 : GATv2Conv = GATv2Conv(num_features, h_dim * num_heads)
        self.gat2 : GATv2Conv = GATv2Conv(h_dim * num_heads, num_classes)

        self.__criterion = None
        self.__losses = None
        self.__optimizer = None
    
    def forward(self, x, edge_index):
        h = F.dropout(x, p=0.6, training=self.training)
        h = self.gat1(h, edge_index)
        h = F.elu(h)
        h = F.dropout(h, p=0.6, training=self.training)
        h = self.gat2(h, edge_index)
        return h, F.log_softmax(h, dim=1)

    def compile(self, criterion : nn.Module, optimizer : str, lr : float = 0.001):
        self.__criterion = criterion
        
        if optimizer.lower() == 'adam':
            self.__optimizer = torch.optim.Adam(self.parameters(), lr = lr)

        else: # default
            self.__optimizer = torch.optim.Adam(self.parameters(), lr = lr)
    
    def accuracy(self, pred_y, y):
        return ((pred_y == y).sum() / len(y)).item()
    
    def fit(self, data, epochs):
        self.losses = []
        if self.__criterion == None:
            print(f"You Should Compile the model first using model.compile()")
            return
        
        self.train()

        training_losses = []
        validation_losses = []

        train_accuracy = []
        validation_accuracy = []

        for epoch in tqdm(range(epochs+1)):
            self.__optimizer.zero_grad()
            _, output = self(data.x, data.edge_index)
            loss = self.__criterion(output[data.val_mask], data.y[data.val_mask])
            acc = self.accuracy(output[data.train_mask].argmax(dim=1), data.y[data.train_mask])

            loss.backward()
            self.__optimizer.step()

            val_loss = self.__criterion(output[data.val_mask], data.y[data.val_mask])
            val_acc = self.accuracy(output[data.val_mask].argmax(dim=1), data.y[data.val_mask])

            training_losses.append(loss)
            validation_losses.append(val_loss)

            train_accuracy.append(acc)
            validation_accuracy.append(val_acc)

            self.losses.append((loss, val_loss))
            if(epoch % 10 == 0):
                print(f'Epoch {epoch:>3} | Train Loss: {loss:.3f} | Train Acc: '
                    f'{acc*100:>6.2f}% | Val Loss: {val_loss:.2f} | '
                    f'Val Acc: {val_acc*100:.2f}%')
        

        return {
            "training_loss" : training_losses,
            "validation_loss" : validation_losses,
            "training_accuracy" : train_accuracy,
            "validation_accuracy" : validation_accuracy
        }

    
    @torch.no_grad()
    def test(self, data):
        model.eval()
        _, output = self(data.x, data.edge_index)
        acc = self.accuracy(output.argmax(dim=1)[data.test_mask], data.y[data.test_mask])
        return acc

    def visualize_graph(self, data):
        h, _ = self(data.x, data.edge_index)
        tsne = TSNE(n_components=2, learning_rate='auto',
         init='pca').fit_transform(h.detach())

        # Plot TSNE
        plt.figure(figsize=(10, 10))
        plt.axis('off')
        plt.scatter(tsne[:, 0], tsne[:, 1], s=50, c=data.y)
        plt.show()
    
    


In [18]:
model = GraphAttentionNetwork(dataset.num_features, 8, dataset.num_classes).to(DEVICE)

In [19]:
criterion = torch.nn.CrossEntropyLoss()
model.compile(criterion, 'adam')

In [20]:
data = data.to(DEVICE)

In [21]:
results = model.fit(data, 200)

 10%|█         | 21/201 [00:00<00:01, 100.23it/s]

Epoch   0 | Train Loss: 1.793 | Train Acc:  10.00% | Val Loss: 1.79 | Val Acc: 20.00%
Epoch  10 | Train Loss: 1.213 | Train Acc:  50.00% | Val Loss: 1.21 | Val Acc: 75.60%
Epoch  20 | Train Loss: 0.804 | Train Acc:  63.33% | Val Loss: 0.80 | Val Acc: 85.20%


 21%|██▏       | 43/201 [00:00<00:01, 104.30it/s]

Epoch  30 | Train Loss: 0.584 | Train Acc:  60.00% | Val Loss: 0.58 | Val Acc: 88.80%
Epoch  40 | Train Loss: 0.411 | Train Acc:  66.67% | Val Loss: 0.41 | Val Acc: 91.60%
Epoch  50 | Train Loss: 0.310 | Train Acc:  59.17% | Val Loss: 0.31 | Val Acc: 94.20%


 38%|███▊      | 76/201 [00:00<00:01, 101.39it/s]

Epoch  60 | Train Loss: 0.247 | Train Acc:  64.17% | Val Loss: 0.25 | Val Acc: 95.40%
Epoch  70 | Train Loss: 0.181 | Train Acc:  64.17% | Val Loss: 0.18 | Val Acc: 96.60%


 50%|████▉     | 100/201 [00:00<00:00, 105.21it/s]

Epoch  80 | Train Loss: 0.171 | Train Acc:  64.17% | Val Loss: 0.17 | Val Acc: 96.00%
Epoch  90 | Train Loss: 0.129 | Train Acc:  67.50% | Val Loss: 0.13 | Val Acc: 97.80%
Epoch 100 | Train Loss: 0.114 | Train Acc:  65.83% | Val Loss: 0.11 | Val Acc: 98.20%


 64%|██████▍   | 129/201 [00:01<00:00, 122.78it/s]

Epoch 110 | Train Loss: 0.095 | Train Acc:  61.67% | Val Loss: 0.09 | Val Acc: 98.20%
Epoch 120 | Train Loss: 0.086 | Train Acc:  63.33% | Val Loss: 0.09 | Val Acc: 98.00%
Epoch 130 | Train Loss: 0.070 | Train Acc:  59.17% | Val Loss: 0.07 | Val Acc: 99.00%


 78%|███████▊  | 157/201 [00:01<00:00, 125.50it/s]

Epoch 140 | Train Loss: 0.066 | Train Acc:  60.00% | Val Loss: 0.07 | Val Acc: 98.80%
Epoch 150 | Train Loss: 0.063 | Train Acc:  60.00% | Val Loss: 0.06 | Val Acc: 98.60%
Epoch 160 | Train Loss: 0.062 | Train Acc:  63.33% | Val Loss: 0.06 | Val Acc: 98.80%


 93%|█████████▎| 186/201 [00:01<00:00, 133.00it/s]

Epoch 170 | Train Loss: 0.056 | Train Acc:  63.33% | Val Loss: 0.06 | Val Acc: 99.00%
Epoch 180 | Train Loss: 0.061 | Train Acc:  60.83% | Val Loss: 0.06 | Val Acc: 99.00%
Epoch 190 | Train Loss: 0.048 | Train Acc:  59.17% | Val Loss: 0.05 | Val Acc: 99.20%


100%|██████████| 201/201 [00:01<00:00, 117.05it/s]

Epoch 200 | Train Loss: 0.038 | Train Acc:  62.50% | Val Loss: 0.04 | Val Acc: 99.20%



