In [None]:
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader, LinkNeighborLoader
import torch_geometric.transforms as T
import torch

import networkx as nx

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pickle

from tqdm.autonotebook import tqdm

In [None]:
anime = pd.read_csv("../data/anime.csv")
rating = pd.read_csv("../data/rating.csv")

anime["genre"] = anime["genre"].str.split(", ")
# anime = anime.dropna(axis=1)
anime.head()

In [None]:
anime.isna().sum()

In [None]:
anime = anime.dropna()
anime.info()

In [None]:
anime["anime_id"] = anime.loc[:, "anime_id"].apply(lambda x: x + 0.5)
anime.head()

In [None]:
rating.head()

In [None]:
types = anime["type"].unique()
genres = anime["genre"].explode().unique()

type2id = {t: i for i, t in enumerate(types)}
id2type = {i: t for i, t in enumerate(types)}

genre2id = {g: i for i, g in enumerate(genres)}
id2genre = {i: g for i, g in enumerate(genres)}


unique_values = {
    "anime_id": anime["anime_id"].unique(),
    "types": [type2id[t] for t in anime["type"].unique()],
    "genre": [genre2id[g] for g in anime["genre"].explode().unique()],
    "user_id": rating["user_id"].unique(),
}

In [None]:
G = nx.Graph()

G.add_nodes_from(unique_values["anime_id"], node_type="anime")
G.add_nodes_from(unique_values["types"], node_type="types")
G.add_nodes_from(unique_values["genre"], node_type="genre")
G.add_nodes_from(unique_values["user_id"], node_type="user")

for anime_id in unique_values["anime_id"]:
    if G.nodes[anime_id]["node_type"] == "anime":
        G.nodes[anime_id]["rating"] = anime[anime["anime_id"] == anime_id][
            "rating"
        ].values

for _, row in anime.iterrows():
    anime_id = row["anime_id"]

    anime_type = type2id[row["type"]]
    genres = [genre2id[g] for g in row["genre"]]

    G.add_edge(anime_id, anime_type, relation="type")

    for genre in genres:
        G.add_edge(anime_id, genre, relation="genre")

for _, row in rating.iterrows():
    user_id = row["user_id"]
    anime_id = row["anime_id"]
    rating_value = row["rating"]

    G.add_edge(user_id, anime_id, weight=rating_value, relation="rating")

In [None]:
edge_index = []
edge_type = []
for u, v, data in G.edges(data=True):
    edge_index.append([u, v])
    if data["relation"] == "rating":
        edge_type.append(0)
    elif data["relation"] == "type":
        edge_type.append(1)
    elif data["relation"] == "genre":
        edge_type.append(2)

edge_index = torch.tensor(edge_index, dtype=torch.long).T
edge_type = torch.tensor(edge_type, dtype=torch.long)

In [None]:
node_features = []
for node in G:
    node_type = G.nodes[node]["node_type"]
    if node_type == "anime":
        node_features.append(G.nodes[node]["rating"])
    elif node_type == "genre":
        node_features.append([0])
    elif node_type == "types":
        node_features.append([1])
    elif node_type == "user":
        node_features.append([2])

x = torch.tensor(node_features, dtype=torch.float).view(-1, 1)

In [None]:
data = Data(x=x, edge_index=edge_index, edge_type=edge_type)
data

In [None]:
data.is_directed()

### Split the data into train, test, and validation sets on edge-level

In [None]:
# Normalize and split the data
transforms = T.Compose(
    [T.NormalizeFeatures(), T.RandomLinkSplit(num_val=0.1, num_test=0.2)]
)

train_data, val_data, test_data = transforms(data)

In [None]:
train_data

In [None]:
val_data

In [None]:
test_data

In [None]:
BATCH_SIZE = 64

# Create DataLoaders for all sets of data
train_loader = LinkNeighborLoader(
    data=train_data,
    num_neighbors=[40, 40],
    batch_size=BATCH_SIZE,
    edge_label_index=train_data.edge_index,
    edge_label=train_data.edge_label,
    shuffle=True,
)

