In [1]:
import os
import torch
print("Using torch", torch.__version__)

Using torch 2.1.0+cu118


In [3]:
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-2.1.0+cu118.html
!pip install pyg-library

Looking in links: https://data.pyg.org/whl/torch-2.1.0+cu118.html


In [4]:
from torch_geometric.data import Data
from torch_geometric.datasets import MovieLens100K
from torch_geometric import nn
import torch_geometric.transforms as T

In [6]:
dataset = MovieLens100K(root='/tmp/movielens')

In [7]:
movielens_raw = dataset[0]
movielens_raw

HeteroData(
  movie={ x=[1682, 18] },
  user={ x=[943, 24] },
  (user, rates, movie)={
    edge_index=[2, 80000],
    rating=[80000],
    time=[80000],
    edge_label_index=[2, 20000],
    edge_label=[20000],
  },
  (movie, rated_by, user)={
    edge_index=[2, 80000],
    rating=[80000],
    time=[80000],
  }
)

In [8]:
del movielens_raw[("user", "rates", "movie")].edge_label_index
del movielens_raw[("user", "rates", "movie")].edge_label

In [9]:
movielens_raw

HeteroData(
  movie={ x=[1682, 18] },
  user={ x=[943, 24] },
  (user, rates, movie)={
    edge_index=[2, 80000],
    rating=[80000],
    time=[80000],
  },
  (movie, rated_by, user)={
    edge_index=[2, 80000],
    rating=[80000],
    time=[80000],
  }
)

In [10]:
print(movielens_raw.num_nodes)
print(movielens_raw.num_edges)
print(movielens_raw.x_dict)
print(movielens_raw.edge_index_dict)

2625
160000
{'movie': tensor([[0., 0., 1.,  ..., 0., 0., 0.],
        [1., 1., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 1., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]]), 'user': tensor([[0.3288, 0.0000, 1.0000,  ..., 0.0000, 1.0000, 0.0000],
        [0.7260, 1.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.3151, 0.0000, 1.0000,  ..., 0.0000, 0.0000, 1.0000],
        ...,
        [0.2740, 0.0000, 1.0000,  ..., 1.0000, 0.0000, 0.0000],
        [0.6575, 1.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.3014, 0.0000, 1.0000,  ..., 1.0000, 0.0000, 0.0000]])}
{('user', 'rates', 'movie'): tensor([[   0,    0,    0,  ...,  942,  942,  942],
        [   0,    1,    2,  ..., 1187, 1227, 1329]]), ('movie', 'rated_by', 'user'): tensor([[   0,    1,    2,  ..., 1187, 1227, 1329],
        [   0,    0,    0,  ...,  942,  942,  942]])}


In [11]:
node_types, edge_types = movielens_raw.metadata()
print(node_types)
print(edge_types)

['movie', 'user']
[('user', 'rates', 'movie'), ('movie', 'rated_by', 'user')]


In [12]:
print(movielens_raw["movie"].num_nodes)
print(movielens_raw["user"].num_nodes)

1682
943


In [13]:
print(movielens_raw[("user", "rates", "movie")].edge_attrs())

['edge_index', 'rating', 'time']


In [14]:
print(movielens_raw[("user", "rates", "movie")].rating)
print(movielens_raw[("user", "rates", "movie")].time)

tensor([5, 3, 4,  ..., 3, 3, 3])
tensor([874965758, 876893171, 878542960,  ..., 888640250, 888640275,
        888692465])


In [15]:
transform = T.Compose([
    T.RandomLinkSplit(
        num_val=0.05,
        num_test=0.1,
        disjoint_train_ratio=0.2,   # supervision
        add_negative_train_samples=False,
        neg_sampling_ratio=1.0,
        edge_types=("user", "rates", "movie"),
        rev_edge_types=('movie', 'rated_by', 'user')
    )
])

In [16]:
train_data, val_data, test_data = transform(movielens_raw)

In [17]:
print(f"Train Data: {train_data}")
print(f"Val Data: {val_data}")
print(f"Test Data: {test_data}")

