In [37]:
import os
import torch
print(torch.__version__)

2.4.0+cu121


In [2]:
if 'IS_GRADESCOPE_ENV' not in os.environ:
  !pip install torch==2.4.0



In [3]:
# Install torch geometric
if 'IS_GRADESCOPE_ENV' not in os.environ:
  torch_version = str(torch.__version__)
  scatter_src = f"https://pytorch-geometric.com/whl/torch-{torch_version}.html"
  sparse_src = f"https://pytorch-geometric.com/whl/torch-{torch_version}.html"
  !pip install torch-scatter -f $scatter_src
  !pip install torch-sparse -f $sparse_src
  !pip install torch-geometric
  !pip install ogb

Looking in links: https://pytorch-geometric.com/whl/torch-2.4.0+cu121.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-2.4.0%2Bcu121/torch_scatter-2.1.2%2Bpt24cu121-cp311-cp311-linux_x86_64.whl (10.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.9/10.9 MB[0m [31m85.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-scatter
Successfully installed torch-scatter-2.1.2+pt24cu121
Looking in links: https://pytorch-geometric.com/whl/torch-2.4.0+cu121.html
Collecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-2.4.0%2Bcu121/torch_sparse-0.6.18%2Bpt24cu121-cp311-cp311-linux_x86_64.whl (5.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.1/5.1 MB[0m [31m34.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch-sparse
Successfully installed torch-sparse-0.6.18+pt24cu121
Collecting torch-geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K

In [38]:
import torch
import torch_geometric
from torch_geometric.datasets import Planetoid


In [39]:
dataset = Planetoid(root="/dataset", name='Cora')

In [40]:
data = dataset[0]

In [41]:
import random
random.seed(8)
index = random.sample(range(0, 2707), 270)
index = sorted(index)
index[:10] # Index of nodes removed from subgraph

[12, 18, 27, 48, 54, 69, 81, 92, 104, 124]

In [42]:
mask = torch.ones(len(data.x), dtype=bool)
mask[index] = False

In [43]:
mask.nonzero(as_tuple=False).squeeze() # Indexes of Nodes which would be in subgraph

tensor([   0,    1,    2,  ..., 2704, 2705, 2707])

In [44]:
original_to_relabel_map = {original_idx.item(): relabel_idx
                           for relabel_idx, original_idx in enumerate(mask.nonzero(as_tuple=False).squeeze())}

In [45]:
# creating subgraph
subgraph = data.clone()

subgraph.x = data.x[mask]
subgraph.y = data.y[mask]

subgraph.train_mask = data.train_mask[mask]
subgraph.test_mask = data.test_mask[mask]
subgraph.val_mask = data.val_mask[mask]

subgraph.edge_index, _ = torch_geometric.utils.subgraph(mask, data.edge_index, relabel_nodes=True)


In [46]:
subgraph, subgraph.edge_index

(Data(x=[2438, 1433], edge_index=[2, 8478], y=[2438], train_mask=[2438], val_mask=[2438], test_mask=[2438]),
 tensor([[ 570, 1679, 2326,  ...,  151,  540, 1336],
         [   0,    0,    0,  ..., 2437, 2437, 2437]]))

In [13]:
removed_edges = [] # contains (x, y) these are the edge which were removed during the subgraph formation. Serves as the final test case to predict these edges.

for i in index:
  for j in range(10556):
    if i == tuple(data.edge_index.t()[j].tolist())[0]:
      removed_edges.append(tuple(data.edge_index.t()[j].tolist()))
      #print(tuple(data.edge_index.t()[j].tolist()))


In [47]:
from torch_geometric.utils import negative_sampling

In [48]:
def Labels(data):
  neg_edges = negative_sampling(data.edge_index, num_nodes=len(subgraph.x))
  pos_edges = data.edge_index

  combined_edges = torch.cat((pos_edges, neg_edges), dim=1)
  labels = torch.cat((torch.ones(pos_edges.size(1)), torch.zeros(neg_edges.size(1))), dim=0)

  return combined_edges, labels


In [49]:
combined_edges, labels = Labels(subgraph)

In [50]:
from torch_geometric.nn import SAGEConv
import torch.nn.functional as f

In [52]:
class EdgePrediction(torch.nn.Module):
  def __init__(self, hidden_dim, input_dim= dataset.num_features, output_dim = 1):
    super().__init__()
    self.conv1 = SAGEConv(input_dim, hidden_dim)
    self.conv2 = SAGEConv(hidden_dim, input_dim)
    #self.linear1 = torch.nn.Linear(input_dim *2, 128)

    #self.linear2 = torch.nn.Linear(128, 32)
    #self.output_layer = torch.nn.Linear(32, output_dim)

  def forward(self, x: torch.tensor, edge_list: torch.tensor):
    x = self.conv1(x, edge_list)
    x = f.relu(x)
    x = self.conv2(x, edge_list)
    return x

  #def predict(self, x: torch.tensor, edge_list: torch.tensor):
    #embed = torch.cat([x[edge_list[0]], x[edge_list[1]]], dim=-1)
    #out = self.linear1(embed)
    #out = self.linear2(out)

    #return self.output_layer(out).squeeze()


In [53]:
model = EdgePrediction(128)
model

EdgePrediction(
  (conv1): SAGEConv(1433, 128, aggr=mean)
  (conv2): SAGEConv(128, 1433, aggr=mean)
)

In [54]:
optimizer = torch.optim.Adam(model.parameters(), lr = 0.01)
criterion = torch.nn.BCEWithLogitsLoss()

In [55]:
def accuracy(pred, label):
  accu = 0.0

  pred = (pred > 0.5).float()
  count = (pred == label).sum().item()
  accu = count / len(pred)
  accu = round(accu, 4)
  return accu

def train(x, edge_index, combined_edges, labels):
    model.train()
    for epoch in range(200):
        optimizer.zero_grad()

        # Get node embeddings
        node_emb = model(x, edge_index)

        # Extract source and destination node embeddings
        src = combined_edges[0]
        dst = combined_edges[1]
        src_emb = node_emb[src]
        dst_emb = node_emb[dst]

        # Compute dot product for edge prediction
        dot_prd = (src_emb * dst_emb).sum(dim=1)
        pred = torch.sigmoid(dot_prd)

        # Compute loss
        loss = criterion(pred, labels)

        # Backpropagation
        loss.backward()
        optimizer.step()

        # Compute accuracy
        Accuracy = accuracy(pred, labels)

        if epoch % 1 == 0 or epoch == 199:
            print(f"Epoch {epoch + 1}/200 - Loss: {loss.item():.4f}, Accuracy: {Accuracy}")

    return node_emb



In [56]:
emb = train(subgraph.x, subgraph.edge_index, combined_edges, labels)

Epoch 1/200 - Loss: 0.8112, Accuracy: 0.5
Epoch 2/200 - Loss: 0.8098, Accuracy: 0.5023
Epoch 3/200 - Loss: 0.7921, Accuracy: 0.5364
Epoch 4/200 - Loss: 0.7835, Accuracy: 0.5472
Epoch 5/200 - Loss: 0.7141, Accuracy: 0.6182
Epoch 6/200 - Loss: 0.7195, Accuracy: 0.6456
Epoch 7/200 - Loss: 0.6798, Accuracy: 0.7174
Epoch 8/200 - Loss: 0.6782, Accuracy: 0.7091
Epoch 9/200 - Loss: 0.6660, Accuracy: 0.7301
Epoch 10/200 - Loss: 0.6588, Accuracy: 0.7483
Epoch 11/200 - Loss: 0.6488, Accuracy: 0.76
Epoch 12/200 - Loss: 0.6416, Accuracy: 0.7706
Epoch 13/200 - Loss: 0.6348, Accuracy: 0.7816
Epoch 14/200 - Loss: 0.6313, Accuracy: 0.7879
Epoch 15/200 - Loss: 0.6247, Accuracy: 0.7982
Epoch 16/200 - Loss: 0.6198, Accuracy: 0.8066
Epoch 17/200 - Loss: 0.6147, Accuracy: 0.8166
Epoch 18/200 - Loss: 0.6096, Accuracy: 0.8267
Epoch 19/200 - Loss: 0.6055, Accuracy: 0.8335
Epoch 20/200 - Loss: 0.6020, Accuracy: 0.8399
Epoch 21/200 - Loss: 0.5986, Accuracy: 0.8452
Epoch 22/200 - Loss: 0.5958, Accuracy: 0.8496
Ep

In [57]:
e = emb
e.size()

torch.Size([2438, 1433])

In [58]:
modified_data = data.clone()


In [59]:
for original_index, relabel_index in original_to_relabel_map.items():
  modified_data.x[original_index] = e[relabel_index]

modified_data

Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])

In [60]:
modified_data.x[11], modified_data.x[12]

(tensor([ 0.0548,  0.0856,  0.0040,  ..., -0.0575,  0.0327, -0.0348],
        grad_fn=<SelectBackward0>),
 tensor([0., 0., 0.,  ..., 0., 0., 0.], grad_fn=<SelectBackward0>))

In [61]:
removed_edges_tensor = torch.tensor(removed_edges).T
removed_edges_tensor.size(), removed_edges_tensor

(torch.Size([2, 1106]),
 tensor([[  12,   12,   12,  ..., 2706, 2706, 2706],
         [1001, 1318, 2661,  ...,  169, 1473, 2707]]))

In [62]:
# To test on all the edges
src = modified_data.edge_index[0]
dst = modified_data.edge_index[1]
src_emb = modified_data.x[src]
dst_emb = modified_data.x[dst]

# Compute dot product for edge prediction
dot_prd = (src_emb * dst_emb).sum(dim=1)
pred = torch.sigmoid(dot_prd).round()

Test_Accuracy_overall = float((pred == 1).sum()/ len(pred))
Test_Accuracy_overall

0.8946570754051208

In [63]:
# To test on all the removed edges
src = removed_edges_tensor[0]
dst = removed_edges_tensor[1]
src_emb = modified_data.x[src]
dst_emb = modified_data.x[dst]

# Compute dot product for edge prediction
dot_prd = (src_emb * dst_emb).sum(dim=1)
pred = torch.sigmoid(dot_prd).round()

Test_Accuracy_overall_removed = float((pred == 1).sum()/ len(pred))
Test_Accuracy_overall_removed

0.5081374049186707

In [64]:
src_nodes = data.edge_index[0]
dest_nodes = data.edge_index[1]

degree_map = {} # Degree of each of the nodes sampled: key-> node index; value-> degree
for i in index:
  count = 0
  for j in src_nodes:
    if i == j:
      count = count + 1
  degree_map[i] = count

degree_sampled = list(set(degree_map.values()))
degree_sampled

[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 16, 17, 18, 22, 34, 42, 65]

In [64]:
list(degree_map.items())[:5]

[(12, 4), (18, 5), (27, 4), (48, 9), (54, 2)]

In [65]:
d = {} # Dictionary: key-> unique node degree; values-> list(nodes indexes which has degree as the key value)
for i in degree_sampled:
  l = []
  for j in degree_map.keys():
    if(degree_map[j] == i):
      l.append(j)
  d[i] = l


In [84]:
degree_based_tuples = []

for degree, nodes in d.items():
    degree_tuples = []
    for node in nodes:
        for edge in removed_edges:
            if edge[0] == node:
                degree_tuples.append(edge)
    degree_based_tuples.append(degree_tuples)

In [104]:
for i in degree_based_tuples:
  print(len(i))

44
100
192
192
135
60
49
56
18
10
11
12
13
16
17
18
22
34
42
65


In [92]:
l = list(d.keys())
count = 0
for i in degree_based_tuples:
  removed_edges_tensor = torch.tensor(i).T
  src = removed_edges_tensor[0]
  dst = removed_edges_tensor[1]
  src_emb = modified_data.x[src]
  dst_emb = modified_data.x[dst]

  # Compute dot product for edge prediction
  dot_prd = (src_emb * dst_emb).sum(dim=1)
  pred = torch.sigmoid(dot_prd).round()

  Test_Accuracy_removed = float((pred == 1).sum()/ len(pred))
  print("Degree {}".format(l[count]),Test_Accuracy_removed)
  count = count+1




Degree 1 0.5227272510528564
Degree 2 0.5799999833106995
Degree 3 0.5520833134651184
Degree 4 0.40625
Degree 5 0.4962962865829468
Degree 6 0.550000011920929
Degree 7 0.6530612111091614
Degree 8 0.4107142984867096
Degree 9 0.6111111044883728
Degree 10 0.30000001192092896
Degree 11 0.6363636255264282
Degree 12 0.3333333432674408
Degree 13 0.7692307829856873
Degree 16 1.0
Degree 17 0.29411765933036804
Degree 18 0.5555555820465088
Degree 22 0.3636363744735718
Degree 34 0.6764705777168274
Degree 42 0.2380952388048172
Degree 65 0.5384615659713745
