In [80]:
import networkx as nx
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

# PyTorch Geometric modules.
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv

# For reproducibility.
torch.manual_seed(42)
np.random.seed(42)


In [81]:
# Read your graph from a GraphML file.
graph = nx.read_graphml("../code_graphs/sdk_graph_anthropic.graphml")

# Get nodes (with attributes) and edges (with attributes).
# (Nodes in your GraphML should have an attribute 'type', and
#  edges should have an attribute 'relationship'.)
nodes = list(graph.nodes(data=True))
edges = list(graph.edges(data=True))


In [82]:
G = nx.Graph()
G.add_nodes_from(nodes)
G.add_edges_from(edges)


In [83]:
# Create a fixed ordering for nodes (needed for the PyG Data object).
node_list = list(G.nodes())
node_to_idx = {node: i for i, node in enumerate(node_list)}


In [84]:
print(node_to_idx)

{'File:/Users/tomas/graph_test/test/venv/lib/python3.13/site-packages/anthropic/_constants.py': 0, 'File:/Users/tomas/graph_test/test/venv/lib/python3.13/site-packages/anthropic/_types.py': 1, 'Class:RequestOptions@/Users/tomas/graph_test/test/venv/lib/python3.13/site-packages/anthropic/_types.py': 2, 'Class:NotGiven@/Users/tomas/graph_test/test/venv/lib/python3.13/site-packages/anthropic/_types.py': 3, 'Method:NotGiven.__bool__@/Users/tomas/graph_test/test/venv/lib/python3.13/site-packages/anthropic/_types.py': 4, 'Method:NotGiven.__repr__@/Users/tomas/graph_test/test/venv/lib/python3.13/site-packages/anthropic/_types.py': 5, 'Class:Omit@/Users/tomas/graph_test/test/venv/lib/python3.13/site-packages/anthropic/_types.py': 6, 'Method:Omit.__bool__@/Users/tomas/graph_test/test/venv/lib/python3.13/site-packages/anthropic/_types.py': 7, 'Class:ModelBuilderProtocol@/Users/tomas/graph_test/test/venv/lib/python3.13/site-packages/anthropic/_types.py': 8, 'Method:ModelBuilderProtocol.build@/Use

In [85]:
# Build the edge index list.
edge_list = []
edge_attr_list = []  # will hold the relationship type for each edge.
for u, v, data in G.edges(data=True):
    # Convert the original node names to indices.
    u_idx = node_to_idx[u]
    v_idx = node_to_idx[v]
    # Because the graph is undirected, add both (u,v) and (v,u).
    edge_list.append((u_idx, v_idx))
    edge_list.append((v_idx, u_idx))
    # Duplicate the edge attribute for the reverse edge.
    edge_attr_list.append(data['relationship'])
    edge_attr_list.append(data['relationship'])


In [86]:
# Convert the edge list into a tensor (shape [2, num_edges]).
edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()


In [87]:
edge_index.shape

torch.Size([2, 4866])

In [88]:
unique_node_types = sorted({ data['type'] for _, data in G.nodes(data=True) })
print("Unique node types:", unique_node_types)


Unique node types: ['class', 'external_class', 'external_function', 'file', 'function', 'method']


In [89]:
# Create a mapping from node type to index.
node_type_to_idx = {typ: i for i, typ in enumerate(unique_node_types)}


In [90]:
node_type_to_idx

{'class': 0,
 'external_class': 1,
 'external_function': 2,
 'file': 3,
 'function': 4,
 'method': 5}

In [91]:
node_type_indices = []
for node in node_list:
    typ = G.nodes[node]['type']
    node_type_indices.append(node_type_to_idx[typ])
node_type_indices = torch.tensor(node_type_indices, dtype=torch.long)


In [92]:
for node, typ in zip(node_list[:5], node_type_indices[:5]):
    print(f"Node: {node[:20]}.., type: {typ}")


Node: File:/Users/tomas/gr.., type: 3
Node: File:/Users/tomas/gr.., type: 3
Node: Class:RequestOptions.., type: 0
Node: Class:NotGiven@/User.., type: 0
Node: Method:NotGiven.__bo.., type: 5


In [93]:
node_emb_dim = 16

In [94]:
node_type_embedding = nn.Embedding(num_embeddings=len(unique_node_types), embedding_dim=node_emb_dim) # (6, 16)


In [95]:
node_type_embedding.weight.shape

torch.Size([6, 16])

In [96]:
x = node_type_embedding(node_type_indices)


In [97]:
node_type_indices[:5]
x[0:5]


tensor([[-0.9138, -0.6581,  0.0780,  0.5258, -0.4880,  1.1914, -0.8140, -0.7360,
         -1.4032,  0.0360, -0.0635,  0.6756, -0.0978,  1.8446, -1.1845,  1.3835],
        [-0.9138, -0.6581,  0.0780,  0.5258, -0.4880,  1.1914, -0.8140, -0.7360,
         -1.4032,  0.0360, -0.0635,  0.6756, -0.0978,  1.8446, -1.1845,  1.3835],
        [ 1.9269,  1.4873,  0.9007, -2.1055,  0.6784, -1.2345, -0.0431, -1.6047,
         -0.7521,  1.6487, -0.3925, -1.4036, -0.7279, -0.5594, -0.7688,  0.7624],
        [ 1.9269,  1.4873,  0.9007, -2.1055,  0.6784, -1.2345, -0.0431, -1.6047,
         -0.7521,  1.6487, -0.3925, -1.4036, -0.7279, -0.5594, -0.7688,  0.7624],
        [ 0.0109, -0.3387, -1.3407, -0.5854,  0.5362,  0.5246,  1.1412,  0.0516,
          0.7440, -0.4816, -1.0495,  0.6039, -1.7223, -0.8278,  1.3347,  0.4835]],
       grad_fn=<SliceBackward0>)

In [98]:
unique_edge_relationships = sorted(set(edge_attr_list))
print("Unique edge relationships:", unique_edge_relationships)


Unique edge relationships: ['calls', 'contains', 'inherits']


In [99]:
# Create a mapping from relationship to index.
edge_type_to_idx = {rel: i for i, rel in enumerate(unique_edge_relationships)}


In [100]:
edge_type_indices = [edge_type_to_idx[rel] for rel in edge_attr_list]
edge_type_indices = torch.tensor(edge_type_indices, dtype=torch.long)


In [101]:
edge_emb_dim = node_emb_dim # This is just for simplicity right now should definlty not be the same as node_emb_dim


In [102]:
edge_type_embedding = nn.Embedding(num_embeddings=len(unique_edge_relationships), embedding_dim=edge_emb_dim)


In [103]:
edge_type_embedding.weight.shape

torch.Size([3, 16])

In [104]:
edge_attr= edge_type_embedding(edge_type_indices)

In [105]:
print(edge_type_indices[0:5])
print(edge_attr[0:5])


tensor([1, 1, 1, 1, 1])
tensor([[-1.4570, -0.1023, -0.5992,  0.4771,  0.7262,  0.0912, -0.3891,  0.5279,
         -0.0127,  0.2408,  0.1325,  0.7642,  1.0950,  0.3399,  0.7200,  0.4114],
        [-1.4570, -0.1023, -0.5992,  0.4771,  0.7262,  0.0912, -0.3891,  0.5279,
         -0.0127,  0.2408,  0.1325,  0.7642,  1.0950,  0.3399,  0.7200,  0.4114],
        [-1.4570, -0.1023, -0.5992,  0.4771,  0.7262,  0.0912, -0.3891,  0.5279,
         -0.0127,  0.2408,  0.1325,  0.7642,  1.0950,  0.3399,  0.7200,  0.4114],
        [-1.4570, -0.1023, -0.5992,  0.4771,  0.7262,  0.0912, -0.3891,  0.5279,
         -0.0127,  0.2408,  0.1325,  0.7642,  1.0950,  0.3399,  0.7200,  0.4114],
        [-1.4570, -0.1023, -0.5992,  0.4771,  0.7262,  0.0912, -0.3891,  0.5279,
         -0.0127,  0.2408,  0.1325,  0.7642,  1.0950,  0.3399,  0.7200,  0.4114]],
       grad_fn=<SliceBackward0>)


In [106]:
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
print(data)

Data(x=[1284, 16], edge_index=[2, 4866], edge_attr=[4866, 16])


In [107]:
class SimpleGCN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        """
        A simple two-layer GCN.
        """
        super(SimpleGCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)
        
    def forward(self, x, edge_index, edge_attr=None):
        # Here edge_attr is not used. To incorporate it,
        # consider using a custom MessagePassing layer.
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x


In [108]:
# Instantiate the model.
model = SimpleGCN(in_channels=node_emb_dim, hidden_channels=32, out_channels=2)


In [109]:
def sample_contrastive_pairs(edge_index, num_nodes, num_negatives=3):
    """
    For each node, sample one positive (neighbor) and a few negatives (non-neighbors).
    """
    neighbors = {i: set() for i in range(num_nodes)}
    edge_index_np = edge_index.cpu().numpy()
    for src, dst in zip(edge_index_np[0], edge_index_np[1]):
        neighbors[src].add(dst)
    
    anchors, positives, negatives = [], [], []
    for i in range(num_nodes):
        if len(neighbors[i]) == 0:
            continue  # skip isolated nodes
        pos_candidates = list(neighbors[i])
        pos_sample = pos_candidates[torch.randint(len(pos_candidates), (1,)).item()]
        neg_candidates = list(set(range(num_nodes)) - neighbors[i] - {i})
        if len(neg_candidates) >= num_negatives:
            neg_sample = [neg_candidates[idx] for idx in torch.randperm(len(neg_candidates))[:num_negatives]]
            anchors.append(i)
            positives.append(pos_sample)
            negatives.append(neg_sample)
    return anchors, positives, negatives



In [110]:
def contrastive_loss(embeddings, anchors, positives, negatives, margin=0.5):
    """
    For each anchor node, the loss encourages the cosine similarity with a positive
    (neighbor) to be higher than with negatives (non-neighbors) by at least the margin.
    """
    loss_all = []
    for i, pos_idx, neg_idxs in zip(anchors, positives, negatives):
        anchor_emb = embeddings[i]
        pos_emb = embeddings[pos_idx]
        pos_sim = F.cosine_similarity(anchor_emb.unsqueeze(0), pos_emb.unsqueeze(0))
        neg_embs = embeddings[neg_idxs]
        neg_sim = F.cosine_similarity(anchor_emb.unsqueeze(0), neg_embs)
        loss_per_negative = F.relu(margin - (pos_sim - neg_sim))
        loss_all.append(loss_per_negative.mean())
    if loss_all:
        return torch.stack(loss_all).mean()
    else:
        return torch.tensor(0.0)

In [111]:
optimizer = torch.optim.AdamW(model.parameters(), lr=0.005)
num_epochs = 1001

In [112]:
# Training loop.
model.train()
for epoch in range(num_epochs):
    optimizer.zero_grad()
    embeddings = model(data.x, data.edge_index, data.edge_attr).detach()
    
    anchors, positives, negatives = sample_contrastive_pairs(data.edge_index, data.num_nodes, num_negatives=30)
    loss = contrastive_loss(embeddings, anchors, positives, negatives, margin=0.5)
    
    loss.backward()
    optimizer.step()
    
    if epoch % 20 == 0:
        print(f"Epoch {epoch:03d}, Loss: {loss.item():.4f}")


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn