In [1]:
!pip install torch_geometric torch-scatter torch-sparse torch-cluster torch-spline-conv -q
!pip install plotly -q

import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv, GATConv

from sklearn.manifold import TSNE
import plotly.graph_objects as go
import pandas as pd
import numpy as np
import random

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m108.0/108.0 kB[0m [31m6.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m210.0/210.0 kB[0m [31m14.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.5/54.5 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m39.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for torch-scatter (setup.py) ... [?25l[?25hdone
  Building wheel for torch-sparse (setup.py) ... [?25l[?25hdone
  Building wheel for torch-clust

## Dataset + seeds

In [2]:
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

dataset = Planetoid(root="/tmp/Cora", name="Cora")
data = dataset[0].to(device)
num_features = dataset.num_node_features
num_classes = dataset.num_classes

Device: cpu


Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!


## Define the two models

In [3]:
class GCN(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.conv1 = GCNConv(in_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, out_dim)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return x

class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels,
                 heads=4, dropout=0.5):
        super().__init__()
        self.gat1 = GATConv(in_channels, hidden_channels,
                            heads=heads, dropout=dropout)
        self.gat2 = GATConv(hidden_channels * heads, out_channels,
                            heads=1, concat=False, dropout=dropout)
        self.dropout = dropout

    def forward(self, x, edge_index):
        x_res = x
        x = self.gat1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        if x_res.shape[1] == x.shape[1]:
            x = x + x_res
        x = self.gat2(x, edge_index)
        return x


## Training helper

In [4]:
def train_model(model, data, lr=0.01, weight_decay=5e-4, patience=20, max_epochs=200):
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    best_state = None
    best_val_acc = 0.0
    bad_epochs = 0

    for epoch in range(1, max_epochs + 1):
        model.train()
        optimizer.zero_grad()
        out = model(data.x, data.edge_index)
        loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()

        model.eval()
        with torch.no_grad():
            logits = model(data.x, data.edge_index)
            preds = logits.argmax(dim=1)
            val_acc = (preds[data.val_mask] == data.y[data.val_mask]).float().mean().item()

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_state = model.state_dict()
            bad_epochs = 0
        else:
            bad_epochs += 1

        if bad_epochs >= patience:
            print(f"Early stopping at epoch {epoch}")
            break

    model.load_state_dict(best_state)
    model.eval()
    with torch.no_grad():
        logits = model(data.x, data.edge_index)
        preds = logits.argmax(dim=1)
        test_acc = (preds[data.test_mask] == data.y[data.test_mask]).float().mean().item()
    return model, logits, test_acc


## Train both and compute embeddings

In [5]:
gcn = GCN(num_features, 64, num_classes)
gat = GAT(num_features, 8, num_classes, heads=8, dropout=0.6)

gcn, gcn_logits, gcn_test_acc = train_model(gcn, data, lr=0.01, weight_decay=5e-4)
print("GCN test accuracy:", gcn_test_acc)

gat, gat_logits, gat_test_acc = train_model(gat, data, lr=0.005, weight_decay=5e-4)
print("GAT test accuracy:", gat_test_acc)

# take logits as embeddings
gcn_emb = gcn_logits.detach().cpu().numpy()
gat_emb = gat_logits.detach().cpu().numpy()
labels = data.y.cpu().numpy()


Early stopping at epoch 28
GCN test accuracy: 0.800000011920929
Early stopping at epoch 47
GAT test accuracy: 0.7820000052452087


## t-SNE + interactive Plotly toggle

In [7]:
## t-SNE + interactive Plotly toggle

from sklearn.manifold import TSNE

# Run t-SNE separately for GCN and GAT embeddings
tsne_gcn = TSNE(n_components=2, learning_rate="auto", init="random", random_state=seed)
gcn_2d = tsne_gcn.fit_transform(gcn_emb)

tsne_gat = TSNE(n_components=2, learning_rate="auto", init="random", random_state=seed)
gat_2d = tsne_gat.fit_transform(gat_emb)

class_names = [
    "Case-Based Reasoning",      # 0
    "Genetic Algorithms",        # 1
    "Neural Networks",           # 2
    "Probabilistic Methods",     # 3
    "Reinforcement Learning",    # 4
    "Rule Learning",             # 5
    "Theory",                    # 6
]

# Add topic names to the dataframes
df_gcn = pd.DataFrame({
    "x": gcn_2d[:, 0],
    "y": gcn_2d[:, 1],
    "label": labels,
    "topic": [class_names[i] for i in labels]
})

df_gat = pd.DataFrame({
    "x": gat_2d[:, 0],
    "y": gat_2d[:, 1],
    "label": labels,
    "topic": [class_names[i] for i in labels]
})

# ---- Plotly Figure ----
fig = go.Figure()

fig.add_trace(
    go.Scatter(
        x=df_gcn["x"],
        y=df_gcn["y"],
        mode="markers",
        marker=dict(
            size=5,
            color=df_gcn["label"],
            colorscale="Turbo",
            showscale=True
        ),
        name=f"GCN (test acc = {gcn_test_acc:.3f})",
        visible=True,
        text=df_gcn["topic"],   # hover label
        hovertemplate="<b>Topic:</b> %{text}<br>" +
                      "x: %{x}<br>y: %{y}<br>" +
                      "<extra></extra>"
    )
)

fig.add_trace(
    go.Scatter(
        x=df_gat["x"],
        y=df_gat["y"],
        mode="markers",
        marker=dict(
            size=5,
            color=df_gat["label"],
            colorscale="Turbo",
            showscale=True
        ),
        name=f"GAT (test acc = {gat_test_acc:.3f})",
        visible=False,
        text=df_gat["topic"],   # hover label
        hovertemplate="<b>Topic:</b> %{text}<br>" +
                      "x: %{x}<br>y: %{y}<br>" +
                      "<extra></extra>"
    )
)

fig.update_layout(
    title="t-SNE node embeddings: GCN vs GAT on Cora",
    xaxis_title="t-SNE dim 1",
    yaxis_title="t-SNE dim 2",
    width=800,
    height=600,
    updatemenus=[
        dict(
            type="buttons",
            direction="left",
            x=0.5,
            y=1.15,
            xanchor="center",
            buttons=[
                dict(
                    label="GCN",
                    method="update",
                    args=[{"visible": [True, False]},
                          {"title": "t-SNE embeddings – GCN"}],
                ),
                dict(
                    label="GAT",
                    method="update",
                    args=[{"visible": [False, True]},
                          {"title": "t-SNE embeddings – GAT"}],
                ),
            ],
        )
    ]
)

fig.show()
fig.write_html("gcn_gat_tsne_interactive.html")


In [8]:
fig.write_html("gcn_gat_tsne_interactive.html")