In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col
from pyspark.ml.recommendation import ALS, ALSModel
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml import Pipeline
from neo4j import GraphDatabase

In [2]:
# Init Spark

spark = SparkSession.builder \
    .master("yarn") \
    .config("spark.executor.instances", "2") \
    .config("spark.executor.memory", "4g") \
    .config("spark.driver.memory", "2g") \
    .config("spark.locality.wait.node", "0") \
    .appName("InitAnimeRecommendation") \
    .getOrCreate()


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/05/26 10:14:00 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/05/26 10:14:07 WARN Client: Neither spark.yarn.jars nor spark.yarn.archive is set, falling back to uploading libraries under SPARK_HOME.


In [3]:
# Load data from HDFS
ratings_path = "hdfs:///dataset/users-score-2023.csv"
anime_path = "hdfs:///dataset/anime-filtered.csv"
users_path = "hdfs:///dataset/users-details-2023.csv"

ratings = spark.read.csv(ratings_path, header=True, inferSchema=True)
anime = spark.read.csv(anime_path, header=True, inferSchema=True)
users = spark.read.csv(users_path, header=True, inferSchema=True)

                                                                                

In [4]:
# EDA for ratings dataset
print("ratings dataset schema:")
ratings.printSchema()

print("First 10 rows of ratings dataset:")
ratings.show(10)

# Delete timestamp column
ratings = ratings.select("user_id", "anime_id", "rating")
print("First 10 rows of ratings after selection:")
ratings.show(10)

# Check duplicates
dropped_duplicates = ratings.dropDuplicates()
duplicate_count = ratings.count() - dropped_duplicates.count()
print(f"Number of duplicate records: {duplicate_count}")

#Print range for rating values
min_rating = ratings.agg({"rating": "min"}).collect()[0][0]
max_rating = ratings.agg({"rating": "max"}).collect()[0][0]
print(f"Ratings range: [{min_rating}, {max_rating}]")

# Print total number of users
num_users = ratings.select("user_id").distinct().count()
print(f"Number of users: {num_users}")

# Print total number of anime
num_anime = ratings.select("anime_id").distinct().count()
print(f"Number of anime: {num_anime}")

# Print total number of ratings
print("Total ratings:", ratings.count())

ratings dataset schema:
root
 |-- user_id: integer (nullable = true)
 |-- Username: string (nullable = true)
 |-- anime_id: integer (nullable = true)
 |-- Anime Title: string (nullable = true)
 |-- rating: string (nullable = true)

First 10 rows of ratings dataset:
+-------+--------+--------+--------------------+------+
|user_id|Username|anime_id|         Anime Title|rating|
+-------+--------+--------+--------------------+------+
|      1|   Xinil|      21|           One Piece|     9|
|      1|   Xinil|      48|         .hack//Sign|     7|
|      1|   Xinil|     320|              A Kite|     5|
|      1|   Xinil|      49|    Aa! Megami-sama!|     8|
|      1|   Xinil|     304|Aa! Megami-sama! ...|     8|
|      1|   Xinil|     306|Abenobashi Mahou☆...|     8|
|      1|   Xinil|      53|       Ai Yori Aoshi|     7|
|      1|   Xinil|      47|               Akira|     5|
|      1|   Xinil|     591|      Amaenaide yo!!|     6|
|      1|   Xinil|      54|   Appleseed (Movie)|     7|
+-----

                                                                                

Number of duplicate records: 0


                                                                                

