In [176]:
import torch
import torch_geometric as pyg
import torch.nn as nn
import torch_geometric.nn as gnn
from torch_geometric.datasets import Amazon
from torch_geometric.utils import train_test_split_edges, add_self_loops
from torch_geometric.transforms import RandomLinkSplit
import torch.optim as optim

In [177]:
amazon = Amazon('amazon', 'Computers')
data = amazon[0]

  if osp.exists(f) and torch.load(f) != _repr(self.pre_transform):
  if osp.exists(f) and torch.load(f) != _repr(self.pre_filter):
  return torch.load(f, map_location)


In [178]:
data

Data(x=[13752, 767], edge_index=[2, 491722], y=[13752])

In [179]:
train_data, val_data, test_data = RandomLinkSplit(num_val=0.1, num_test=0.1, add_negative_train_samples=True)(data)

In [180]:
train_data

Data(x=[13752, 767], edge_index=[2, 393378], y=[13752], edge_label=[786756], edge_label_index=[2, 786756])

In [181]:
features = torch.tensor([
    [1, 1, 1],
    [2, 2, 2],
    [3, 3, 3]
], dtype=torch.float)

edge_index = torch.tensor([
    [0, 1, 2],
    [2, 0, 1]
])

neg_edge_index = torch.tensor([
    [1, 0, 2],
    [2, 1, 0]
])

In [182]:
class GINLayer(gnn.MessagePassing):
    def __init__(self, in_features: int, hidden_features: int, out_features: int, dropout: float):
        super(GINLayer, self).__init__(aggr='add')
        self.mlp_f = nn.Sequential(
            nn.Linear(in_features, hidden_features),
            nn.ReLU(),
            nn.Linear(hidden_features, hidden_features)
        )
        self.mlp_o = nn.Sequential(
            nn.Linear(hidden_features, hidden_features),
            nn.ReLU(),
            nn.Linear(hidden_features, out_features)
        )
        self.eps = nn.Parameter(torch.rand(size=(1,)))
        self.skip = nn.Linear(hidden_features, out_features)
        self.norm = nn.LayerNorm(out_features)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor):
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.shape[0])
        out = self.propagate(edge_index=edge_index, x=x) # (num_nodes, hidden_features)
        skip_out = out
        eps = (1 + self.eps)
        out = eps * self.mlp_f(x) + out
        out = self.mlp_o(out) + self.skip(skip_out) # (num_nodes, out_features)
        out = self.dropout(out)
        out = self.norm(out)
        return out # (num_nodes, out_features)

    def message(self, x_j: torch.Tensor):
        return self.mlp_f(x_j) # (num_nodes, hidden_features)


In [191]:
gin = GINLayer(features.shape[-1], 16, 4, 0.2)
gin_out = gin(train_data.x, train_data.edge_index)
gin_out.shape

RuntimeError: mat1 and mat2 shapes cannot be multiplied (407130x767 and 3x16)

In [184]:
class MyGNN(nn.Module):
    def __init__(self, in_features: int, hidden_features: int, out_features: int, num_layers: int, dropout: float):
        super(MyGNN, self).__init__()
        self.dropout = dropout
        self.layers = nn.ModuleList([])
        self.layers.append(GINLayer(in_features, hidden_features, hidden_features, self.dropout))
        for _ in range(num_layers - 1):
            self.layers.append(GINLayer(hidden_features, hidden_features, hidden_features, self.dropout))
        self.link_predictor = nn.Sequential(
            nn.Linear(2 * hidden_features, hidden_features),
            nn.ReLU(),
            nn.Linear(hidden_features, out_features)        
        )
        

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, negative_edges: torch.Tensor):
        for layer in self.layers:
            x = layer(x, edge_index) # (num_nodes, hidden_features)
        x_cat_pos = self.__combine_node_embeddings(edge_index, x) # (num_nodes, 2 * hidden_features)
        x_cat_neg = self.__combine_node_embeddings(negative_edges, x) # (num_nodes, 2 * hidden_features)
        x_pos = self.link_predictor(x_cat_pos) # (num_nodes, out_features)
        x_neg = self.link_predictor(x_cat_neg) # (num_nodes, out_features)
        return x_pos, x_neg
    
    def __combine_node_embeddings(self, edges: torch.Tensor, nodes: torch.Tensor):
        return torch.cat([nodes[edges[0]], nodes[edges[1]]], dim=-1) # (num_nodes, 2 * hidden_features)

In [197]:
in_features = data.x.shape[-1]
hidden_features = 32
out_features = 1
num_layers = 1
dropout = 0.0
model = MyGNN(in_features, hidden_features, out_features, num_layers=num_layers, dropout=dropout)

In [198]:
def get_loss(pos_scores: torch.Tensor, neg_scores: torch.Tensor):
    pos_scores = torch.sigmoid(pos_scores)
    neg_scores = torch.sigmoid(neg_scores)
    return -torch.mean(torch.log(pos_scores) + torch.log(1 - neg_scores))

In [200]:
val_data

Data(x=[13752, 767], edge_index=[2, 393378], y=[13752], edge_label=[98344], edge_label_index=[2, 98344])

In [201]:
optimizer = optim.Adam(model.parameters(), lr=0.001)
num_epochs = 200
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    pos_scores, neg_scores = model(train_data.x, train_data.edge_index, train_data.edge_label_index[:, train_data.edge_label == 0])
    loss = get_loss(pos_scores, neg_scores)
    loss.backward()
    train_loss = loss.item()
    optimizer.step()
    model.eval()
    with torch.inference_mode():
        pos_scores, neg_scores = model(val_data.x, val_data.edge_index, torch.tensor([]))
        val_loss = get_loss(pos_scores, neg_scores).item()


    print(f'Epoch: {epoch + 1} | Loss: {train_loss.item()} | Val Loss: {val_loss}')

IndexError: index 0 is out of bounds for dimension 0 with size 0

In [189]:
train_data.x.shape

torch.Size([13752, 767])

In [None]:
model(features, edge_index, neg_edge_index)

(tensor([[ 0.0035],
         [-0.0018],
         [ 0.0031]], grad_fn=<AddmmBackward0>),
 tensor([[ 0.0167],
         [-0.0120],
         [-0.0081]], grad_fn=<AddmmBackward0>))