test_loader = LinkNeighborLoader(
    data=test_data,
    num_neighbors=[40, 40],
    batch_size=BATCH_SIZE,
    edge_label_index=test_data.edge_index,
    edge_label=test_data.edge_label,
    shuffle=False,
)

val_loader = LinkNeighborLoader(
    data=val_data,
    num_neighbors=[40, 40],
    batch_size=BATCH_SIZE,
    edge_label_index=val_data.edge_index,
    edge_label=val_data.edge_label,
    shuffle=False,
)

In [None]:
print(f"Train DataLoader length: {len(train_loader)}")
print(f"Test DataLoader length: {len(test_loader)}")
print(f"Val DataLoader length: {len(val_loader)}")

In [None]:
# Example of batch in train DataLoader
for batch in train_loader:
    print(batch)
    break

### Save everything into the ".pkl" file (optional)

In [None]:
data_dict = {
    "train_loader": train_loader,
    "test_loader": test_loader,
    "val_loader": val_loader,
    "data": data,
    "graph": G,
}

In [None]:
save_file_path = "../data/pickle_checkpoints/data_stats_v1.pkl"

In [None]:
with open(save_file_path, "wb") as file:
    pickle.dump(data_dict, file)

In [None]:
with open(save_file_path, "rb") as f:
    data_dict = pickle.load(f)
data_dict

In [109]:
a = iter(train_loader)

In [110]:
# x, edge_index, edge_type, edge_label = next(iter(train_loader))
b = next(a)

In [None]:
batch = Data(
    x=[10103, 1],
    edge_index=[2, 78764],
    edge_type=[78764],
    edge_label=[64],
    edge_label_index=[2, 64],
    n_id=[10103],
    e_id=[78764],
    num_sampled_nodes=[3],
    num_sampled_edges=[2],
    input_id=[64]
  )

In [124]:
b.edge_label

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

In [117]:
b.num_sampled_nodes, b.num_sampled_edges

([128, 2531, 7659], [4452, 74455])

In [123]:
b.edge_type

tensor([2, 2, 2,  ..., 0, 0, 0])

In [121]:
train_loader = data_dict["train_loader"]
test_loader = data_dict["test_loader"]
val_loader = data_dict["val_loader"]

### Model and Training


