In [None]:
from pyspark.sql import SparkSession
import pyspark.sql.types as T
import pyspark.sql.functions as F

from itertools import chain
import pickle
import os
import pandas as pd
import glob

import seaborn as sns
sns.set_theme()

spark = SparkSession.builder.getOrCreate()

In [None]:
input_dir = '/mnt/d/datasets/anime2020/animelist.csv'
temp_dir = '/mnt/d/datasets/anime2020/animelist_temp'
output_dir = '/mnt/d/datasets/anime2020/animelist_sample'

In [None]:
schema = T.StructType([
    T.StructField("user_id", T.IntegerType(), True),
    T.StructField("anime_id", T.IntegerType(), True),
    T.StructField("rating", T.IntegerType(), True),
    T.StructField("watching_status", T.IntegerType(), True),
    T.StructField("watched_episodes", T.IntegerType(), True),
  ])

In [None]:
df_animelist = spark.read.csv(input_dir, header=True, schema=schema)

In [None]:
df_animelist.show()

print(df_animelist.dtypes)
print(df_animelist.count())

In [None]:
user_count_anime = (
    df_animelist
    .groupBy('user_id')
    .count()
    .where(F.col("count") >= 5)
    .where(F.col("count") <= 1823) # Removing some outliers, 99% percentile
    .sample(fraction=0.7)
)

In [None]:
# user_count_anime_df = user_count_anime.toPandas()
# user_count_anime_df.describe(percentiles=[0.25, 0.5, 0.75, 0.9, 0.95, 0.99])

In [None]:
df_animelist = (
    df_animelist
    .withColumn("rating", (F.col("rating") / F.lit(10.0)).cast(T.FloatType()))
    .join(user_count_anime, on='user_id')
)

In [None]:
(
    df_animelist
    .select("user_id" , "anime_id", "rating")
    .orderBy(F.rand())
    .coalesce(1)
    .write.mode("overwrite").parquet(temp_dir)
)

df_animelist = spark.read.parquet(temp_dir)

In [None]:
df_animelist.count()

In [None]:
df_animelist.show()

In [None]:
df_animelist = pd.read_parquet(glob.glob(os.path.join(temp_dir, "*.parquet"))[0])
df_animelist

In [None]:
# Encoding categorical data
# user_ids = [
#     row["user_id"] for row in
#     df_animelist.select(F.col("user_id")).distinct().orderBy(F.col("user_id").asc()).collect()
# ]
user_ids = sorted(df_animelist["user_id"].unique().tolist())


user2user_encoded = {x: i for i, x in enumerate(user_ids)}
user_encoded2user = {i: x for i, x in enumerate(user_ids)}
df_animelist["user"] = df_animelist["user_id"].map(user2user_encoded).astype("int32")
# mapping_expr = F.create_map([F.lit(x) for x in chain(*user2user_encoded.items())])
# df_animelist = df_animelist.withColumn("user", mapping_expr[F.col("user_id")].cast(T.IntegerType()))

n_users = len(user2user_encoded)


# anime_ids = [
#     row["anime_id"] for row in
#     df_animelist.select(F.col("anime_id")).distinct().orderBy(F.col("anime_id").asc()).collect()
# ]
anime_ids = sorted(df_animelist["anime_id"].unique().tolist())
anime2anime_encoded = {x: i for i, x in enumerate(anime_ids)}
anime_encoded2anime = {i: x for i, x in enumerate(anime_ids)}
df_animelist["anime"] = df_animelist["anime_id"].map(anime2anime_encoded).astype("int32")
# mapping_expr2 = F.create_map([F.lit(x) for x in chain(*anime2anime_encoded.items())])
# df_animelist = df_animelist.withColumn("anime", mapping_expr2[F.col("anime_id")].cast(T.IntegerType()))

# anime_ids = rating_df["anime_id"].unique().tolist
# rating_df["anime"] = rating_df["anime_id"].map(anime2anime_encoded).astype("int32")
n_animes = len(anime2anime_encoded)

print("Num of users: {}, Num of animes: {}".format(n_users, n_animes))

In [None]:
df_animelist.to_parquet(os.path.join(output_dir, "anime_ratings.parquet"))

# (
#     df_animelist
#     .select("user_id" , "anime_id", "rating", "user", "anime")
#     .orderBy(F.rand())
#     .coalesce(1)
#     .write.mode("overwrite").parquet(output_dir)
# )

In [None]:
with open(os.path.join(output_dir, "user2user_encoded.pickle"), "wb") as f:
    pickle.dump(user2user_encoded, f)
    
with open(os.path.join(output_dir, "user_encoded2user.pickle"), "wb") as f:
    pickle.dump(user_encoded2user, f)

with open(os.path.join(output_dir, "anime2anime_encoded.pickle"), "wb") as f:
    pickle.dump(anime2anime_encoded, f)

with open(os.path.join(output_dir, "anime_encoded2anime.pickle"), "wb") as f:
    pickle.dump(anime_encoded2anime, f)