Ratings range: [ Doushite Konna Otoko ni......""", 9]


                                                                                

Number of users: 270033


                                                                                

Number of anime: 16500




Total ratings: 24325191


                                                                                

In [5]:
# Calcolo della frazione di campionamento
fraction = 1/100  # 0.01

# Esegui il campionamento
ratings_sampled = ratings.sample(withReplacement=False, fraction=fraction, seed=42).cache()

# Print total number of ratings
print("Total ratings after sampling:", ratings_sampled.count())



Total ratings after sampling: 244390


                                                                                

In [6]:
ratings_sampled.select("rating").distinct().show()

+-------------------+
|             rating|
+-------------------+
|                  7|
|                  3|
|                  8|
|          Igi Ari!"|
|                  5|
|                  6|
| Igi Ari! Season 2"|
|                  9|
|                  1|
|                 10|
|                  4|
|                  2|
+-------------------+



In [7]:
users = users.select("Mal ID", "Username")
users = users.withColumnRenamed("Mal ID", "user_id")

In [8]:
# EDA for anime dataset
print("anime dataset schema:")
anime.printSchema()
print("First 10 rows of anime dataset:")
anime.show(10)

anime dataset schema:
root
 |-- anime_id: integer (nullable = true)
 |-- Name: string (nullable = true)
 |-- Score: string (nullable = true)
 |-- Genres: string (nullable = true)
 |-- English name: string (nullable = true)
 |-- Japanese name: string (nullable = true)
 |-- sypnopsis: string (nullable = true)
 |-- Type: string (nullable = true)
 |-- Episodes: string (nullable = true)
 |-- Aired: string (nullable = true)
 |-- Premiered: string (nullable = true)
 |-- Producers: string (nullable = true)
 |-- Licensors: string (nullable = true)
 |-- Studios: string (nullable = true)
 |-- Source: string (nullable = true)
 |-- Duration: string (nullable = true)
 |-- Rating: string (nullable = true)
 |-- Ranked: string (nullable = true)
 |-- Popularity: string (nullable = true)
 |-- Members: string (nullable = true)
 |-- Favorites: string (nullable = true)
 |-- Watching: string (nullable = true)
 |-- Completed: string (nullable = true)
 |-- On-Hold: string (nullable = true)
 |-- Dropped: string

In [9]:
cols_to_drop = ['English name', 'Japanese name', "Aired", "Duration", "Premiered", "Ranked", "Popularity", "Members", "Favorites", "Watching", "Completed", "On-Hold", "Dropped", "Producers", "Licensors", "Studios"]

anime = anime.drop(*cols_to_drop)

In [10]:
def extract_elements(iterator):
    distinct_genres_local = set()
    for row in iterator:
        if(row["Genres"] != None):
            genres_list = row["Genres"].split(", ")
            distinct_genres_local.update(genres_list)
    yield distinct_genres_local  # attenzione: questo è un set per partizione

# Ottieni i set locali da ogni partizione
distinct_genres_sets = anime.rdd.mapPartitions(extract_elements).collect()

# Unisci tutti i set (in memoria sul driver, ma dopo aggregazione parziale)
distinct_genres = set().union(*distinct_genres_sets)
print(distinct_genres)

[Stage 46:>                                                         (0 + 2) / 2]

{'Kids', 'Shounen', 'Seinen', 'Mystery', 'Military', 'Slice of Life', 'Drama', 'Psychological', 'Mecha', 'Supernatural', 'Shoujo', 'Harem', 'Comedy', 'Vampire', 'Game', 'Music', 'Romance', 'Parody', '6.51', 'Samurai', 'Thriller', 'Sports', '6.36', 'Yuri', 'Magic', 'Fantasy', 'Martial Arts', 'Dementia', 'Ecchi', '7.14', 'Super Power', 'Demons', 'School', 'Historical', 'Yaoi', 'Sci-Fi', 'Adventure', 'Shounen Ai', 'Josei', 'Shoujo Ai', '6.45', 'Police', 'Cars', 'Unknown', 'Space', 'Horror', 'Hentai', 'Action'}


                                                                                

In [11]:
genres_to_be_removed = ["6.36", "6.45", "6.51", "7.14"]

In [12]:
#sostituisce la colonna dei generi in formato stringa con una colonna dei generi in formato array di stringhe, includendo solo i generi validi
from pyspark.sql.functions import udf, size, when, array
from pyspark.sql.types import ArrayType, StringType

def create_genres_list(genres):
    genres_list = genres.split(', ') if genres != None else []
    genres_list = [x for x in genres_list if x not in genres_to_be_removed]
    return genres_list

create_genres_list_udf = udf(create_genres_list, ArrayType(StringType()))
anime = anime.withColumn("genres_list", create_genres_list_udf(col("Genres")))

anime = anime.filter(size(col("genres_list")) > 0)
anime = anime.withColumn("genres_list", when(col("genres_list")[0] == "Unknown", array()).otherwise(col("genres_list")))
anime = anime.drop("Genres")

In [13]:
#from pyspark.sql.functions import when

#anime.select("Type").distinct().show(anime.count(), truncate=False)

anime = anime.filter(anime["Type"].isin(["TV", "OVA", "Special", "Movie", "Music", "ONA"]))
anime = anime.withColumn("Type", when(col("Type").isNull(), "").otherwise(col("Type")))
#anime.select("Type").distinct().show(anime.count(), truncate=False)

In [14]:
#consider only those anime having valuid synopsis
anime = anime.withColumn("sypnopsis", when(col("sypnopsis").isin(["No synopsis information has been added to this title. Help improve our database by adding a synopsis here ."]) | col("sypnopsis").isNull(), "").otherwise(col("sypnopsis")))

In [15]:
#convert number of episodes to int. Inavlid values are converted to None
anime = anime.filter(col("Episodes").cast("int").isNotNull())

In [16]:
anime.show(10)

[Stage 47:>                                                         (0 + 1) / 1]

+--------+--------------------+-----+--------------------+-----+--------+-----------+--------------------+--------------------+
|anime_id|                Name|Score|           sypnopsis| Type|Episodes|     Source|              Rating|         genres_list|
+--------+--------------------+-----+--------------------+-----+--------+-----------+--------------------+--------------------+
|       5|Cowboy Bebop: Ten...| 8.39|other day, anothe...|Movie|       1|   Original|R - 17+ (violence...|[Action, Drama, M...|
|       7|  Witch Hunter Robin| 7.27|ches are individu...|   TV|      26|   Original|PG-13 - Teens 13 ...|[Action, Mystery,...|
|       8|      Bouken Ou Beet| 6.98|It is the dark ce...|   TV|      52|      Manga|       PG - Children|[Adventure, Fanta...|
|      16|Hachimitsu to Clover| 8.06|Yuuta Takemoto, a...|   TV|      24|      Manga|PG-13 - Teens 13 ...|[Comedy, Drama, J...|
|      17|Hungry Heart: Wil...| 7.59|Kyosuke Kano has ...|   TV|      52|      Manga|PG-13 - Teens 13 ..

                                                                                

In [17]:
anime.select("Source").distinct().show(100000)



+-------------+
|       Source|
+-------------+
| Visual novel|
|Digital manga|
|     Original|
|        Novel|
| Picture book|
|         Book|
|      Unknown|
|        Other|
|        Radio|
|        Manga|
| 4-koma manga|
|        Music|
|         Game|
|    Web manga|
|    Card game|
|  Light novel|
+-------------+



                                                                                

In [18]:
anime = anime.withColumn("Source", when(col("Source") == "Unknown", "").otherwise(col("Source")))
anime.select("Source").distinct().show()

+-------------+
|       Source|
+-------------+
| Visual novel|
|Digital manga|
|     Original|
|        Novel|
| Picture book|
|         Book|
|        Other|
|        Radio|
|        Manga|
| 4-koma manga|
|        Music|
|         Game|
|    Web manga|
|    Card game|
|             |
|  Light novel|
+-------------+



In [19]:
anime.select("Rating").distinct().show(10)

+--------------------+
|              Rating|
+--------------------+
|R - 17+ (violence...|
|        G - All Ages|
|             Unknown|
|       PG - Children|
|    R+ - Mild Nudity|
|PG-13 - Teens 13 ...|
|         Rx - Hentai|
+--------------------+



In [20]:
anime = anime.withColumn("Rating", when(col("Rating") == "Unknown", "").otherwise(col("Rating")))

In [21]:
#Keep only the ratings of anime in "anime"

ratings_sampled = ratings_sampled.withColumn("rating", col("rating").cast("int")).filter(col("rating").isNotNull())

ratings_sampled = ratings_sampled.join(anime.select("anime_id"), on="anime_id", how="inner")

ratings_sampled.cache()
ratings_sampled.show(10)
print(f"Numero di anime con valutazione: {ratings_sampled.select("anime_id").distinct().count()}")

                                                                                

+--------+-------+------+
|anime_id|user_id|rating|
+--------+-------+------+
|    3229|      1|     6|
|   28805|      4|     8|
|   14353|      4|     7|
|    6702|     20|     6|
|    1362|     47|     2|
|     379|     47|     6|
|      45|     47|     6|
|    1858|     48|     7|
|    5978|     48|     6|
|     323|     48|     5|
+--------+-------+------+
only showing top 10 rows

Numero di anime con valutazione: 6675


In [22]:
ratings_sampled.count()

162254

In [23]:
ratings_path = "hdfs:///ratings"
ratings_sampled.write.format("parquet").mode("overwrite").option("header", True).save(ratings_path)

                                                                                

# Prepare anime subset for content based filtering

In [21]:
anime_cbf = anime.filter((col("sypnopsis") != "") & (size(col("genres_list")) > 0) & (col("Source") != "") & (col("Rating") != "") & (col("Type") != ""))

print(f"Numero di anime per content base filtering: {anime_cbf.count()}")

Numero di anime per content base filtering: 8199


In [22]:
#Vettorizzazione di sypnopsis
from pyspark.ml.feature import Tokenizer, StopWordsRemover, HashingTF, IDF

tokenizer = Tokenizer(inputCol="sypnopsis", outputCol="words")
remover = StopWordsRemover(inputCol="words", outputCol="filtered_words")
hashingTF = HashingTF(inputCol="filtered_words", outputCol="rawFeatures", numFeatures=5000)
idf = IDF(inputCol="rawFeatures", outputCol="tfidfFeatures")

pipeline = Pipeline(stages=[tokenizer, remover, hashingTF, idf])
model = pipeline.fit(anime_cbf)
anime_cbf = model.transform(anime_cbf)

                                                                                

In [23]:
#One-Hot-Encoding di "Source", "Rating", "Type"
from pyspark.ml.feature import StringIndexer, OneHotEncoder

categorical_cols = ['Source', 'Rating', 'Type']
indexers = [StringIndexer(inputCol=c, outputCol=c+"_idx") for c in categorical_cols]
encoders = [OneHotEncoder(inputCol=c+"_idx", outputCol=c+"_vec") for c in categorical_cols]

# Esegui le trasformazioni
for idx, enc in zip(indexers, encoders):
    anime_cbf = idx.fit(anime_cbf).transform(anime_cbf)
    anime_cbf = enc.fit(anime_cbf).transform(anime_cbf)

                                                                                

In [24]:
#Multi-Hot-Encoding di genres_list (In this case CountVectorizer is equivalent to MHE)
from pyspark.ml.feature import CountVectorizer

cv = CountVectorizer(inputCol="genres_list", outputCol="genres_vec")
cv_model = cv.fit(anime_cbf)
anime_cbf = cv_model.transform(anime_cbf)


                                                                                

In [25]:
from pyspark.ml.feature import VectorAssembler

assembler = VectorAssembler(
    inputCols=["tfidfFeatures", "Source_vec", "Rating_vec", "Type_vec", "genres_vec"],
    outputCol="combined_features"
)
anime_cbf = assembler.transform(anime_cbf)
anime_cbf = anime_cbf.select("anime_id", "combined_features")


In [26]:
from pyspark.sql.functions import col, expr, row_number, collect_list, struct
from pyspark.sql.window import Window
from pyspark.ml.linalg import Vectors, VectorUDT
from pyspark.sql.functions import udf
import numpy as np
from numpy import dot
from numpy.linalg import norm
from pyspark.sql.types import DoubleType, IntegerType, StructType, ArrayType, StructField

# Cosine similarity UDF
def cosine_similarity(v1, v2):
    v1 = np.array(v1.toArray())
    v2 = np.array(v2.toArray())
    denom = norm(v1) * norm(v2)
    if denom == 0:
        return float(0.0)
    return float(dot(v1, v2) / denom)

cosine_similarity_udf = udf(cosine_similarity, DoubleType())

# 1. Esegui una self-join (crossJoin)
anime_cbf_blocks = anime_cbf.randomSplit([0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125], seed='40')

similar_anime_struct = StructType([
    StructField("anime_id", IntegerType(),True),
    StructField("similarity", DoubleType(),True)])

schema = StructType([
    StructField("anime_id", IntegerType(),True),
    StructField("similar_animes", ArrayType(similar_anime_struct),True)])

top_similar = spark.createDataFrame([], schema)


for a_cbf_block in anime_cbf_blocks:
    
    anime_pairs = a_cbf_block.alias("a").crossJoin(anime_cbf.alias("b")) \
        .filter(col("a.anime_id") != col("b.anime_id"))
    
    #anime_pairs.persist(pyspark.StorageLevel.MEMORY_AND_DISK)
    
    # 2. Calcola la similarità
    anime_similarities = anime_pairs.withColumn(
        "similarity",
        cosine_similarity_udf(col("a.combined_features"), col("b.combined_features"))
    ).select(
        col("a.anime_id").alias("anime_id"),
        col("b.anime_id").alias("similar_anime_id"),
        col("similarity")
    )
    
    # 3. Aggiungi un ranking per top-N
    windowSpec = Window.partitionBy("anime_id").orderBy(col("similarity").desc())
    
    top_n = 10
    
    anime_top_n = anime_similarities.withColumn(
        "rank", row_number().over(windowSpec)
    ).filter(col("rank") <= top_n)
    
    # 4. Raggruppa i risultati in array
    anime_top_n_grouped = anime_top_n.groupBy("anime_id").agg(
        collect_list(struct(col("similar_anime_id").alias("anime_id"), col("similarity"))).alias("similar_animes")
    )

    top_similar = top_similar.union(anime_top_n_grouped)

# Risultato finale
top_similar.cache()
top_similar.show(10)

25/05/22 21:46:32 WARN DAGScheduler: Broadcasting large task binary with size 1071.1 KiB
25/05/22 21:47:25 WARN DAGScheduler: Broadcasting large task binary with size 1077.4 KiB
25/05/22 21:47:25 WARN DAGScheduler: Broadcasting large task binary with size 1077.4 KiB


+--------+--------------------+
|anime_id|      similar_animes|
+--------+--------------------+
|     148|[{41545, 0.290830...|
|    2366|[{36752, 0.200686...|
|   24171|[{31327, 0.518369...|
|   24347|[{2245, 0.4312990...|
|   31236|[{31244, 0.255097...|
|   38220|[{37184, 0.603756...|
|   38422|[{39570, 0.183520...|
|    4158|[{8619, 0.2802179...|
|    4190|[{632, 0.42003983...|
|   12715|[{12223, 0.361718...|
+--------+--------------------+
only showing top 10 rows



In [27]:
dataframe_path = "hdfs:///dataframe"
top_similar.write.format("parquet").mode("overwrite").option("header", True).save(dataframe_path)

25/05/22 21:47:27 WARN DAGScheduler: Broadcasting large task binary with size 1276.8 KiB
                                                                                

In [4]:
# Percorso alla cartella su HDFS (sostituisci con il tuo percorso reale)
top_similars_path = "hdfs:///dataframe"

# Leggi tutti i file parquet nella cartella
top_similars = spark.read.parquet(top_similars_path).cache()

                                                                                

# Collaborative filtering algorithm (ALS)

In [24]:
# Split data in training and test
(training, test) = ratings_sampled.randomSplit([0.8, 0.2], seed='40')

print("Total training ratings:", training.count())
print("Total test ratings:", test.count())

                                                                                

Total training ratings: 129872
Total test ratings: 32382


In [25]:
# Configure ALS algorithm
#als = ALS(userCol="userId", itemCol="anime_id", ratingCol="rating", coldStartStrategy="drop", maxIter=10, regParam=0.5, rank=40)
als = ALS(userCol="user_id", itemCol="anime_id", ratingCol="rating", coldStartStrategy="drop")

# create pipeline
pipeline = Pipeline(stages=[als])

param_grid = ParamGridBuilder() \
    .addGrid(als.rank, [10, 30, 40]) \
    .addGrid(als.regParam, [0.01, 0.1, 0.5]) \
    .addGrid(als.maxIter, [10]) \
    .build()

evaluator = RegressionEvaluator(metricName="rmse", labelCol="rating", predictionCol="prediction")

cross_validator = CrossValidator(estimator=pipeline,
                                 estimatorParamMaps=param_grid,
                                 evaluator=evaluator,
                                 numFolds=3)                                 


In [26]:
# Fit the model on training data
#model=pipeline.fit(training)
cv_model = cross_validator.fit(training)
model = cv_model.bestModel


                                                                                

In [27]:
print(model.stages[0].rank) 
print(model.stages[0]._java_obj.parent().getMaxIter()) 
print(model.stages[0]._java_obj.parent().getRegParam()) 

# Model prediction
predictions = model.transform(test)

40
10
0.5


In [28]:
#SCORES
#evaluator = RegressionEvaluator(metricName="rmse", labelCol="rating", predictionCol="prediction")
rmse = evaluator.evaluate(predictions)
mae_evaluator = RegressionEvaluator(metricName="mae", labelCol="rating", predictionCol="prediction")
mae = mae_evaluator.evaluate(predictions)

print(f"Root-mean-square error = {rmse}") 
print(f"Mean Absolute Error = {mae}")

                                                                                

Root-mean-square error = 2.3258878695370018
Mean Absolute Error = 1.8648179615182792


In [None]:
#Save model in hdfs (Optional)
model_path = "hdfs:///model"
model.write().overwrite().save(model_path)



In [4]:
#get 5 anime recommendations for users
userRecs = model.stages[0].recommendForAllUsers(5).cache()  
print("userRecs dataset schema:")
userRecs.printSchema()

userRecs dataset schema:
root
 |-- user_id: integer (nullable = false)
 |-- recommendations: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- anime_id: integer (nullable = true)
 |    |    |-- rating: float (nullable = true)



In [5]:
# Save recommendation in hdfs (Optional)
recommendation_path = "hdfs:///output"
userRecs.write.format("parquet").mode("overwrite").option("overwrite", True).save(recommendation_path)


                                                                                

# Load data to Neo4j

In [3]:
# Connection to Neo4j
uri = "bolt://localhost:7687"
driver = GraphDatabase.driver(uri, auth=("neo4j", "bigdata"))

In [5]:
# Delete all from Neo4j
def delete_all(tx):
    tx.run("MATCH ()-[r:SIMILAR_TO]->() DELETE r") #delete all relationships
    # tx.run("MATCH (n:User) DELETE n") #delete all nodes
           
with driver.session() as session:
    session.execute_write(delete_all)

In [25]:
# create nodes Genre for anime genres
def create_genre_node(tx, name):
    tx.run("MERGE (:Genre {name: $name})",
           name=name)
    
with driver.session() as session:
    for genre in distinct_genres:
        session.execute_write(create_genre_node, genre)

In [27]:
# create nodes Genre for anime genres

distinct_types = anime.select("Type").filter(col("Type") != "").distinct().collect()

def create_type_node(tx, name):
    tx.run("MERGE (:Type {name: $name})",
           name=name)
    
with driver.session() as session:
    for t in distinct_types:
        session.execute_write(create_type_node, t.Type)

In [28]:
distinct_sources = anime.select("Source").filter(col("Source") != "").distinct().collect()

def create_source_node(tx, name):
    tx.run("MERGE (:Source {name: $name})",
           name=name)
    
with driver.session() as session:
    for source in distinct_sources:
        session.execute_write(create_source_node, source.Source)

In [29]:
distinct_ratings = anime.select("Rating").filter(col("Rating") != "").distinct().collect()

def create_rating_node(tx, name):
    tx.run("MERGE (:Rating {name: $name})",
           name=name)
    
with driver.session() as session:
    for rating in distinct_ratings:
        session.execute_write(create_rating_node, rating.Rating)

In [32]:
def create_anime(tx, anime_id, name, synopsis, episodes, score, genres_list, anime_type, source, rating):
    tx.run("MERGE (a:Anime {id: $anime_id, title: $name, synopsis: $synopsis, episodes: $episodes, average_score: $score}) ",
           anime_id=anime_id, name=name, synopsis=synopsis, episodes=episodes, score=score)
    for genre in genres_list:
        tx.run(
            "MATCH (a:Anime {id: $anime_id}), (g:Genre {name: $genre})"
            "MERGE (a)-[:BELONGS_TO_GENRE]->(g)",
            anime_id=anime_id, genre=genre)
    tx.run(
        "MATCH (a:Anime {id: $anime_id}), (t:Type {name: $anime_type})"
        "MERGE (a)-[:TYPE]->(t)",
        anime_id=anime_id, anime_type=anime_type)

    tx.run(
        "MATCH (a:Anime {id: $anime_id}), (s:Source {name: $source})"
        "MERGE (a)-[:ADAPTED_FROM]->(s)",
        anime_id=anime_id, source=source)

    tx.run(
        "MATCH (a:Anime {id: $anime_id}), (r:Rating {name: $rating})"
        "MERGE (a)-[:HAS_RATING]->(r)",
        anime_id=anime_id, rating=rating)

with driver.session() as session:
    for a in anime.collect():
        session.execute_write(create_anime, a.anime_id, a.Name, a.sypnopsis, a.Episodes, a.Score, a.genres_list, a.Type, a.Source, a.Rating)

                                                                                

In [29]:
users_who_rated = ratings_sampled.select("user_id").distinct().join(users, on="user_id", how="inner")
print(users_who_rated.count())

def create_user_node(tx, user_id, username):
    tx.run("MERGE (u:User {user_id: $user_id, username: $username})",
            user_id=user_id, username=username)

with driver.session() as session:
    for row in users_who_rated.collect():
        session.execute_write(create_user_node, row.user_id, row.Username)

                                                                                

In [31]:
import numpy as np
import math

ratings_sampled.show(5)

def create_user_ratings(tx,row):
    tx.run("""
           MATCH (u:User {user_id: $user_id})
           MATCH (a:Anime {id: $anime_id})
           MERGE (u)-[:RATED {rating: $rating}]->(a)
           """,
           user_id=row.user_id, anime_id=row.anime_id, rating=row.rating)

weights = np.full(math.ceil(ratings_sampled.count()/5000), 1.0).tolist()
ratings_split = ratings_sampled.randomSplit(weights, seed=40)

with driver.session() as session:
    for split in ratings_split:
        for row in split.collect():
           session.execute_write(create_user_ratings,row)

+--------+-------+------+
|anime_id|user_id|rating|
+--------+-------+------+
|    3229|      1|     6|
|   28805|      4|     8|
|   14353|      4|     7|
|    6702|     20|     6|
|    1362|     47|     2|
+--------+-------+------+
only showing top 5 rows



In [8]:
# create relations :RECOMMENDED
def create_recommendations(tx, user_id, recs):
    for rec in recs:
        tx.run("""
                MATCH (u:User {user_id: $user_id})
                MATCH (a:Anime {id: $anime_id})
                MERGE (u)-[:RECOMMENDED{rating: $rating}]->(a)
                """,
               user_id=user_id, anime_id=rec.anime_id, rating=rec.rating)

with driver.session() as session:
    for row in userRecs.collect():
        recs = row.recommendations
        recs = recs if isinstance(recs, list) else [recs]
        session.execute_write(create_recommendations, row.user_id, recs)

                                                                                

In [6]:
def create_similar(tx, anime1_id, recs):
    for rec in recs:
        tx.run("""
                MATCH (a1:Anime {id: $anime1_id})
                MATCH (a2:Anime {id: $anime2_id})
                MERGE (a1)-[:SIMILAR_TO{similarity: $similarity}]->(a2)
                """,
               anime1_id=anime1_id, anime2_id=rec.anime_id, similarity=rec.similarity)

with driver.session() as session:
    for row in top_similars.collect():
        recs = row.similar_animes
        recs = recs if isinstance(recs, list) else [recs]
        session.execute_write(create_similar, row.anime_id, recs)

                                                                                

In [None]:
#close all
driver.close()
spark.stop()