# Weighted edges variant

Below we repeat the pipeline using a weighted graph where each existing edge has an attribute weight. We pass edge_weight into the GCN layers and compute weighted variants of centralities. The decoder remains the dot‑product.

## Exercise 4 
Implement the weighted GCN variant. Use unit weights (1.0) on all edges for reproducibility and show the link prediction evaluation and visualization.

In [None]:
# Build weighted graph and features (ensure weights exist)
Gw = G.copy()
for u, v in Gw.edges():
  Gw[u][v]['weight'] = 1.0  # set simple unit weights
  Gw[u][v]['inv_weight'] = 1.0 / max(Gw[u][v]['weight'], 1e-12)

deg_w_vals = torch.tensor([Gw.degree(i, weight='weight') for i in range(N)], dtype=torch.float32)
if deg_w_vals.max() > 0:
  deg_w_vals = deg_w_vals / deg_w_vals.max()

betw_w_dict = nx.betweenness_centrality(Gw, weight='inv_weight', normalized=True)
betw_w_vals = torch.tensor([betw_w_dict[i] for i in range(N)], dtype=torch.float32)
close_w_dict = nx.closeness_centrality(Gw, distance='inv_weight')
close_w_vals = torch.tensor([close_w_dict[i] for i in range(N)], dtype=torch.float32)
scalars_w = torch.stack([deg_w_vals, betw_w_vals, close_w_vals], dim=1)

edges_w = torch.tensor(list(Gw.edges()), dtype=torch.long).t().contiguous()
edge_index_w = to_undirected(edges_w, num_nodes=N)
edge_weight = torch.tensor([Gw[u][v]['weight'] for u, v in Gw.edges()], dtype=torch.float32)
edge_weight_w = torch.cat([edge_weight, edge_weight], dim=0)


class WeightedKarateGCN(nn.Module):
  def __init__(self, num_clubs=2, club_emb_dim=4, scalar_dim=3, hidden_dim=10, dropout=0.0):
    super().__init__()
    self.club_emb = nn.Embedding(num_clubs, club_emb_dim)
    self.lin = nn.Linear(club_emb_dim + scalar_dim, hidden_dim)
    self.conv1 = GCNConv(hidden_dim, hidden_dim, add_self_loops=True)
    self.conv2 = GCNConv(hidden_dim, hidden_dim, add_self_loops=True)
    self.dropout = nn.Dropout(dropout)

  def forward(self, club_idx, scalars, edge_index, edge_weight):
    x = torch.cat([self.club_emb(club_idx), scalars], dim=1)
    x = F.relu(self.lin(x))
    x = self.dropout(x)
    x = self.conv1(x, edge_index, edge_weight=edge_weight)
    y = F.relu(x)
    y = self.conv2(y, edge_index, edge_weight=edge_weight)
    y = self.dropout(y)
    z = torch.cat([x, y], dim=1)
    return z


model_w = WeightedKarateGCN()
opt_w = torch.optim.Adam(model_w.parameters(), lr=0.01, weight_decay=5e-4)

for epoch in range(10000):
  model_w.train()
  opt_w.zero_grad()
  z_w = model_w(club_idx, scalars_w, edge_index_w, edge_weight_w)
  loss_w = edge_recon_loss(z_w, edge_index_w, N)
  loss_w.backward()
  opt_w.step()
  if epoch % 500 == 0:
    print(f"[Weighted] epoch {epoch:4d} | loss {loss_w.item():.4f}")

# Evaluation
model_w.eval()
with torch.no_grad():
  z_eval_w = model_w(club_idx, scalars_w, edge_index_w, edge_weight_w)

threshold = 0.5
TP = TN = FP = FN = 0
pos_probs, neg_probs = [], []
logloss = 0.0
count = 0

