SIMPLE TESTS

In [1]:
import os

os.environ["PYSPARK_PYTHON"] = "python"

In [2]:
from pyspark.sql import SparkSession

spark = SparkSession\
    .builder\
    .master("local[*]")\
    .getOrCreate()

In [3]:
from pyspark.sql.functions import col

csv_anime = spark.read.format("csv")\
    .option("header", "true")\
    .option("escape", "\"")\
    .option("inferSchema", "true")\
    .load("../csv/data.csv")\
    .withColumn("score", col("score").cast("double"))

In [5]:
from pyspark.ml.feature import VectorAssembler

assembler = VectorAssembler()\
    .setInputCols([col for col in csv_anime.columns if col != "score"])\
    .setOutputCol("features")

data_df = assembler\
    .transform(csv_anime)\
    .select("features", "score")

In [6]:
train_df, test_df = data_df.randomSplit([0.8, 0.2])

In [76]:
from pyspark.ml.regression import (
    DecisionTreeRegressor,
    RandomForestRegressor,
    GBTRegressor,
)

print("Decision Tree Regressor")
model_tree = DecisionTreeRegressor()\
    .setLabelCol("score")\
    .setFeaturesCol("features")\
    .fit(train_df)

print("Random Forest Regressor")
model_random_forest = RandomForestRegressor()\
    .setLabelCol("score")\
    .setFeaturesCol("features")\
    .setNumTrees(10)\
    .fit(train_df)

print("Gradient Boosted Tree Regressor")
model_gbt = GBTRegressor()\
    .setLabelCol("score")\
    .setFeaturesCol("features")\
    .fit(train_df)


In [77]:
predictions_tree = model_tree.transform(test_df)
predictions_random_forest = model_random_forest.transform(test_df)
predictions_gbt = model_gbt.transform(test_df)

In [78]:
from pyspark.ml.evaluation import RegressionEvaluator

evaluator = RegressionEvaluator(
    labelCol="score",
    predictionCol="prediction",
    metricName="rmse",
)

error_tree = evaluator.evaluate(predictions_tree)
print(f"Decision tree: rmse = {error_tree}")

error_random_forest = evaluator.evaluate(predictions_random_forest)
print(f"Random forest: rmse = {error_random_forest}")

error_gbt = evaluator.evaluate(predictions_gbt)
print(f"Gradient Boosted Tree: rmse = {error_gbt}")

Decision tree: rmse = 3.044417074181475
Random forest: rmse = 3.0435001772367203
Gradient Boosted Tree: rmse = 2.9433101966278787


MATRIX FACTORIZATION

In [3]:
from pyspark.mllib.recommendation import ALS, Rating

from pyspark.sql.functions import col

csv_anime = spark.read.format("csv")\
    .option("header", "true")\
    .option("escape", "\"")\
    .option("inferSchema", "true")\
    .load("../csv/data2.csv")\
    .withColumn("rating", col("rating").cast("double"))

In [4]:
ratings = csv_anime.rdd.map(lambda row: Rating(row.user, row.product, row.rating))

In [5]:
rank = 10
numIterations = 10
model = ALS.train(ratings, rank, numIterations)

In [6]:
testdata = ratings.map(lambda p: (p[0], p[1]))

In [7]:
predictions = model.predictAll(testdata).map(lambda r: ((r[0], r[1]), r[2]))

In [8]:
ratesAndPreds = ratings.map(lambda r: ((r[0], r[1]), r[2])).join(predictions)

In [9]:
MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).mean()

In [10]:
import math

RMSE = math.sqrt(MSE)

In [11]:
print("RMSE: " + str(RMSE))

RMSE: 0.311464908266571


In [9]:
import sqlite3

sqlite_conn = sqlite3.connect("../prisma/dev.db")
sqlite_cur = sqlite_conn.cursor()

users = list(set([u["user"] for u in csv_anime.select("user").collect()]))
animes = list(set([a["product"] for a in csv_anime.select("product").collect()]))

def get_user_name(id: int) -> str:
    return sqlite_cur.execute(f"SELECT name FROM user WHERE id={id}").fetchone()[0]

def get_anime_title(id: int) -> str:
    return sqlite_cur.execute(f"SELECT titleEnglish FROM anime WHERE id={id}").fetchone()[0]

def get_best_animes(user: int, n: int = 5) -> list:
    return [r.product for r in model.recommendProducts(user, n)]

def recommend_user(id: int):
    username = get_user_name(id)
    best_animes = get_best_animes(id)
    print(f"Recommendations for {username}:")
    for anime in best_animes:
        print(f"\t{get_anime_title(anime)}")

In [15]:
recommend_user(2255153)

Recommendations for karthiga:
	Naruto
	Log Horizon
	D.Gray-man
	Neon Genesis Evangelion
	K-ON!


In [18]:
model.predict(2255153, 232)

5.797619985466222

In [23]:
[get_user_name(r.user) for r in model.recommendUsers(232, 3)]

['MistButterfly', 'Tomoki-sama', 'bskai']

In [14]:
dot = """
graph G {
    fontname="Helvetica,Arial,sans-serif"
    node [fontname="Helvetica,Arial,sans-serif"]
    edge [fontname="Helvetica,Arial,sans-serif"]
"""

def add_line(line: str):
    global dot
    dot += "    " + line + "\n"

animes_added = []

N = 2

for user_id in users:
    user_name = get_user_name(user_id)
    add_line(f"U{user_id} [shape=ellipse,color=red,style=bold,label=\"{user_name}\",labelloc=b];")
    best_animes = get_best_animes(user_id, N)
    for anime_id in best_animes:
        if anime_id in animes_added:
            continue
        animes_added.append(anime_id)
        anime_title = get_anime_title(anime_id)
        anime_title = anime_title.replace("\"", "")
        add_line(f"A{anime_id} [shape=box,color=blue,style=bold,label=\"{anime_title}\",labelloc=b];")

for user_id in users:
    best_animes = get_best_animes(user_id, N)
    for anime_id in best_animes:
        add_line(f"U{user_id} -- A{anime_id}  [style=bold,color=blue];")

dot += "}\n"
print(dot)

with open("../graph.dot", "w") as f:
    f.write(dot)


graph G {
    fontname="Helvetica,Arial,sans-serif"
    node [fontname="Helvetica,Arial,sans-serif"]
    edge [fontname="Helvetica,Arial,sans-serif"]
    U1 [shape=ellipse,color=red,style=bold,label="Xinil",labelloc=b];
    A801 [shape=box,color=blue,style=bold,label="Ghost in the Shell: Stand Alone Complex 2nd GIG",labelloc=b];
    A72 [shape=box,color=blue,style=bold,label="Full Metal Panic? Fumoffu",labelloc=b];
    U14658 [shape=ellipse,color=red,style=bold,label="L-LawlietDN",labelloc=b];
    A30 [shape=box,color=blue,style=bold,label="Neon Genesis Evangelion",labelloc=b];
    A165 [shape=box,color=blue,style=bold,label="RahXephon",labelloc=b];
    U167812 [shape=ellipse,color=red,style=bold,label="LyannaStark",labelloc=b];
    A12189 [shape=box,color=blue,style=bold,label="Hyouka",labelloc=b];
    A11597 [shape=box,color=blue,style=bold,label="Nisemonogatari",labelloc=b];
    U2637159 [shape=ellipse,color=red,style=bold,label="Lithuelle",labelloc=b];
    A1482 [shape=box,color=b