# Part A — Unweighted GCN for Link Prediction

## Exercise 1 
Build the feature matrix and the unweighted GCN. Train it with an edge reconstruction (BCE) objective for a small number of epochs. Report the training loss every ~500 epochs.

Solution 1

In [None]:
# Imports and feature construction
import torch
import torch.nn.functional as F
from torch import nn
import networkx as nx

try:
  from torch_geometric.nn import GCNConv
  from torch_geometric.utils import to_undirected, negative_sampling
except Exception as e:
  raise RuntimeError("torch_geometric is required. See install note at the end.") from e

torch.manual_seed(42)

# Load graph
G = nx.karate_club_graph()
N = G.number_of_nodes()  # 34

# Club as categorical index (0 = 'Mr. Hi', 1 = 'Officer')
club_idx = torch.tensor([0 if G.nodes[i]['club'] == 'Mr. Hi' else 1 for i in range(N)], dtype=torch.long)  # [N]

# Scalar features: degree (normalized), betweenness, closeness
deg_vals = torch.tensor([G.degree(i) for i in range(N)], dtype=torch.float32)
deg_vals = deg_vals / deg_vals.max()
betw_dict = nx.betweenness_centrality(G, normalized=True)
betw_vals = torch.tensor([betw_dict[i] for i in range(N)], dtype=torch.float32)
close_dict = nx.closeness_centrality(G)
close_vals = torch.tensor([close_dict[i] for i in range(N)], dtype=torch.float32)
scalars = torch.stack([deg_vals, betw_vals, close_vals], dim=1)  # [N, 3]

# Undirected edge index (PyG format)
edges = torch.tensor(list(G.edges()), dtype=torch.long).t().contiguous()  # [2, E]
edge_index = to_undirected(edges, num_nodes=N)


# Model: 10-d projection + 2-layer GCN; return concat of conv1/conv2 outputs (dim=20)
class KarateGCN(nn.Module):
  def __init__(self, num_clubs=2, club_emb_dim=4, scalar_dim=3, hidden_dim=10):
    super().__init__()
    self.club_emb = nn.Embedding(num_clubs, club_emb_dim)
    self.lin = nn.Linear(club_emb_dim + scalar_dim, hidden_dim)
    self.conv1 = GCNConv(hidden_dim, hidden_dim, add_self_loops=True)
    self.conv2 = GCNConv(hidden_dim, hidden_dim, add_self_loops=True)

  def forward(self, club_idx, scalars, edge_index):
    x = torch.cat([self.club_emb(club_idx), scalars], dim=1)  # [N, 7]
    x = F.relu(self.lin(x))  # [N, 10]
    x = self.conv1(x, edge_index)  # [N, 10]
    y = F.relu(x)
    y = self.conv2(y, edge_index)  # [N, 10]
    z = torch.cat([x, y], dim=1)  # [N, 20]
    return z


def edge_recon_loss(z, edge_index, num_nodes):
  # Dot-product decoder with negative sampling
  src, dst = edge_index
  pos_logits = (z[src] * z[dst]).sum(dim=1)
  pos_labels = torch.ones_like(pos_logits)
  pos_loss = F.binary_cross_entropy_with_logits(pos_logits, pos_labels)

  neg_edge_index = negative_sampling(edge_index, num_nodes=num_nodes, num_neg_samples=src.size(0))
  ns, nd = neg_edge_index
  neg_logits = (z[ns] * z[nd]).sum(dim=1)
  neg_labels = torch.zeros_like(neg_logits)
  neg_loss = F.binary_cross_entropy_with_logits(neg_logits, neg_labels)

  return pos_loss + neg_loss


model = KarateGCN()
opt = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

for epoch in range(10000):  # small for speed
  model.train()
  opt.zero_grad()
  z = model(club_idx, scalars, edge_index)
  loss = edge_recon_loss(z, edge_index, N)
  loss.backward()
  opt.step()
  if epoch % 500 == 0:
    print(f"epoch {epoch:4d} | loss {loss.item():.4f}")

## Exercise 2
Evaluate link prediction quality by scoring all unordered node pairs using the dot‑product decoder. Report accuracy at a threshold of 0.5 and the mean log loss.

Solution 2

In [None]:
import math