for i in range(N):
  for j in range(i + 1, N):
    y = 1 if Gw.has_edge(i, j) else 0
    prob = torch.sigmoid((z_eval_w[i] * z_eval_w[j]).sum()).item()
    yhat = 1 if prob >= threshold else 0
    if y == 1 and yhat == 1:
      TP += 1
    elif y == 0 and yhat == 0:
      TN += 1
    elif y == 0 and yhat == 1:
      FP += 1
    else:
      FN += 1
    (pos_probs if y == 1 else neg_probs).append(prob)
    p = min(max(prob, 1e-12), 1 - 1e-12)
    logloss += -(y * math.log(p) + (1 - y) * math.log(1 - p))
    count += 1

accuracy = (TP + TN) / count
avg_pos_p = sum(pos_probs) / len(pos_probs)
avg_neg_p = sum(neg_probs) / len(neg_probs)
logloss /= count

print("[Weighted] Pairs:", count)
print(f"[Weighted] Pos: {len(pos_probs)} | Neg: {len(neg_probs)}")
print(f"[Weighted] Accuracy @ {threshold:.2f}: {accuracy:.4f}")
print(f"[Weighted] TP: {TP}  TN: {TN}  FP: {FP}  FN: {FN}")
print(f"[Weighted] Avg P(edge) on edges:     {avg_pos_p:.4f}")
print(f"[Weighted] Avg P(edge) on non-edges: {avg_neg_p:.4f}")
print(f"[Weighted] Mean log loss: {logloss:.4f}")

# Visualization
pos_w = pos # nx.spring_layout(Gw, seed=42)
non_edges_w, probs_w = [], []
for i in range(N):
  for j in range(i + 1, N):
    if not Gw.has_edge(i, j):
      p = torch.sigmoid((z_eval_w[i] * z_eval_w[j]).sum()).item()
      non_edges_w.append((i, j))
      probs_w.append(p)

edges_probs_w = [(e, p) for e, p in zip(non_edges_w, probs_w) if p >= 0.5]
non_edges_w_draw = [e for e, _ in edges_probs_w]
probs_w_draw = [p for _, p in edges_probs_w]

node_colors_w = ['tab:blue' if int(club_idx[i].item()) == 0 else 'tab:orange' for i in range(N)]
Hw = nx.Graph(); Hw.add_nodes_from(Gw.nodes()); Hw.add_edges_from(non_edges_w_draw)

fig, ax = plt.subplots(figsize=(8, 6))
nx.draw_networkx_edges(Hw, pos_w, edgelist=non_edges_w_draw, edge_color=probs_w_draw,
             edge_cmap=plt.cm.Greys, edge_vmin=0.5, edge_vmax=1.0, width=0.7, alpha=0.9, ax=ax)
nx.draw_networkx_edges(Gw, pos_w, edgelist=list(Gw.edges()), edge_color="black", width=1.6, ax=ax)
nx.draw_networkx_nodes(Gw, pos_w, node_color=node_colors_w, edgecolors="black", linewidths=0.5, node_size=180, ax=ax)
sm = mpl.cm.ScalarMappable(cmap=plt.cm.Greys, norm=plt.Normalize(vmin=0.5, vmax=1))
sm.set_array([])
cbar = plt.colorbar(sm, ax=ax, fraction=0.046, pad=0.04)
cbar.set_label("Predicted P(edge) for non-edges (≥ 0.5, weighted model)")
ax.legend(handles=[
  Line2D([0], [0], marker='o', color='w', label='Mr. Hi',
       markerfacecolor='tab:blue', markeredgecolor='black', markeredgewidth=0.5, markersize=8),
  Line2D([0], [0], marker='o', color='w', label='Officer',
       markerfacecolor='tab:orange', markeredgecolor='black', markeredgewidth=0.5, markersize=8)
], loc='best', frameon=False)
ax.set_title("Weighted GCN: real edges (black) + non-edges shaded by predicted probability (≥ 0.5)")
ax.axis("off")
plt.tight_layout(); plt.show()