# Necessary imports

In [32]:
import pandas as pd
import networkx as nx
import torch
from random import shuffle, randint
from torch import nn
from torch_geometric.data import HeteroData

# Date preprocessing

In [33]:
anime = pd.read_csv('../data/anime.csv')
rating = pd.read_csv('../data/rating.csv')

In [34]:
def get_graph(anime: pd.DataFrame, rating: pd.DataFrame) -> nx.Graph:
    anime = anime.copy()
    rating = rating.copy()
    anime = anime.dropna()
    rating = rating.dropna()
    anime['genre'] = anime['genre'].str.split(', ')
    anime = anime[anime['members'] > 300_000]
    mp = {i: j for i, j in zip(anime['name'], anime['anime_id'])}
    anime = anime.drop(columns=['name', 'members'])

    anime_id = set(anime['anime_id'])
    genre = anime.explode('genre')

    rating = [(f'user_{i}', f'anime_{j}', r) for i, j, r in zip(rating['user_id'], rating['anime_id'], rating['rating'])
              if j in anime_id]
    genre = [(f'anime_{i}', f'genre_{j}') for i, j in zip(genre['anime_id'], genre['genre'])]
    tp = [(f'anime_{i}', f'type_{j}') for i, j in zip(anime['anime_id'], anime['type'])]

    user_id = list(set(i for i, j, r in rating))
    anime_id = list(set(j for i, j, r in rating))
    genres = list(set(j for i, j in genre))
    types = list(set(j for i, j in tp))

    G = nx.Graph()

    G.add_nodes_from(user_id, node_type="user")
    G.add_nodes_from(anime_id, node_type="anime")
    G.add_nodes_from(genres, node_type="entity")
    G.add_nodes_from(types, node_type="entity")

    G.add_weighted_edges_from(rating)
    G.add_edges_from(genre)
    G.add_edges_from(tp)

    return G, mp

In [35]:
graph, mp = get_graph(anime, rating)
print(graph)

Graph with 70827 nodes and 1535686 edges


In [36]:
def relational_neighborhood_construction(graph: nx.Graph, mp):
    users = [node for node in graph.nodes if node.startswith('user')]
    anime = [node for node in graph.nodes if node.startswith('anime')]
    types = [node for node in graph.nodes if node.startswith('type')]
    genre = [node for node in graph.nodes if node.startswith('genre')]

    mp = {j: i for i, j in mp.items()}
    anime2idx = {j: i for i, j in enumerate(anime)}
    users2idx = {j: len(anime2idx) + i for i, j in enumerate(users)}
    mapping = {mp[int(i[6:])]: j for i, j in anime2idx.items()}

    data = HeteroData()
    data['user'].node_id = torch.tensor(list(users2idx.values()))
    data['anime'].node_id = torch.tensor(list(anime2idx.values()))

    edges = set()
    for user in users:
        for anm in graph.neighbors(user):
            edges.add((user, anm))
    data['user', 'watched', 'anime'].edge_index = torch.tensor([(users2idx[i], anime2idx[j]) for i, j in edges]).T

    edges = set()
    for gnr in genre:
        anm = sorted(graph.neighbors(gnr))
        for i in range(len(anm)):
            for j in range(i + 1, len(anm)):
                edges.add((anm[i], anm[j]))

    data['anime', 'genre', 'anime'].edge_index = torch.tensor([(anime2idx[i], anime2idx[j]) for i, j in edges]).T

    edges = set()
    for tp in types:
        anm = sorted(graph.neighbors(tp))
        for i in range(len(anm)):
            for j in range(i + 1, len(anm)):
                edges.add((anm[i], anm[j]))

    data['anime', 'type', 'anime'].edge_index = torch.tensor([(anime2idx[i], anime2idx[j]) for i, j in edges]).T

    return data, mapping

In [37]:
data, mapping = relational_neighborhood_construction(graph, mp)
print(data)

HeteroData(
  user={ node_id=[70686] },
  anime={ node_id=[103] },
  (user, watched, anime)={ edge_index=[2, 1535080] },
  (anime, genre, anime)={ edge_index=[2, 3757] },
  (anime, type, anime)={ edge_index=[2, 4759] }
)