model.eval()
with torch.no_grad():
  z_eval = model(club_idx, scalars, edge_index)

threshold = 0.5
TP = TN = FP = FN = 0
pos_probs, neg_probs = [], []
logloss = 0.0
count = 0

for i in range(N):
  for j in range(i + 1, N):
    y = 1 if G.has_edge(i, j) else 0
    prob = torch.sigmoid((z_eval[i] * z_eval[j]).sum()).item()
    yhat = 1 if prob >= threshold else 0

    if y == 1 and yhat == 1:
      TP += 1
    elif y == 0 and yhat == 0:
      TN += 1
    elif y == 0 and yhat == 1:
      FP += 1
    else:
      FN += 1

    (pos_probs if y == 1 else neg_probs).append(prob)
    p = min(max(prob, 1e-12), 1 - 1e-12)
    logloss += -(y * math.log(p) + (1 - y) * math.log(1 - p))
    count += 1

accuracy = (TP + TN) / count
avg_pos_p = sum(pos_probs) / len(pos_probs)
avg_neg_p = sum(neg_probs) / len(neg_probs)
logloss /= count

print("Pairs:", count)
print(f"Pos: {len(pos_probs)} | Neg: {len(neg_probs)}")
print(f"Accuracy @ {threshold:.2f}: {accuracy:.4f}")
print(f"TP: {TP}  TN: {TN}  FP: {FP}  FN: {FN}")
print(f"Avg P(edge) on edges:     {avg_pos_p:.4f}")
print(f"Avg P(edge) on non-edges: {avg_neg_p:.4f}")
print(f"Mean log loss: {logloss:.4f}")

# Exercise 3
Visualize the spring layout with true edges (black) and non‑edges shaded by predicted probability for pairs with probability ≥ 0.5.

Solution 3

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.lines import Line2D

model.eval()
with torch.no_grad():
  z_vis = model(club_idx, scalars, edge_index)

pos = nx.spring_layout(G, seed=42)
non_edges, probs = [], []
for i in range(N):
  for j in range(i + 1, N):
    if not G.has_edge(i, j):
      p = torch.sigmoid((z_vis[i] * z_vis[j]).sum()).item()
      non_edges.append((i, j))
      probs.append(p)

threshold = 0.5
edges_probs = [(e, p) for e, p in zip(non_edges, probs) if p >= threshold]
non_edges_draw = [e for e, _ in edges_probs]
probs_draw = [p for _, p in edges_probs]

node_colors = ['tab:blue' if int(club_idx[i].item()) == 0 else 'tab:orange' for i in range(N)]

H = nx.Graph()
H.add_nodes_from(G.nodes())
H.add_edges_from(non_edges_draw)

fig, ax = plt.subplots(figsize=(8, 6))

nx.draw_networkx_edges(
  H, pos, edgelist=non_edges_draw, edge_color=probs_draw,
  edge_cmap=plt.cm.Greys, edge_vmin=0.5, edge_vmax=1.0, width=0.7, alpha=0.9, ax=ax,
)
nx.draw_networkx_edges(G, pos, edgelist=list(G.edges()), edge_color="black", width=1.6, ax=ax)
nx.draw_networkx_nodes(G, pos, node_color=node_colors, edgecolors="black", linewidths=0.5, node_size=180, ax=ax)

sm = mpl.cm.ScalarMappable(cmap=plt.cm.Greys, norm=plt.Normalize(vmin=0.5, vmax=1))
sm.set_array([])
cbar = plt.colorbar(sm, ax=ax, fraction=0.046, pad=0.04)
cbar.set_label("Predicted P(edge) for non-edges (≥ 0.5)")

legend_elems = [
  Line2D([0], [0], marker='o', color='w', label='Mr. Hi',
       markerfacecolor='tab:blue', markeredgecolor='black', markeredgewidth=0.5, markersize=8),
  Line2D([0], [0], marker='o', color='w', label='Officer',
       markerfacecolor='tab:orange', markeredgecolor='black', markeredgewidth=0.5, markersize=8)
]
ax.legend(handles=legend_elems, loc='best', frameon=False)
ax.set_title("Zachary's Karate Club: real edges (black) + non-edges shaded by predicted probability (≥ 0.5)")
ax.axis("off")
plt.tight_layout()
plt.show()