In [None]:
import torch

# Numpy for matrices
import numpy as np

# Visualization libraries
import matplotlib.pyplot as plt
import networkx as nx

In [None]:
from torch_geometric.datasets import WikiCS

# Import dataset from PyTorch Geometric
dataset = WikiCS('./')

# Print information
print(dataset)
print('------------')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

In [None]:
print(f'Graph: {dataset[0]}')
data = dataset[0]

In [None]:
from torch.nn import Linear
import torch_geometric.nn as gnn

class GNN(torch.nn.Module):
    def __init__(self, algorithm):
        super().__init__()

        if algorithm == 'gat':
            self.gcn = gnn.GATConv(dataset.num_features, 64)
        else:
            self.gcn = gnn.GCNConv(dataset.num_features, 64)
        self.hdn = Linear(64, 32)
        self.out = Linear(32, dataset.num_classes)

        self.loss = []
        self.accuracy = []
        self.optim = None

    def forward(self, x, edge_index):
        embeddings = self.gcn(x, edge_index).relu()
        x = self.hdn(embeddings).relu()
        x = self.out(x)
        return x, embeddings

gcn = GNN('gcn')
gcn.optim = torch.optim.Adam(gcn.parameters(), lr=5e-3)
print(gcn)

gat = GNN('gat')
gat.optim = torch.optim.Adam(gat.parameters(), lr=5e-3)
print(gat)

In [None]:
data.x

In [None]:
criterion = torch.nn.CrossEntropyLoss()

# Calculate accuracy
def accuracy(pred_y, y):
    return (pred_y == y).sum() / len(y)

# Training loop
for epoch in range(100):
    # Clear gradients
    gcn.optim.zero_grad()
    gat.optim.zero_grad()

    # Forward pass
    output_gcn, _ = gcn(data.x, data.edge_index)
    output_gat, _ = gat(data.x, data.edge_index)

    # Calculate loss function
    gcn_loss = criterion(output_gcn, data.y)
    gat_loss = criterion(output_gat, data.y)

    # Compute gradients
    gcn_loss.backward()
    gat_loss.backward()

    # Tune parameters
    gcn.optim.step()
    gat.optim.step()

    # Store data for animations
    gcn.loss.append(gcn_loss.item())
    gcn.accuracy.append(accuracy(output_gcn.argmax(dim=1), data.y))

    # Store data for animations
    gat.loss.append(gat_loss.item())
    gat.accuracy.append(accuracy(output_gat.argmax(dim=1), data.y))

    # Print metrics every 10 epochs
    if epoch % 10 == 0:
        print(f'Epoch {epoch:>3}')
        print(f'\tGCN | Loss: {gcn.loss[-1]:.2f} | Acc: {gcn.accuracy[-1]*100:.2f}%')
        print(f'\tGAT | Loss: {gat.loss[-1]:.2f} | Acc: {gat.accuracy[-1]*100:.2f}%')

In [None]:
fig, ax = plt.subplots(1, 2)

ax[0].plot(gcn.accuracy, linewidth=2, color='green', label="GCN")
ax[1].plot(gcn.loss, linewidth=2, color='green')

ax[0].plot(gat.accuracy, linewidth=2, color='red', label="GAT")
ax[1].plot(gat.loss, linewidth=2, color='red')

ax[0].set_title('Accuracy over training')
ax[1].set_title('Loss over training')

fig.legend()
fig.show()

In [None]:
labels = ['Computational Linguistics',          # 0
        'Databases',                            # 1
        'Operating Systems',                    # 2
        'Computer Architecture',                # 3
        'Computer Security',                    # 4
        'Internet Protocols',                   # 5
        "Computer File Systems",                # 6
        'Distributed Computing Architecture',   # 7 
        'Web Technology',                       # 8
        "Programming Language"]                 # 9

In [None]:
data.edge_index

In [None]:
output, embeddings = gat(data.x, data.edge_index)

print(labels[data.y[0]])
print(labels[torch.argmax(output[0]).item()])

In [None]:
import torch.nn.functional as F

In [None]:
embeddings[0].shape

In [None]:
embeddings[0]

In [None]:
data.y[:10]
values = []
for i in range(10):
    values.append(torch.argmax(output[i]).item())
values ## keeping only ten items for simplicity

In [None]:
idx = 2
to_compare = embeddings[idx] # picking an embedding from the list

In [None]:
similarities = torch.tensor(
    [F.cosine_similarity( # similarity between two tensors
        to_compare, embd, dim=0
        ) for embd in embeddings[:10]] # get similarities from our 10 items
    )
_, indices = torch.topk(similarities, 2) # top 2 largest similarities

In [None]:
for i in range(len(values)):
    print("\033[00m", end="")
    if i == indices[1].item(): # keeping 2nd similarity because #1 is `to_compare` itself
        print("\033[92m", end="")
    if i == idx:
        print("\033[93m", end="")
    print(values[i], end=',')

In yellow, the requested tensor to compare

In green, the most similar tensor