In [1]:
import torch
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark import SparkFiles

In [10]:
KAFKA_TOPIC = "movielens"
KAFKA_BOOTSTRAP_SERVER = "localhost:9092"
NUM_USERS = 162541
NUM_MOVIES = 59047
EMBEDDING_DIM = 10
LR = 0.01

In [3]:
spark = (
    SparkSession.builder.config(
        "spark.jars.packages", "org.apache.spark:spark-sql-kafka-0-10_2.12:3.5.0"
    )
    .appName("recommender")
    .getOrCreate()
)
spark.sparkContext.setLogLevel("ERROR")

24/06/07 20:20:44 WARN Utils: Your hostname, omen resolves to a loopback address: 127.0.1.1; using 192.168.1.122 instead (on interface wlo1)
24/06/07 20:20:44 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address


:: loading settings :: url = jar:file:/home/rafael/Documentos/CAA/Project2/venv/lib/python3.10/site-packages/pyspark/jars/ivy-2.5.1.jar!/org/apache/ivy/core/settings/ivysettings.xml


Ivy Default Cache set to: /home/rafael/.ivy2/cache
The jars for the packages stored in: /home/rafael/.ivy2/jars
org.apache.spark#spark-sql-kafka-0-10_2.12 added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-b607804e-e756-4d29-b9c9-6a28853b9e06;1.0
	confs: [default]
	found org.apache.spark#spark-sql-kafka-0-10_2.12;3.5.0 in central
	found org.apache.spark#spark-token-provider-kafka-0-10_2.12;3.5.0 in central
	found org.apache.kafka#kafka-clients;3.4.1 in central
	found org.lz4#lz4-java;1.8.0 in central
	found org.xerial.snappy#snappy-java;1.1.10.3 in central
	found org.slf4j#slf4j-api;2.0.7 in central
	found org.apache.hadoop#hadoop-client-runtime;3.3.4 in central
	found org.apache.hadoop#hadoop-client-api;3.3.4 in central
	found commons-logging#commons-logging;1.1.3 in central
	found com.google.code.findbugs#jsr305;3.0.0 in central
	found org.apache.commons#commons-pool2;2.11.1 in central
:: resolution report :: resolve 501ms :: artifacts dl 25ms
	::

In [4]:
raw_data = (
    spark.readStream.format("kafka")
    .option("kafka.bootstrap.servers", KAFKA_BOOTSTRAP_SERVER)
    .option("subscribe", KAFKA_TOPIC)
    .option("startingOffsets", "latest")
    .load()
)

data = (
    raw_data.selectExpr("CAST(value AS STRING) as value")
    .select(
        from_json(
            "value", "userId INT, movieId INT, rating DOUBLE, timestamp INT"
        ).alias("data")
    )
    .select("data.*")  # unpack dict
    .selectExpr("userId", "movieId", "rating")
)

In [None]:
# define linear model for collaborative filtering
class LinearModel(torch.nn.Module):
    def __init__(self, num_users, num_items, k, lr):
        super().__init__()
        # define user matrix (U) and item matrix (I)
        self.U = torch.nn.Parameter(torch.zeros(num_users, k))
        self.I = torch.nn.Parameter(torch.zeros(num_items, k))

        # initialise the matrices
        torch.nn.init.xavier_uniform_(self.U)
        torch.nn.init.xavier_uniform_(self.I)

        # define the loss function
        self.loss = torch.nn.MSELoss()

        # define the optimizer
        self.optim = torch.optim.Adam(self.parameters(), lr=lr)

    def forward(self):
        return self.U @ self.I.T
    
    def calculate_loss(self, actual, predicted):
        return self.loss(actual, predicted)
    
    def optimise(self, loss):
        self.optim.zero_grad()
        loss.backward()
        self.optim.step()

    def fit(self, ratings, num_epochs=100):
        for epoch in range(num_epochs):
            predicted_ratings = self.forward()
            loss = self.calculate_loss(ratings, predicted_ratings)
            self.optimise(loss)
            print(f"Epoch {epoch+1}, Loss: {loss.item()}")

lin_model = LinearModel(NUM_USERS, NUM_MOVIES, EMBEDDING_DIM, LR)

In [None]:
# delete query 'ratings_' if it exists
for q in spark.streams.active:
    if q.name == "ratings_":
        q.stop()


def fine_tune(batch_df, batch_id):
    if batch_df.count() == 0:
        print(f"[BATCH {batch_id}] Empty batch")
        return

    # TODO: load
    print(f"[BATCH {batch_id}] Model loaded")

    print(f"[BATCH {batch_id}] Fitting Model")
    # TODO: fit
    print(f"[BATCH {batch_id}] Model fitted")

    # TODO: save
    print(f"[BATCH {batch_id}] Model saved")

    # TODO: add column 'predictions' to batch_df

    # calculate RMSE
    # rmse = (
    #     batch_df.withColumn("error", col("rating") - col("prediction"))
    #     .selectExpr("sqrt(avg(error*error)) as rmse")
    #     .collect()[0]
    #     .rmse
    # )
    # print(f"[BATCH {batch_id}] RMSE: {rmse}")


query = (
    data.writeStream.trigger(processingTime="15 seconds")
    .foreachBatch(apply_to_stream)
    .start()
)

query.awaitTermination()