In [38]:
def train_test_split(data, test_size=0.2):
    train, test = data.clone(), data.clone()
    users = [i.item() for i in data['user'].node_id]
    shuffle(users)

    idx = int(len(users) * test_size)
    train_idx = {j: i + data['anime'].num_nodes for i, j in enumerate(users[idx:])}
    test_idx = {j: i + data['anime'].num_nodes for i, j in enumerate(users[:idx])}

    train['user'].node_id = torch.tensor([train_idx[i] for i in users[idx:]])
    test['user'].node_id = torch.tensor([test_idx[i] for i in users[:idx]])
    train_edges, test_edges = [], []

    edges = data['user', 'watched', 'anime'].edge_index.cpu().detach().numpy()
    for i in range(edges.shape[-1]):
        if edges[0][i] in train_idx:
            train_edges.append((train_idx[edges[0][i]], edges[1][i]))
        else:
            test_edges.append((test_idx[edges[0][i]], edges[1][i]))
    train['user', 'watched', 'anime'].edge_index = torch.tensor(train_edges).T
    test['user', 'watched', 'anime'].edge_index = torch.tensor(test_edges).T
    return train, test

In [39]:
train, test = train_test_split(data)
print(train)
print(test)

HeteroData(
  user={ node_id=[56549] },
  anime={ node_id=[103] },
  (user, watched, anime)={ edge_index=[2, 1229745] },
  (anime, genre, anime)={ edge_index=[2, 3757] },
  (anime, type, anime)={ edge_index=[2, 4759] }
)
HeteroData(
  user={ node_id=[14137] },
  anime={ node_id=[103] },
  (user, watched, anime)={ edge_index=[2, 305335] },
  (anime, genre, anime)={ edge_index=[2, 3757] },
  (anime, type, anime)={ edge_index=[2, 4759] }
)


# Model

In [None]:
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import softmax
import torch.nn as nn
import torch_scatter


