In [26]:
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader, LinkNeighborLoader
import torch_geometric.transforms as T
import torch

import networkx as nx

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pickle

In [3]:
anime = pd.read_csv("../data/anime.csv")
rating = pd.read_csv("../data/rating.csv")

anime['genre'] = anime['genre'].str.split(', ')
#anime = anime.dropna(axis=1)
anime.head()

Unnamed: 0,anime_id,name,genre,type,episodes,rating,members
0,32281,Kimi no Na wa.,"[Drama, Romance, School, Supernatural]",Movie,1,9.37,200630
1,5114,Fullmetal Alchemist: Brotherhood,"[Action, Adventure, Drama, Fantasy, Magic, Mil...",TV,64,9.26,793665
2,28977,Gintama°,"[Action, Comedy, Historical, Parody, Samurai, ...",TV,51,9.25,114262
3,9253,Steins;Gate,"[Sci-Fi, Thriller]",TV,24,9.17,673572
4,9969,Gintama&#039;,"[Action, Comedy, Historical, Parody, Samurai, ...",TV,51,9.16,151266


In [4]:
anime.isna().sum()

anime_id      0
name          0
genre        62
type         25
episodes      0
rating      230
members       0
dtype: int64

In [5]:
anime = anime.dropna()
anime.info()

<class 'pandas.core.frame.DataFrame'>
Index: 12017 entries, 0 to 12293
Data columns (total 7 columns):
 #   Column    Non-Null Count  Dtype  
---  ------    --------------  -----  
 0   anime_id  12017 non-null  int64  
 1   name      12017 non-null  object 
 2   genre     12017 non-null  object 
 3   type      12017 non-null  object 
 4   episodes  12017 non-null  object 
 5   rating    12017 non-null  float64
 6   members   12017 non-null  int64  
dtypes: float64(1), int64(2), object(4)
memory usage: 751.1+ KB


In [6]:
anime['anime_id'] = anime.loc[:, 'anime_id'].apply(lambda x: x + 0.5)
anime.head()

Unnamed: 0,anime_id,name,genre,type,episodes,rating,members
0,32281.5,Kimi no Na wa.,"[Drama, Romance, School, Supernatural]",Movie,1,9.37,200630
1,5114.5,Fullmetal Alchemist: Brotherhood,"[Action, Adventure, Drama, Fantasy, Magic, Mil...",TV,64,9.26,793665
2,28977.5,Gintama°,"[Action, Comedy, Historical, Parody, Samurai, ...",TV,51,9.25,114262
3,9253.5,Steins;Gate,"[Sci-Fi, Thriller]",TV,24,9.17,673572
4,9969.5,Gintama&#039;,"[Action, Comedy, Historical, Parody, Samurai, ...",TV,51,9.16,151266


In [7]:
rating.head()

Unnamed: 0,user_id,anime_id,rating
0,1,20,-1
1,1,24,-1
2,1,79,-1
3,1,226,-1
4,1,241,-1


In [8]:
types = anime["type"].unique()
genres = anime["genre"].explode().unique()

type2id = {t:i for i, t in enumerate(types)}
id2type = {i:t for i, t in enumerate(types)}

genre2id = {g:i for i, g in enumerate(genres)}
id2genre = {i:g for i, g in enumerate(genres)}


unique_values = {
    "anime_id" : anime["anime_id"].unique(),
    "types" : [type2id[t] for t in anime["type"].unique()],
    "genre" : [genre2id[g] for g in anime["genre"].explode().unique()],
    "user_id": rating["user_id"].unique(),
}

In [9]:
G = nx.Graph()

G.add_nodes_from(unique_values["anime_id"], node_type='anime')
G.add_nodes_from(unique_values["types"], node_type='types')
G.add_nodes_from(unique_values["genre"], node_type='genre')
G.add_nodes_from(unique_values["user_id"], node_type='user')

for anime_id in unique_values["anime_id"]:
    if G.nodes[anime_id]['node_type'] == 'anime':
        G.nodes[anime_id]["rating"] = anime[anime["anime_id"] == anime_id]["rating"].values

for _, row in anime.iterrows():
    anime_id = row["anime_id"]
    
    anime_type = type2id[row["type"]]
    genres = [genre2id[g] for g in row["genre"]]

    G.add_edge(anime_id, anime_type, relation="type")
    
    for genre in genres:
        G.add_edge(anime_id, genre, relation="genre")