Train Data: HeteroData(
  movie={ x=[1682, 18] },
  user={ x=[943, 24] },
  (user, rates, movie)={
    edge_index=[2, 54400],
    rating=[54400],
    time=[54400],
    edge_label=[13600],
    edge_label_index=[2, 13600],
  },
  (movie, rated_by, user)={
    edge_index=[2, 54400],
    rating=[54400],
    time=[54400],
  }
)
Val Data: HeteroData(
  movie={ x=[1682, 18] },
  user={ x=[943, 24] },
  (user, rates, movie)={
    edge_index=[2, 68000],
    rating=[68000],
    time=[68000],
    edge_label=[8000],
    edge_label_index=[2, 8000],
  },
  (movie, rated_by, user)={
    edge_index=[2, 68000],
    rating=[68000],
    time=[68000],
  }
)
Test Data: HeteroData(
  movie={ x=[1682, 18] },
  user={ x=[943, 24] },
  (user, rates, movie)={
    edge_index=[2, 72000],
    rating=[72000],
    time=[72000],
    edge_label=[16000],
    edge_label_index=[2, 16000],
  },
  (movie, rated_by, user)={
    edge_index=[2, 72000],
    rating=[72000],
    time=[72000],
  }
)


In [18]:
print(train_data[edge_types[0]].edge_label)
print(train_data[edge_types[0]].edge_label_index)

tensor([1., 1., 1.,  ..., 1., 1., 1.])
tensor([[392, 876, 420,  ..., 579, 642, 586],
        [495, 339, 173,  ..., 293, 715, 994]])


In [40]:
import sys
import torch_geometric
from torch_geometric.nn import GCNConv, SAGEConv, GATConv, GINConv
from torch_geometric.nn import BatchNorm, LayerNorm, HeteroBatchNorm, HeteroLayerNorm
from torch_geometric.nn import to_hetero

class SAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()

        self.relu = torch.nn.ReLU()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels)
        self.conv3 = SAGEConv(hidden_channels, out_channels)

    def forward(self, node_feature, edge_index):

        output1 = self.relu(self.conv1(node_feature, edge_index))
        output2 = self.relu(self.conv2(output1, edge_index))
        output = self.conv3(output2, edge_index)
        return output

class SAGE_RES(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()

        self.relu = torch.nn.ReLU()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels)
        self.conv3 = SAGEConv(hidden_channels, out_channels)
        self.res = torch.nn.Linear(in_channels, out_channels)
        # self.norm1 = BatchNorm(hidden_channels, 2)
        # self.norm2 = BatchNorm(hidden_channels, 2)

    def forward(self, node_feature, edge_index):

        output1 = self.relu(self.conv1(node_feature, edge_index))
        # output1 = self.norm1(output1)
        output2 = self.relu(self.conv2(output1, edge_index))
        # output2 = self.norm2(output2)
        output3 = self.conv3(output2, edge_index)
        output_res = self.res(node_feature)
        return (output3 + output_res) * 0.5

class SAGE_RES_2(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()

        self.relu = torch.nn.ReLU()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels)
        self.conv3 = SAGEConv(hidden_channels, out_channels)
        self.res1 = torch.nn.Linear(in_channels, hidden_channels)
        self.res2 = torch.nn.Linear(hidden_channels, out_channels)

    def forward(self, node_feature, edge_index):

        output1 = self.relu(self.conv1(node_feature, edge_index))
        res1 = self.res1(node_feature)
        res2 = self.res2(output1)
        output2 = self.relu(self.conv2(output1, edge_index))
        output2_res = (output2 + res1) * 0.5
        output3 = self.conv3(output2_res, edge_index)
        output3_res = (output3 + res2) * 0.5
        return output3_res

class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()

        self.relu = torch.nn.ReLU()
        self.conv1 = GATConv(in_channels, hidden_channels, add_self_loops=False)
        self.conv2 = GATConv(hidden_channels, hidden_channels, add_self_loops=False)
        self.conv3 = GATConv(hidden_channels, out_channels, add_self_loops=False)

    def forward(self, node_feature, edge_index):

        output1 = self.relu(self.conv1(node_feature, edge_index))
        output2 = self.relu(self.conv2(output1, edge_index))
        output = self.conv3(output2, edge_index)
        return output

