In [None]:
# Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np

In [None]:
# ============= GPU SETUP =============
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
if device.type == 'cuda':
    print(f'GPU: {torch.cuda.get_device_name(0)}')

In [None]:
# ============= DATA =============
num_nodes = 9

edges = [
    [0, 1], [1, 0], [1, 2], [2, 1],
    [3, 4], [4, 3], [4, 5], [5, 4],
    [6, 7], [7, 6], [7, 8], [8, 7],
    [0, 3], [3, 0], [3, 6], [6, 3],
    [1, 4], [4, 1], [4, 7], [7, 4],
    [2, 5], [5, 2], [5, 8], [8, 5],
]

edge_index = torch.tensor(edges, dtype=torch.long).t().to(device)  # Move to GPU

np.random.seed(42)
node_features = torch.tensor([
    [8.0, 1, 45.0, 12, 1, 0.3],
    [8.0, 1, 25.0, 35, 1, 0.7],
    [8.0, 1, 50.0, 8, 0, 0.2],
    [8.0, 1, 40.0, 15, 1, 0.4],
    [8.0, 1, 20.0, 45, 1, 0.8],
    [8.0, 1, 35.0, 20, 1, 0.5],
    [8.0, 1, 30.0, 28, 0, 0.6],
    [8.0, 1, 38.0, 18, 1, 0.45],
    [8.0, 1, 55.0, 5, 0, 0.15],
], dtype=torch.float).to(device)  # Move to GPU

labels = torch.tensor([0, 2, 0, 1, 2, 1, 2, 1, 0], dtype=torch.long).to(device)  # Move to GPU

In [None]:
# ============= GAT LAYER =============
class GATLayer(nn.Module):
    def __init__(self, in_features, out_features, dropout=0.6, alpha=0.2, concat=True):
        super(GATLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.dropout = dropout
        self.alpha = alpha
        self.concat = concat

        self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)

        self.a = nn.Parameter(torch.zeros(size=(2 * out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)

        self.leakyrelu = nn.LeakyReLU(self.alpha)

    def forward(self, x, edge_index):
        h = torch.mm(x, self.W)
        num_nodes = h.size(0)

        source_nodes = edge_index[0]
        target_nodes = edge_index[1]

        h_concat = torch.cat([h[source_nodes], h[target_nodes]], dim=1)
        e = self.leakyrelu(torch.matmul(h_concat, self.a))

        attention = torch.zeros(num_nodes, num_nodes, device=x.device)  # Keep on same device
        attention[source_nodes, target_nodes] = e.squeeze()
        attention = F.softmax(attention, dim=0)
        attention = F.dropout(attention, self.dropout, training=self.training)

        h_prime = torch.matmul(attention.t(), h)

        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime


class MultiHeadGATLayer(nn.Module):
    def __init__(self, in_features, out_features, num_heads, dropout=0.6, alpha=0.2, concat=True):
        super(MultiHeadGATLayer, self).__init__()
        self.num_heads = num_heads
        self.concat = concat

        self.attention_heads = nn.ModuleList([
            GATLayer(in_features, out_features, dropout, alpha, concat)
            for _ in range(num_heads)
        ])

    def forward(self, x, edge_index):
        head_outputs = [head(x, edge_index) for head in self.attention_heads]

        if self.concat:
            return torch.cat(head_outputs, dim=1)
        else:
            return torch.mean(torch.stack(head_outputs), dim=0)


class GAT(nn.Module):
    def __init__(self, in_features, hidden_dim, num_classes, num_heads=4, dropout=0.6):
        super(GAT, self).__init__()

        self.gat1 = MultiHeadGATLayer(
            in_features,
            hidden_dim,
            num_heads=num_heads,
            dropout=dropout,
            concat=True
        )

        self.gat2 = GATLayer(
            hidden_dim * num_heads,
            num_classes,
            dropout=dropout,
            concat=False
        )

    def forward(self, x, edge_index):
        x = self.gat1(x, edge_index)
        x = self.gat2(x, edge_index)
        return x

In [None]:
# ============= TRAINING =============
model = GAT(in_features=6, hidden_dim=8, num_classes=3, num_heads=4, dropout=0.3).to(device)  # Move to GPU
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)

print(f'\nModel parameters on: {next(model.parameters()).device}')

def train():
    model.train()
    optimizer.zero_grad()
    out = model(node_features, edge_index)
    loss = F.cross_entropy(out, labels)
    loss.backward()
    optimizer.step()
    return loss.item()

def test():
    model.eval()
    with torch.no_grad():
        out = model(node_features, edge_index)
        pred = out.argmax(dim=1)
        correct = (pred == labels).sum().item()
        acc = correct / len(labels)
    return acc, pred

# Train
print("\nTraining GAT on GPU...")
for epoch in range(200):
    loss = train()
    if epoch % 20 == 0:
        acc, _ = test()
        print(f'Epoch {epoch:03d}, Loss: {loss:.4f}, Accuracy: {acc:.4f}')

# Final predictions
acc, predictions = test()
print(f'\n=== Final Results ===')
print(f'Accuracy: {acc:.4f}')
print(f'True labels:      {labels.cpu().numpy()}')  # Move to CPU for display
print(f'Predictions:      {predictions.cpu().numpy()}')
print(f'\n0=Low, 1=Medium, 2=High Congestion')

In [None]:
# ============= VISUALIZE =============
def visualize_attention():
    model.eval()
    with torch.no_grad():
        x = node_features
        h = torch.mm(x, model.gat1.attention_heads[0].W)

        source_nodes = edge_index[0]
        target_nodes = edge_index[1]

        h_concat = torch.cat([h[source_nodes], h[target_nodes]], dim=1)
        e = model.gat1.attention_heads[0].leakyrelu(
            torch.matmul(h_concat, model.gat1.attention_heads[0].a)
        )

        num_nodes = node_features.size(0)
        attention = torch.zeros(num_nodes, num_nodes, device=device)
        attention[source_nodes, target_nodes] = e.squeeze()
        attention = F.softmax(attention, dim=0)

    # Move to CPU for visualization
    attention_cpu = attention.cpu().numpy()
    predictions_cpu = predictions.cpu().numpy()

    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.imshow(attention_cpu, cmap='hot', interpolation='nearest')
    plt.colorbar(label='Attention Weight')
    plt.title('Attention Weight Matrix')
    plt.xlabel('Source Node')
    plt.ylabel('Target Node')

    plt.subplot(1, 2, 2)
    G = nx.Graph()
    for i in range(num_nodes):
        G.add_node(i)
    for i in range(0, len(edges), 2):
        G.add_edge(edges[i][0], edges[i][1])

    pos = {
        0: (0, 2), 1: (1, 2), 2: (2, 2),
        3: (0, 1), 4: (1, 1), 5: (2, 1),
        6: (0, 0), 7: (1, 0), 8: (2, 0)
    }

    colors = ['green', 'yellow', 'red']
    node_colors = [colors[predictions_cpu[i]] for i in range(num_nodes)]

    nx.draw(G, pos, node_color=node_colors, node_size=1200,
            with_labels=True, font_size=14, font_weight='bold',
            edge_color='gray', width=2)
    plt.title('GAT Predictions\nGreen=Low, Yellow=Med, Red=High')

    plt.tight_layout()
    plt.show()

visualize_attention()