for _, row in rating.iterrows():
    user_id = row["user_id"]
    anime_id = row["anime_id"]
    rating_value = row["rating"]

    G.add_edge(user_id, anime_id, weight=rating_value, relation="rating")

In [10]:
edge_index = []
edge_type = []
for u, v, data in G.edges(data=True):
    edge_index.append([u, v])
    if data["relation"] == "rating":
        edge_type.append(0)
    elif data["relation"] == "type":
        edge_type.append(1)
    elif data["relation"] == "genre":
        edge_type.append(2)

edge_index = torch.tensor(edge_index, dtype=torch.long).T
edge_type = torch.tensor(edge_type, dtype=torch.long) 

In [11]:
node_features = []
for node in G:
    node_type = G.nodes[node]['node_type']
    if node_type == 'anime':
        node_features.append(G.nodes[node]["rating"])
    elif node_type == 'genre':
        node_features.append([0])
    elif node_type == 'types':
        node_features.append([1])
    elif node_type == 'user':
        node_features.append([2])

x = torch.tensor(node_features, dtype=torch.float).view(-1, 1)

  x = torch.tensor(node_features, dtype=torch.float).view(-1, 1)


In [12]:
data = Data(x=x, edge_index=edge_index, edge_type=edge_type)
data

Data(x=[85533, 1], edge_index=[2, 7854025], edge_type=[7854025])

In [13]:
data.is_directed()

True

### Split the data into train, test, and validation sets on edge-level

In [14]:
# Normalize and split the data
transforms = T.Compose(
    [
        T.NormalizeFeatures(),
        T.RandomLinkSplit(num_val=0.1, num_test=0.2)
    ]
)

train_data, val_data, test_data = transforms(data)

In [15]:
train_data

Data(x=[85533, 1], edge_index=[2, 5497818], edge_type=[5497818], edge_label=[10995636], edge_label_index=[2, 10995636])

In [16]:
val_data

Data(x=[85533, 1], edge_index=[2, 5497818], edge_type=[5497818], edge_label=[1570804], edge_label_index=[2, 1570804])

In [17]:
test_data

Data(x=[85533, 1], edge_index=[2, 6283220], edge_type=[6283220], edge_label=[3141610], edge_label_index=[2, 3141610])

In [18]:
BATCH_SIZE = 64

# Create DataLoaders for all sets of data
train_loader = LinkNeighborLoader(
    data=train_data,
    num_neighbors=[40, 40],
    batch_size = BATCH_SIZE,
    edge_label_index=train_data.edge_index,
    edge_label=train_data.edge_label,
    shuffle=True
)

test_loader = LinkNeighborLoader(
    data=test_data,
    num_neighbors=[40, 40],
    batch_size = BATCH_SIZE,
    edge_label_index=test_data.edge_index,
    edge_label=test_data.edge_label,
    shuffle=False
)

val_loader = LinkNeighborLoader(
    data=val_data,
    num_neighbors=[40, 40],
    batch_size = BATCH_SIZE,
    edge_label_index=val_data.edge_index,
    edge_label=val_data.edge_label,
    shuffle=False
)

In [23]:
print(f'Train DataLoader length: {len(train_loader)}')
print(f'Test DataLoader length: {len(test_loader)}')
print(f'Val DataLoader length: {len(val_loader)}')

Train DataLoader length: 85904
Test DataLoader length: 98176
Val DataLoader length: 85904


In [25]:
# Example of batch in train DataLoader
for batch in train_loader:
    print(batch)
    break

Data(x=[10778, 1], edge_index=[2, 76193], edge_type=[76193], edge_label=[64], edge_label_index=[2, 64], n_id=[10778], e_id=[76193], num_sampled_nodes=[3], num_sampled_edges=[2], input_id=[64])


### Save everything into the ".pkl" file (optional)

In [28]:
data_dict = {
    'train_loader': train_loader,
    'test_loader': test_loader,
    'val_loader': val_loader,
    'data': data,
    'graph': G
}

In [30]:
with open('../data/pickle_checkpoints/data_stats_v1.pkl', 'wb') as file:
    pickle.dump(data_dict, file)