class GAT_RES(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()

        self.relu = torch.nn.ReLU()
        self.conv1 = GATConv(in_channels, hidden_channels, add_self_loops=False)
        self.conv2 = GATConv(hidden_channels, hidden_channels, add_self_loops=False)
        self.conv3 = GATConv(hidden_channels, out_channels, add_self_loops=False)
        self.res = torch.nn.Linear(in_channels, out_channels)

    def forward(self, node_feature, edge_index):

        output1 = self.relu(self.conv1(node_feature, edge_index))
        output2 = self.relu(self.conv2(output1, edge_index))
        output3 = self.conv3(output2, edge_index)
        output_res = self.res(node_feature)
        return (output3 + output_res) * 0.5

class GAT_RES_2(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()

        self.relu = torch.nn.ReLU()
        self.conv1 = GATConv(in_channels, hidden_channels, add_self_loops=False)
        self.conv2 = GATConv(hidden_channels, hidden_channels, add_self_loops=False)
        self.conv3 = GATConv(hidden_channels, out_channels, add_self_loops=False)
        self.res1 = torch.nn.Linear(in_channels, hidden_channels)
        self.res2 = torch.nn.Linear(hidden_channels, out_channels)

    def forward(self, node_feature, edge_index):

        output1 = self.relu(self.conv1(node_feature, edge_index))
        res1 = self.res1(node_feature)
        res2 = self.res2(output1)
        output2 = self.relu(self.conv2(output1, edge_index))
        output2_res = (output2 + res1) * 0.5
        output3 = self.conv3(output2_res, edge_index)
        output3_res = (output3 + res2) * 0.5
        return output3_res

class Embedder(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        # [0]: movie, feature dim: 18
        # [1]: user, feature dim: 24
        movies = movielens_raw[node_types[0]]
        users = movielens_raw[node_types[1]]
        self.sage = SAGE(in_channels, hidden_channels, out_channels)
        self.sage = to_hetero(self.sage, metadata=movielens_raw.metadata())
        self.sage_res = SAGE_RES(in_channels, hidden_channels, out_channels)
        self.sage_res = to_hetero(self.sage_res, metadata=movielens_raw.metadata())
        self.sage_res_2 = SAGE_RES_2(in_channels, hidden_channels, out_channels)
        self.sage_res_2 = to_hetero(self.sage_res_2, metadata=movielens_raw.metadata())
        self.gat = GAT(in_channels, hidden_channels, out_channels)
        self.gat = to_hetero(self.gat, metadata=movielens_raw.metadata())
        self.gat_res = GAT_RES(in_channels, hidden_channels, out_channels)
        self.gat_res = to_hetero(self.gat_res, metadata=movielens_raw.metadata())
        self.gat_res_2 = GAT_RES_2(in_channels, hidden_channels, out_channels)
        self.gat_res_2 = to_hetero(self.gat_res_2, metadata=movielens_raw.metadata())
        self.linear_movie = torch.nn.Linear(movies.num_node_features, in_channels)
        self.linear_user = torch.nn.Linear(users.num_node_features, in_channels)
        self.gnn = self.gat_res_2

    def forward(self, hetero_data):
        features = {
            node_types[0]: self.linear_movie(hetero_data[node_types[0]].x),
            node_types[1]: self.linear_user(hetero_data[node_types[1]].x)
        }
        embeddings = self.gnn(features, hetero_data.edge_index_dict)
        return embeddings

In [41]:
def calc_emb_similarity(node_embs, edge_index, method="cosine"):
    # node_types[1] = "user"
    # node_types[0] = "movie"
    if method == "cosine":
        return torch.sum(node_embs[node_types[1]][edge_index[0]] * node_embs[node_types[0]][edge_index[1]], 1)

In [42]:
from torch_geometric.loader import LinkNeighborLoader
relation_rate = edge_types[0]
train_loader = LinkNeighborLoader(
    data=train_data,
    num_neighbors=[40, 20, 10],
    neg_sampling="binary",
    neg_sampling_ratio=1.0,
    edge_label_index=(relation_rate, train_data[relation_rate].edge_label_index),
    edge_label=train_data[relation_rate].edge_label,
    batch_size=256,
    shuffle=True
)



In [43]:
from tqdm import tqdm
from torch_geometric.utils import negative_sampling

def train(model, dataloader, optimizer, loss_fn):
    correct_count = 0
    all_count = 0
    loss = 0
    model.train()
    for batch in tqdm(dataloader):
        optimizer.zero_grad()

        node_embeddings = model(batch)

        """
        neg_edge_index = negative_sampling(
            edge_index=batch[edge_types[0]].edge_index,
            num_nodes=batch.num_nodes,
            num_neg_samples=batch[edge_types[0]].edge_label.shape
        )

        edges_all = torch.cat((batch[edge_types[0]].edge_label_index, neg_edge_index), dim=1)
        labels_all = torch.cat((batch[edge_types[0]].edge_label, torch.zeros(batch[edge_types[0]].edge_label.shape)), dim=0)
        """

        labels = batch[edge_types[0]].edge_label
        similarities = calc_emb_similarity(node_embeddings, batch[edge_types[0]].edge_label_index)
        predictions = similarities.sigmoid() > 0.5
        correct_count += torch.sum(predictions == labels)
        all_count += len(labels)

        loss = loss_fn(similarities, labels)
        loss.backward()
        optimizer.step()
    return model, (float(correct_count) / float(all_count))


In [44]:
from sklearn.metrics import roc_auc_score

@torch.no_grad()
def test(model, hetero_data):
    model.eval()
    node_embs = model(hetero_data)
    node_embs = calc_emb_similarity(node_embs, hetero_data[edge_types[0]].edge_label_index).view(-1).sigmoid()
    return roc_auc_score(hetero_data[edge_types[0]].edge_label.cpu().numpy(), node_embs.cpu().numpy())

In [51]:
model = Embedder(movielens_raw[node_types[0]].num_node_features, hidden_channels=128, out_channels=64)
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-2)
loss_fn = torch.nn.BCEWithLogitsLoss()
print(model)

Embedder(
  (sage): GraphModule(
    (conv1): ModuleDict(
      (user__rates__movie): SAGEConv(18, 128, aggr=mean)
      (movie__rated_by__user): SAGEConv(18, 128, aggr=mean)
    )
    (relu): ModuleDict(
      (movie): ReLU()
      (user): ReLU()
    )
    (conv2): ModuleDict(
      (user__rates__movie): SAGEConv(128, 128, aggr=mean)
      (movie__rated_by__user): SAGEConv(128, 128, aggr=mean)
    )
    (conv3): ModuleDict(
      (user__rates__movie): SAGEConv(128, 64, aggr=mean)
      (movie__rated_by__user): SAGEConv(128, 64, aggr=mean)
    )
  )
  (sage_res): GraphModule(
    (conv1): ModuleDict(
      (user__rates__movie): SAGEConv(18, 128, aggr=mean)
      (movie__rated_by__user): SAGEConv(18, 128, aggr=mean)
    )
    (relu): ModuleDict(
      (movie): ReLU()
      (user): ReLU()
    )
    (conv2): ModuleDict(
      (user__rates__movie): SAGEConv(128, 128, aggr=mean)
      (movie__rated_by__user): SAGEConv(128, 128, aggr=mean)
    )
    (conv3): ModuleDict(
      (user__rates__m

In [52]:
epochs = 10
for epoch in range(1, epochs + 1):
    model, acc = train(model, train_loader, optimizer, loss_fn)
    val_auc = test(model, val_data)
    test_auc = test(model, test_data)
    print(f'Epoch: {epoch:03d}, Training Accuracy: {acc:.4f}, Val AUC: {val_auc:.4f}, Test AUC: {test_auc:.4f}')

100%|██████████| 54/54 [00:15<00:00,  3.45it/s]


Epoch: 001, Training Accuracy: 0.6656, Val AUC: 0.8423, Test AUC: 0.8416


100%|██████████| 54/54 [00:15<00:00,  3.43it/s]


Epoch: 002, Training Accuracy: 0.7301, Val AUC: 0.8570, Test AUC: 0.8570


100%|██████████| 54/54 [00:15<00:00,  3.48it/s]


Epoch: 003, Training Accuracy: 0.7458, Val AUC: 0.8688, Test AUC: 0.8659


100%|██████████| 54/54 [00:15<00:00,  3.49it/s]


Epoch: 004, Training Accuracy: 0.7483, Val AUC: 0.8608, Test AUC: 0.8686


100%|██████████| 54/54 [00:15<00:00,  3.47it/s]


Epoch: 005, Training Accuracy: 0.7650, Val AUC: 0.8552, Test AUC: 0.8518


100%|██████████| 54/54 [00:16<00:00,  3.33it/s]


Epoch: 006, Training Accuracy: 0.7628, Val AUC: 0.8611, Test AUC: 0.8565


100%|██████████| 54/54 [00:15<00:00,  3.38it/s]


Epoch: 007, Training Accuracy: 0.7647, Val AUC: 0.8808, Test AUC: 0.8820


100%|██████████| 54/54 [00:15<00:00,  3.43it/s]


Epoch: 008, Training Accuracy: 0.7726, Val AUC: 0.8830, Test AUC: 0.8770


100%|██████████| 54/54 [00:15<00:00,  3.48it/s]


Epoch: 009, Training Accuracy: 0.7736, Val AUC: 0.8711, Test AUC: 0.8706


100%|██████████| 54/54 [00:15<00:00,  3.50it/s]


Epoch: 010, Training Accuracy: 0.7662, Val AUC: 0.8813, Test AUC: 0.8809
