In [1]:
import torch
from torch.utils.data import DataLoader, Dataset, random_split
import pandas as pd
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark import SparkFiles
import numpy as np
import wandb

In [2]:
wandb.login(key="1bf6d96598e920a3fe32392d71154f5e9011cdbd", relogin=True)
wandb.init(project="proj-caa-2")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/rafael/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mrfg[0m ([33mrafaelgoncalvesua[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
KAFKA_TOPIC = "movielens"
KAFKA_BOOTSTRAP_SERVER = "localhost:9092"
MODEL_NAME = "non_linear"
NUM_USERS = 162541
NUM_MOVIES = 59047
EMBEDDING_DIM = 10
BATCH_SIZE = 64
LR = 0.01

In [4]:
spark = (
    SparkSession.builder.config(
        "spark.jars.packages", "org.apache.spark:spark-sql-kafka-0-10_2.12:3.5.0"
    )
    .appName("recommender")
    .getOrCreate()
)

sc = spark.sparkContext
sc.setLogLevel("ERROR")


24/06/15 01:59:27 WARN Utils: Your hostname, omen resolves to a loopback address: 127.0.1.1; using 192.168.1.122 instead (on interface wlo1)
24/06/15 01:59:27 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address


:: loading settings :: url = jar:file:/home/rafael/Documentos/CAA/Project2/venv/lib/python3.10/site-packages/pyspark/jars/ivy-2.5.1.jar!/org/apache/ivy/core/settings/ivysettings.xml


Ivy Default Cache set to: /home/rafael/.ivy2/cache
The jars for the packages stored in: /home/rafael/.ivy2/jars
org.apache.spark#spark-sql-kafka-0-10_2.12 added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-7c9ca5d2-6977-4565-8096-083d3716f5dc;1.0
	confs: [default]
	found org.apache.spark#spark-sql-kafka-0-10_2.12;3.5.0 in central
	found org.apache.spark#spark-token-provider-kafka-0-10_2.12;3.5.0 in central
	found org.apache.kafka#kafka-clients;3.4.1 in central
	found org.lz4#lz4-java;1.8.0 in central
	found org.xerial.snappy#snappy-java;1.1.10.3 in central
	found org.slf4j#slf4j-api;2.0.7 in central
	found org.apache.hadoop#hadoop-client-runtime;3.3.4 in central
	found org.apache.hadoop#hadoop-client-api;3.3.4 in central
	found commons-logging#commons-logging;1.1.3 in central
	found com.google.code.findbugs#jsr305;3.0.0 in central
	found org.apache.commons#commons-pool2;2.11.1 in central
:: resolution report :: resolve 534ms :: artifacts dl 21ms
	::

In [5]:
raw_data = (
    spark.readStream.format("kafka")
    .option("kafka.bootstrap.servers", KAFKA_BOOTSTRAP_SERVER)
    .option("subscribe", KAFKA_TOPIC)
    .option("startingOffsets", "latest")
    .load()
)

data = (
    raw_data.selectExpr("CAST(value AS STRING) as value")
    .select(
        from_json(
            "value", "userId INT, movieId INT, rating DOUBLE, timestamp INT"
        ).alias("data")
    )
    .select("data.*")  # unpack dict
    .selectExpr("userId", "movieId", "rating")
)

In [6]:
class MovieLensDataset(Dataset):
    def __init__(self, ratings):
        self.users = ratings['userId'].values
        self.items = ratings['movieId'].values
        self.ratings = ratings['rating'].values

    def __len__(self):
        return len(self.ratings)

    def __getitem__(self, idx):
        return (self.users[idx], self.items[idx], self.ratings[idx])
    
class CollaborativeFilteringModel(torch.nn.Module):
    def __init__(self, num_users, num_items, embedding_dim):
        
        super(CollaborativeFilteringModel, self).__init__()
        self.user_embedding = torch.nn.Embedding(num_users, embedding_dim)
        self.item_embedding = torch.nn.Embedding(num_items, embedding_dim)
        self.fc1 = torch.nn.Linear(embedding_dim * 2, 128)
        self.fc2 = torch.nn.Linear(128, 64)
        self.fc3 = torch.nn.Linear(64, 1)

        self.criterion = torch.nn.MSELoss()
        self.optim = torch.optim.Adam(self.parameters(), lr=0.001)
    
    def forward(self, user_ids, item_ids):
        if torch.any(user_ids >= self.user_embedding.num_embeddings):
            raise ValueError(f"user_ids contain indices outside the range: {user_ids} | {self.user_embedding.num_embeddings}")
        if torch.any(item_ids >= self.item_embedding.num_embeddings):
            raise ValueError(f"item_ids contain indices outside the range: {item_ids} | {self.item_embedding.num_embeddings}")

        user_embeds = self.user_embedding(user_ids)
        item_embeds = self.item_embedding(item_ids)
        x = torch.cat([user_embeds, item_embeds], dim=1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def train_(self, train_loader):
        self.train()
        total_loss = 0

        for user_ids, item_ids, ratings in train_loader:
            user_ids, item_ids, ratings = user_ids.to(self.device), item_ids.to(self.device), ratings.to(self.device).float()
            ratings = ratings.float().view(-1, 1)

            self.optim.zero_grad()
            predictions = self(user_ids, item_ids)
            loss = self.criterion(predictions, ratings)
            loss.backward()
            self.optim.step()

            total_loss += loss.item()

        return total_loss / len(train_loader)
    
    def evaluate(self, test_loader):
        self.eval()
        total_loss = 0

        with torch.no_grad():
            for user_ids, item_ids, ratings in test_loader:
                user_ids, item_ids, ratings = user_ids.to(self.device), item_ids.to(self.device), ratings.to(self.device).float()
                ratings = ratings.float().view(-1, 1)

                predictions = self(user_ids, item_ids)
                loss = self.criterion(predictions, ratings)

                total_loss += loss.item()

        return total_loss / len(test_loader)

    
    def fit(self, train_loader, val_loader, test_loader=None, num_epochs=10):
        # self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.device = torch.device('cpu')

        for epoch in range(num_epochs):
            train_loss = self.train_(train_loader)
            val_loss = self.evaluate(val_loader)
            print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {train_loss}, Val Loss: {val_loss}')
            wandb.log({"train_loss": train_loss, "val_loss": val_loss})

        if test_loader:
            test_loss = self.evaluate(test_loader)
            print(f'Test Loss: {test_loss}')
            wandb.log({"test_loss": test_loss})

            

In [7]:
model = CollaborativeFilteringModel(NUM_USERS, NUM_MOVIES, EMBEDDING_DIM)

# save torch model
torch.save(model.state_dict(), f"models/{MODEL_NAME}.pth")

# load torch model
model.load_state_dict(torch.load(f"models/{MODEL_NAME}.pth"))

# broadcast data buffer
broad_data_buffer = sc.broadcast([])

# static test dataset and user/item index mapping
test_df = pd.read_csv("data/ratings_test.csv")
user2idx = {user: idx for idx, user in enumerate(test_df['userId'].unique())}
item2idx = {item: idx for idx, item in enumerate(test_df['movieId'].unique())}

test_df['userId'] = test_df['userId'].map(user2idx)
test_df['movieId'] = test_df['movieId'].map(item2idx)
test_loader = DataLoader(MovieLensDataset(test_df), batch_size=BATCH_SIZE, shuffle=False)

In [8]:
# delete query 'ratings_' if it exists
for q in spark.streams.active:
    if q.name == "ratings_":
        q.stop()


def fine_tune(batch_df, batch_id):
    if batch_df.count() == 0:
        print(f"[BATCH {batch_id}] Empty batch")
        return
    
    batch_pandas = batch_df.toPandas()

    for user in batch_pandas["userId"]:
        if user not in user2idx:
            user2idx[user] = len(user2idx)

    for item in batch_pandas["movieId"]:
        if item not in item2idx:
            item2idx[item] = len(item2idx)
    
    # update accumulated DataFrame with new batch as dict
    global broad_data_buffer
    broad_data_buffer.unpersist()
    data_buffer = broad_data_buffer.value.copy()
    data_buffer.extend(batch_pandas.to_dict(orient="records"))

    while len(data_buffer) < 10000:
        broad_data_buffer = sc.broadcast(data_buffer)
        print(f"[BATCH {batch_id}] Current buffer size: {len(data_buffer)}")
        return

    model.load_state_dict(torch.load("models/non_linear.pth"))
    print(f"[BATCH {batch_id}] Model loaded")

    data_buffer_df = pd.DataFrame(data_buffer, index=None)
    data_buffer_df["userId"] = data_buffer_df["userId"].map(user2idx)
    data_buffer_df["movieId"] = data_buffer_df["movieId"].map(item2idx)

    dataset = MovieLensDataset(data_buffer_df)

    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    print(f"[BATCH {batch_id}] Data loaded")

    print(f"[BATCH {batch_id}] Fitting Model")
    model.fit(train_loader, val_loader, test_loader)
    print(f"[BATCH {batch_id}] Model fitted")

    torch.save(model.state_dict(), f"models/{MODEL_NAME}.pth")
    print(f"[BATCH {batch_id}] Model saved")

    # clear data buffer
    broad_data_buffer = sc.broadcast([])
    

query = (
    data.writeStream.trigger(processingTime="60 seconds")
    .foreachBatch(fine_tune)
    .start()
)

query.awaitTermination()

                                                                                

[BATCH 0] Empty batch


                                                                                

[BATCH 1] Current buffer size: 2000
[BATCH 2] Current buffer size: 7000


                                                                                

[BATCH 3] Model loaded
[BATCH 3] Data loaded
[BATCH 3] Fitting Model
Epoch 1/10, Loss: 2.7856151593243417, Val Loss: 1.418178610685395
Epoch 2/10, Loss: 1.2810803498227172, Val Loss: 1.3542559001503922
Epoch 3/10, Loss: 1.198266336157278, Val Loss: 1.3303112213204547
Epoch 4/10, Loss: 1.1386875210364171, Val Loss: 1.3023858215750717
Epoch 5/10, Loss: 1.090739487138994, Val Loss: 1.2961150349640265
Epoch 6/10, Loss: 1.058551635478903, Val Loss: 1.2939593632046769
Epoch 7/10, Loss: 1.0248008900624843, Val Loss: 1.2797560720908931
Epoch 8/10, Loss: 0.9909287171861145, Val Loss: 1.2921760939970248
Epoch 9/10, Loss: 0.9570004431747952, Val Loss: 1.308965182885891
Epoch 10/10, Loss: 0.9225196428825518, Val Loss: 1.2884252987256863
Test Loss: 1.2194591492951063
[BATCH 3] Model fitted
[BATCH 3] Model saved
[BATCH 4] Current buffer size: 6000
[BATCH 5] Model loaded
[BATCH 5] Data loaded
[BATCH 5] Fitting Model
Epoch 1/10, Loss: 1.1991066281000773, Val Loss: 1.1606443242022866
Epoch 2/10, Loss: 

ERROR:root:KeyboardInterrupt while sending command.
Traceback (most recent call last):
  File "/home/rafael/Documentos/CAA/Project2/venv/lib/python3.10/site-packages/py4j/java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
  File "/home/rafael/Documentos/CAA/Project2/venv/lib/python3.10/site-packages/py4j/clientserver.py", line 511, in send_command
    answer = smart_decode(self.stream.readline()[:-1])
  File "/usr/lib/python3.10/socket.py", line 705, in readinto
    return self._sock.recv_into(b)
KeyboardInterrupt


KeyboardInterrupt: 

Epoch 3/10, Loss: 0.9757293415242347, Val Loss: 1.232354007448469
Epoch 4/10, Loss: 0.9236061322516289, Val Loss: 1.2108759726796832
Epoch 5/10, Loss: 0.8775900053805199, Val Loss: 1.2336673395974296
Epoch 6/10, Loss: 0.8273890716010246, Val Loss: 1.2446716615131923
Epoch 7/10, Loss: 0.7742068385300429, Val Loss: 1.2652670792170932
Epoch 8/10, Loss: 0.721555026329082, Val Loss: 1.2974375213895526
Epoch 9/10, Loss: 0.668317010221274, Val Loss: 1.331337457043784
Epoch 10/10, Loss: 0.6081610272328059, Val Loss: 1.3677057198115758
Test Loss: 1.2853894289142698
[BATCH 7] Model fitted
[BATCH 7] Model saved
[BATCH 8] Current buffer size: 6000
[BATCH 9] Model loaded
[BATCH 9] Data loaded
[BATCH 9] Fitting Model
Epoch 1/10, Loss: 1.1727715166409811, Val Loss: 1.1554991540155912
Epoch 2/10, Loss: 1.0208570404847463, Val Loss: 1.1584658387460207
Epoch 3/10, Loss: 0.9540578377246857, Val Loss: 1.153820552323994
Epoch 4/10, Loss: 0.9029183801015218, Val Loss: 1.165286969197424
Epoch 5/10, Loss: 0.8