class DSKReG(MessagePassing):
    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        num_relations: int,
        num_classes: int,
        top_k: int = 5,
    ) -> None:
        super(DSKReG, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.num_relations = num_relations
        self.num_classes = num_classes

        self.top_k = top_k

        self.linear_rel = nn.Linear(hidden_dim * 2, 1, bias=True)
        self.linear_agg = nn.Linear(hidden_dim, hidden_dim, bias=True)

        self.relation_weight = nn.Parameter(torch.randn(hidden_dim))

    def forward(self, x, edge_index, edge_type, user_emb, size=None):
        x = self.propagate(
            edge_index, x=x, edge_type=edge_type, user_emb=user_emb, size=size
        )
        return x

    def message(self, x_i, x_j, index, ptr, size_i):
        relevance_score = self.rel_scores(x_i, x_j)
        return self.gumbel_softmax_sampling(relevance_score, index)

    def aggregate(self, inputs, index, ptr=None, dim_size=None):
        return torch_scatter.scatter(inputs, index, dim=0, reduce="mean")

    def rel_scores(self, relation_emb, neighbor_emb):
        concat_emb = torch.cat([relation_emb, neighbor_emb], dim=-1)
        return torch.softmax(self.linear_rel(concat_emb).squeeze(-1), dim=0)

    def gumbel_softmax_sampling(self, relevance_score, index):
        grouped_scores = softmax(relevance_score, index=index)

        gumbel_noise = (
            torch.rand_like(grouped_scores).log()
            - torch.rand_like(grouped_scores).log()
        )

        softmax_logits = torch.softmax(
            (torch.log(grouped_scores) + gumbel_noise) / self.tau, dim=0
        )

        _, top_k_indices = torch.topk(
            softmax_logits, self.top_k, dim=0, largest=True, sorted=False, out=None
        )

        mask = torch.zeros_like(softmax_logits)
        mask[top_k_indices] = 1.0

        return mask * softmax_logits

    def loss(self, user_emb, pos_item_emb, neg_item_emb, reg_lambda=0.001):
        pos_scores = (user_emb * pos_item_emb).sum(dim=-1)
        neg_scores = (user_emb * neg_item_emb).sum(dim=-1)

        bpr_loss = -torch.log(torch.sigmoid(pos_scores - neg_scores)).mean()
        l2_norm = (
            user_emb.norm(2).pow(2)
            + pos_item_emb.norm(2).pow(2)
            + neg_item_emb.norm(2).pow(2)
        )

        return bpr_loss + reg_lambda * l2_norm

In [56]:
from torch_geometric.nn import SAGEConv
from torch.nn import functional as F


class DSKReG(torch.nn.Module):
    def __init__(self, user_num, anime_num, hidden_channels, hidden_layers=2, dropout=0.5):
        super().__init__()
        self.user_num = user_num
        self.anime_num = anime_num

        self.dropout = nn.Dropout(dropout)

        self.watched = nn.ModuleList(
            [SAGEConv(1, hidden_channels)] + [SAGEConv(hidden_channels, hidden_channels) for _ in
                                              range(hidden_layers - 1)])
        self.genre = nn.ModuleList([SAGEConv(hidden_channels, hidden_channels) for _ in range(hidden_layers)])
        self.tp = nn.ModuleList([SAGEConv(hidden_channels, hidden_channels) for _ in range(hidden_layers)])

        self.user_linear = nn.Linear(hidden_channels, hidden_channels)
        self.anime_linear = nn.Linear(hidden_channels, hidden_channels)

    def cat(self, user, anime):
        return torch.cat((user, anime))

    def uncat(self, nodes):
        return nodes[:-self.anime_num], nodes[-self.anime_num:]

    def forward(self, data):
        user = data['user'].node_id
        anime = data['anime'].node_id

        user = torch.ones_like(user)[:, None].float()
        anime = torch.ones_like(anime)[:, None].float()

        nodes = self.cat(user, anime)
        for watched, genre, tp in zip(self.watched, self.genre, self.tp):
            nodes = self.dropout(nodes)

            nodes = watched(nodes, data['user', 'watched', 'anime'].edge_index)

            user, anime = self.uncat(nodes)

            anime = genre(anime, data['anime', 'genre', 'anime'].edge_index)

            anime = tp(anime, data['anime', 'type', 'anime'].edge_index)
            anime = F.relu(anime)

            nodes = self.cat(user, anime)

        user, anime = self.uncat(nodes)

        user = self.user_linear(user)
        anime = self.anime_linear(anime)
        
        user_idx = (torch.ones((data['anime'].num_nodes, 1)).int() * torch.arange(data['user'].num_nodes)).T.reshape(-1)
        anime_idx = (torch.ones((data['user'].num_nodes, 1)).int() * torch.arange(data['anime'].num_nodes)).reshape(-1)
        user = user[user_idx]
        anime = anime[anime_idx]

        return (user * anime).sum(-1)

In [57]:
model = DSKReG(data['user'].num_nodes, data['anime'].num_nodes, 64)
print(model(train))

tensor([ 0.3987,  0.3736,  0.2440,  ...,  0.2416, -0.1113, -0.0401],
       grad_fn=<SumBackward1>)


# Training

In [58]:
def all_edges(data):
    edges = data['user', 'watched', 'anime'].edge_index.clone()
    edges[0] -= data['anime'].num_nodes
    edgs = torch.sparse_coo_tensor(edges, torch.ones(edges.shape[1]),
                                   (data['user'].num_nodes, data['anime'].num_nodes)).to_dense()
    return edgs.reshape(-1)


def get_data(data):
    indices = data['user', 'watched', 'anime'].edge_index
    size = indices.shape[-1]
    idx = list(range(size))
    shuffle(idx)
    idx = idx[:randint(int(size * 0.7), int(size * 0.9))]
    input = data.clone()
    input['user', 'watched', 'anime'].edge_index = indices[:, idx]
    return input

In [59]:
all_edges(test)

tensor([0., 0., 0.,  ..., 0., 0., 0.])

In [60]:
print(train)
get_data(train)

HeteroData(
  user={ node_id=[56549] },
  anime={ node_id=[103] },
  (user, watched, anime)={ edge_index=[2, 1229745] },
  (anime, genre, anime)={ edge_index=[2, 3757] },
  (anime, type, anime)={ edge_index=[2, 4759] }
)


HeteroData(
  user={ node_id=[56549] },
  anime={ node_id=[103] },
  (user, watched, anime)={ edge_index=[2, 904238] },
  (anime, genre, anime)={ edge_index=[2, 3757] },
  (anime, type, anime)={ edge_index=[2, 4759] }
)

In [61]:
model = DSKReG(data['user'].num_nodes, data['anime'].num_nodes, 64, 2)
loss_fn = nn.BCEWithLogitsLoss()
optim = torch.optim.Adam(model.parameters())
best = 1

In [62]:
for epoch in range(10):
    print(f'Epoch {epoch + 1}...')

    model.train()
    input = get_data(train)
    output = all_edges(train)
    out = model(input)
    
    loss = loss_fn(out, output)
    
    optim.zero_grad()
    loss.backward()
    optim.step()
    print(f'Train loss: {loss.item()}')

    model.eval()
    input = get_data(test)
    output = all_edges(test)
    out = model(input)
    
    loss = loss_fn(out, output).item()
    
    if loss < best:
        best = loss
        torch.save(model, 'best.pt')
    print(f'Test loss: {loss}')

Epoch 1...
Train loss: 0.6760521531105042
Test loss: 0.5353950262069702
Epoch 2...
Train loss: 0.5569848418235779
Test loss: 0.5192142724990845
Epoch 3...
Train loss: 0.5505090355873108
Test loss: 0.5373105406761169
Epoch 4...
Train loss: 0.5743471384048462
Test loss: 0.5232699513435364
Epoch 5...
Train loss: 0.5463542342185974
Test loss: 0.5142055749893188
Epoch 6...
Train loss: 0.5367845296859741
Test loss: 0.5197067856788635
Epoch 7...
Train loss: 0.534759521484375
Test loss: 0.527216911315918
Epoch 8...
Train loss: 0.5379371047019958
Test loss: 0.5277498364448547
Epoch 9...
Train loss: 0.5404622554779053
Test loss: 0.5232611298561096
Epoch 10...
Train loss: 0.5345408916473389
Test loss: 0.5178993940353394


# Predict top 10 anime for new user

In [49]:
model = torch.load('best.pt')

In [50]:
new_data = data.clone()
watched = ('Death Note', 'Sword Art Online', 'Naruto')
ids = [mapping[i] for i in watched]
id = new_data['user'].num_nodes
add = torch.tensor([[id, id, id], ids])
new_data['user', 'watched', 'anime'].edge_index = torch.cat((new_data['user', 'watched', 'anime'].edge_index, add), -1)
new_data['user'].node_id = torch.cat((new_data['user'].node_id, torch.tensor([new_data['user'].num_nodes])))
print(data)
print(new_data)

HeteroData(
  user={ node_id=[70686] },
  anime={ node_id=[103] },
  (user, watched, anime)={ edge_index=[2, 1535080] },
  (anime, genre, anime)={ edge_index=[2, 3757] },
  (anime, type, anime)={ edge_index=[2, 4759] }
)
HeteroData(
  user={ node_id=[70687] },
  anime={ node_id=[103] },
  (user, watched, anime)={ edge_index=[2, 1535083] },
  (anime, genre, anime)={ edge_index=[2, 3757] },
  (anime, type, anime)={ edge_index=[2, 4759] }
)


In [54]:
logits = model(new_data).reshape(new_data['user'].num_nodes, new_data['anime'].num_nodes)
logits.shape

torch.Size([70687, 103])

In [55]:
logits = logits[-1].cpu().detach().numpy().tolist()
logits = [(j, i) for i, j in enumerate(logits)]
logits.sort()
reverse_mapping = {j: i for i, j in mapping.items()}
result = [i for j, i in logits if i not in ids]
[reverse_mapping[i] for i in result[:10]]

['Higurashi no Naku Koro ni',
 'Kimi ni Todoke',
 'Ano Hi Mita Hana no Namae wo Bokutachi wa Mada Shiranai.',
 'Bakemonogatari',
 'Mahou Shoujo Madoka★Magica',
 'Suzumiya Haruhi no Yuuutsu',
 'Ore no Imouto ga Konnani Kawaii Wake ga Nai',
 'Ouran Koukou Host Club',
 'Ao no Exorcist',
 'Gintama']