Model based on [DSKReG: Differentiable Sampling on Knowledge Graph for
Recommendation with Relational GNN](https://arxiv.org/pdf/2108.11883v1)


[BPR: Bayesian Personalized Ranking](https://arxiv.org/pdf/1205.2618) loss used. This loss mostly optimize ranking of model's predictions. The main idea of this loss is to maximize the probability of user prefering observed item over an unobserved.

In [None]:
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import softmax
import torch.nn as nn
import torch_scatter


class DSKReG(MessagePassing):
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        num_relations: int,
        num_classes: int,
        top_k: int = 5,
    ) -> None:
        super(DSKReG, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.num_relations = num_relations
        self.num_classes = num_classes

        self.top_k = top_k

        self.linear_rel = nn.Linear(hidden_dim * 2, 1, bias=True)
        self.linear_agg = nn.Linear(hidden_dim, hidden_dim, bias=True)

        self.relation_weight = nn.Parameter(torch.randn(hidden_dim))

    def forward(self, x, edge_index, edge_type, user_emb, size=None):
        self.propagate(
            edge_index, x=x, edge_type=edge_type, user_emb=user_emb, size=size
        )

    def message(self, x_i, x_j, index, ptr, size_i):
        pass

    def aggregate(self, inputs, index, ptr=None, dim_size=None):
        return torch_scatter.scatter(inputs, index, dim=0, reduce="mean")

    def rel_scores(self, relation_emb, neighbor_emb):
        concat_emb = torch.cat([relation_emb, neighbor_emb], dim=-1)
        return torch.softmax(self.linear_rel(concat_emb).squeeze(-1), dim=0)

    def gumbel_softmax_sampling(self, relevance_score, index):
        grouped_scores = softmax(relevance_score, index=index)

        gumbel_noise = (
            torch.rand_like(grouped_scores).log()
            - torch.rand_like(grouped_scores).log()
        )

        softmax_logits = torch.softmax(
            (torch.log(grouped_scores) + gumbel_noise) / self.tau, dim=0
        )

        _, top_k_indices = torch.topk(
            softmax_logits, self.top_k, dim=0, largest=True, sorted=False, out=None
        )

        mask = torch.zeros_like(softmax_logits)
        mask[top_k_indices] = 1.0

        return mask * softmax_logits

    def loss(self, user_emb, pos_item_emb, neg_item_emb, reg_lambda=0.001):
        pos_scores = (user_emb * pos_item_emb).sum(dim=-1)
        neg_scores = (user_emb * neg_item_emb).sum(dim=-1)

        bpr_loss = -torch.log(torch.sigmoid(pos_scores - neg_scores)).mean()
        l2_norm = (
            user_emb.norm(2).pow(2)
            + pos_item_emb.norm(2).pow(2)
            + neg_item_emb.norm(2).pow(2)
        )

        return bpr_loss + reg_lambda * l2_norm

In [None]:
def negative_sampling(edge_index, num_nodes, node_types, num_neg_samples):
    existing_edges = set(map(tuple, edge_index.T.tolist()))
    neg_edges = set()

    while len(neg_edges) < num_neg_samples:
        u = torch.randint(0, num_nodes, (1,)).item()
        v = torch.randint(0, num_nodes, (1,)).item()

        if node_types[u] == "user" and node_types[v] == "anime":
            if (u, v) not in existing_edges and (v, u) not in existing_edges:
                neg_edges.add((u, v))
        elif node_types[u] == "anime" and node_types[v] == "user":
            if (u, v) not in existing_edges and (v, u) not in existing_edges:
                neg_edges.add((u, v))

    return torch.tensor(list(neg_edges), dtype=torch.long).T


def validate(model, dataloader, device: str = "cpu"):
    loss_ = 0
    with torch.no_grad():
        for batch in dataloader:
            #loss = model.loss(user_emb, pos_item_emb, neg_item_emb)
            #loss_ += loss.item()
            loss += 1
    return loss_


def train(
    model,
    optimizer,
    train_dataloader,
    validation_dataloader,
    test_dataloader,
    num_epochs: int = 10,
    device: str = "cpu",
    reg_lambda: float = 0.001,
):
    for epoch in range(num_epochs):
        epoch_loss = 0
        validation_loss = 0
        test_loss = 0

        train_loader = tqdm(
            enumerate(train_dataloader), desc=f"Epoch {epoch + 1}/{num_epochs}"
        )

        for batch in train_loader:
            (
                x,
                edge_index,
                edge_type,
                edge_label,
                edge_label_index,
                n_id,
                e_id,
                num_sampled_nodes,
                num_sampled_edges,
                input_id,
            ) = batch

            optimizer.zero_grad()

            num_neg_samples = 
            neg_edge_index = negative_sampling(
                edge_index, num_nodes, node_types, num_neg_samples
            )

            pos_item_emb = user_emb[pos_edge_index[1]]
            neg_item_emb = user_emb[neg_edge_index[1]]

            loss = model.loss(user_emb, pos_item_emb, neg_item_emb)

            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            train_loader.set_postfix({"loss": loss.item()})

        validation_loss = validate(validation_dataloader)
        test_loss = validate(test_dataloader)

        print(
            f"Epoch {epoch + 1}, Training Loss: {epoch_loss / len(train_dataloader):.4f}"
        )
        print(f"Validation Loss: {validation_loss / len(validation_dataloader):.4f}")
        print(f"Test Loss: {test_loss / len(test_dataloader):.4f}")