<a href="https://colab.research.google.com/github/BharathSShankar/DSA4212_Assignments/blob/bharath-exp/tfidf_vectoriser.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# mount the Google Drive
from google.colab import drive
drive.mount("/content/drive")

import torch
import pandas as pd
import numpy as np
from torch import nn 
from tqdm.auto import tqdm
import torch.nn.functional as F
np.random.seed(42)

# goto to data folder -- you may need to change this location
%cd /content/drive/MyDrive/DSA4212/Assignment\ 2/assignment_2_data

Mounted at /content/drive
/content/drive/MyDrive/DSA4212/Assignment 2/assignment_2_data


In [None]:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.cluster import KMeans
from nltk.corpus import stopwords
from nltk.stem import SnowballStemmer
from nltk.tokenize import word_tokenize

In [None]:
import nltk
nltk.download('stopwords')
nltk.download('punkt')

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [None]:
anime_df = pd.read_csv("assignment_2_anime.csv")
train_df = pd.read_csv("assignment_2_ratings_train.csv")
test_df = pd.read_csv("assignment_2_ratings_test.csv")

In [None]:
animeid2inner = {animeid:i for i,animeid in enumerate(anime_df["anime_id"])}

In [None]:
from sklearn.decomposition import TruncatedSVD
# Tokenize, remove stopwords, and stem the titles
# Your custom list of stop words
custom_stop_words = ['(TV)', "TV", 'Season', "Animation",  "2nd", "3rd"] + [str(i)+"th" for i in range(4,10)]

# Combine custom stop words with NLTK's stop words
all_stop_words = set(stopwords.words('english')).union(custom_stop_words)


def preprocess_title(title):
    word_tokens= word_tokenize(title.lower())
    words = [word for word in word_tokens if word.lower() not in all_stop_words]
    assert "Animation" not in words 
    return ' '.join(words)

processed_titles = anime_df['name'].apply(preprocess_title)

# Create a TF-IDF matrix
vectorizer = TfidfVectorizer()
tfidf_matrix = vectorizer.fit_transform(processed_titles)
n_components = 50  # The desired dimensionality of the compressed space
trunsvd = TruncatedSVD(n_components=n_components)
compressed_tfidf_matrix = trunsvd.fit_transform(tfidf_matrix)

In [None]:
tfidf_matrix[0].todense()

matrix([[0., 0., 0., ..., 0., 0., 0.]])

In [None]:
genre_map = ['Comedy','Action', 'Adventure','Fantasy', 'Sci-Fi',
 'Drama','Shounen', 'Kids', 'Romance', 'School', 'Slice of Life', 'Hentai', 'Supernatural',
 'Mecha', 'Music', 'Historical', 'Magic', 'Ecchi', 'Shoujo', 'Seinen', 'Sports', 'Mystery',
 'Super Power', 'Military', 'Parody', 'Space','Horror', 'Harem', 'Demons', 'Martial Arts', 'Dementia', 'Psychological',
 'Police', 'Game', 'Samurai', 'Vampire','Thriller',
 'Cars', 'Shounen Ai', 'NaN', 'Shoujo Ai', 'Josei', 'Yuri',
 'Yaoi']

type_map = anime_df["type"].unique()

type2id = {type_dat: i for i, type_dat in enumerate(type_map)}

genre2id = {genre:id for id, genre in enumerate(genre_map)}

def get_genre(string):
  return [i for i in string.split(", ") if i in genre_map]

def genres_to_vector(lst):
  if lst != []:
    return [genre2id[g] for g in lst]
  return [genre2id['NaN']]

def get_genre_vector(anime_id):
  out = np.array([0]*len(genre_map))
  genres = get_genre(str((anime_df[anime_df["anime_id"] == anime_id]["genre"]).values[0]))
  vector = genres_to_vector(genres)
  out[vector] = 1
  return out

class TrainDatasetEmb(torch.utils.data.Dataset):
    def __init__(self, tfidf_matrix, anime_df):
        self.data = tfidf_matrix
        self.anime_df = anime_df
        self.genre_vecs = self.precompute_genre_vectors()

    def __len__(self):
        return self.data.shape[0]

    def precompute_genre_vectors(self):
        genre_vectors = []
        for anime_id in sorted(self.anime_df['anime_id'].unique()):
            genre_vectors.append(torch.tensor(get_genre_vector(anime_id), dtype=torch.float32))
        return genre_vectors
    
    def __getitem__(self, idx):
        return torch.tensor(self.data[idx, :].todense(), dtype = torch.float32), self.genre_vecs[idx]

In [None]:
TrainDatasetEmb(tfidf_matrix, anime_df)[0]

(tensor([[0., 0., 0.,  ..., 0., 0., 0.]]),
 tensor([1., 1., 1., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0.]))

In [None]:
class BetterEmbeddingsFromTitles(nn.Module):
    def __init__(self, layer_1, layer_2, emb_layer, pred_layer):
        super().__init__()
        self.layer1 = nn.LazyLinear(layer_1)
        self.layer2 = nn.LazyLinear(layer_2)
        self.embedding = nn.LazyLinear(emb_layer)
        self.pred = nn.LazyLinear(pred_layer)

        

    def forward(self, x):
        x = self.layer1(x)
        x = F.relu(x)
        x = self.layer2(x)
        x = F.relu(x)
        x = self.embedding(x)
        x = F.relu(x)
        x = self.pred(x)
        x = F.sigmoid(x)
        return x

    def embed(self, x):
        x = self.layer1(x)
        x = F.relu(x)
        x = self.layer2(x)
        x = F.relu(x)
        x = self.embedding(x)
        return x

In [None]:
loss_fn = nn.BCEWithLogitsLoss()
dataset = TrainDatasetEmb(tfidf_matrix, anime_df)
device = torch.device("cuda:0")
model = BetterEmbeddingsFromTitles(2048, 256, 50, len(genre_map)).to(device)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)
n_epochs = 200
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)

with tqdm(range(n_epochs)) as t:
    for i in t:
        for batch_tfidf, batch_genre in tqdm(dataloader, leave = False):
            batch_tfidf, batch_genre = batch_tfidf.to(device), batch_genre.to(device)

            optimizer.zero_grad()
      
            predictions = model(batch_tfidf)
            predictions = predictions.squeeze(1)
            loss = loss_fn(predictions, batch_genre)

            loss.backward()
            t.set_description(f"loss : {loss}")
            optimizer.step()

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

  0%|          | 0/193 [00:00<?, ?it/s]

In [None]:
compressed_tfidf_matrix

array([[ 1.34218800e-01,  7.18599832e-03,  1.32158115e-01, ...,
        -2.33012261e-02,  1.14511993e-02,  1.96477477e-02],
       [ 5.39708938e-03,  9.92387695e-03, -3.77308582e-03, ...,
        -3.57136882e-05, -3.41747311e-03,  1.17743839e-03],
       [ 8.79490468e-03, -3.63809604e-03,  7.20979757e-04, ...,
         5.93199225e-03,  8.98738654e-04, -1.01026416e-02],
       ...,
       [ 1.28895148e-02, -5.65483920e-03,  4.21427985e-03, ...,
         2.59641261e-02, -3.49949485e-03, -3.61148069e-02],
       [ 5.17836528e-02, -2.69567556e-02,  4.25236671e-03, ...,
         1.64328531e-04, -1.77999551e-02, -4.28284466e-02],
       [-2.34536335e-11,  8.65166243e-11,  1.41024619e-10, ...,
         1.04202954e-08, -4.10561984e-08, -3.49274634e-08]])

In [None]:
%%capture
!pip install scikit-surprise

In [None]:
from surprise import Dataset, Reader
from surprise import SVD
from surprise.model_selection import train_test_split
from surprise import accuracy

In [None]:
# Create a Reader object for parsing the ratings dataframes
reader = Reader(rating_scale=(1, 10))

# Load trainset and testset from your pre-split rtrain and rtest dataframes
train_data = Dataset.load_from_df(train_df, reader)
trainset = train_data.build_full_trainset()

test_data = Dataset.load_from_df(test_df, reader)
testset = test_data.construct_testset(raw_testset=test_data.raw_ratings)

In [None]:
# Train the SVD algorithm on the trainset
svd = SVD()
svd.fit(trainset)

In [None]:
class CustomDataset2(torch.utils.data.Dataset):
    def __init__(self, df, trainset, anime_df):
        self.df = df
        self.trainset = trainset
        self.anime_df = anime_df
        self.genre_vectors = self.precompute_genre_vectors()
        self.user_tensors, self.item_tensors = self.precompute_user_item_tensors()

    def precompute_genre_vectors(self):
        genre_vectors = {}
        for anime_id in self.anime_df['anime_id'].unique():
            genre_vectors[anime_id] = torch.tensor(compressed_tfidf_matrix[animeid2inner[anime_id]], dtype=torch.float64)
        return genre_vectors

    def precompute_user_item_tensors(self):
        unique_user_ids = set(self.df['user_id'].unique())
        unique_anime_ids = set(self.df['anime_id'].unique())

        user_tensors = {user_id: torch.tensor(trainset._raw2inner_id_users[user_id], dtype=torch.long) for user_id in unique_user_ids}
        item_tensors = {anime_id: torch.tensor(trainset._raw2inner_id_items[anime_id], dtype=torch.long) for anime_id in unique_anime_ids}

        return user_tensors, item_tensors

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        user = self.user_tensors[row['user_id']]
        item = self.item_tensors[row['anime_id']]
        genre = self.genre_vectors[row['anime_id']]
        rating = torch.tensor(row['rating'], dtype=torch.float64)
        return user, item, genre, rating


class WideAndDeep(torch.nn.Module):
    def __init__(self, n_users, n_items, n_genres, n_factors=100, deep_layers=[64, 32]):
        super().__init__()
        self.user_factors = nn.Embedding(n_users, n_factors)
        self.item_factors = nn.Embedding(n_items, n_factors)
        self.user_biases = nn.Embedding(n_users, 1)
        self.item_biases = nn.Embedding(n_items, 1)

        # Deep layers
        self.deep_layers = nn.ModuleList()
        input_size = n_factors + n_components # user_factors + item_factors + genre_factors
        for layer_size in deep_layers:
            linear_layer = nn.Linear(input_size, layer_size)
            nn.init.xavier_normal_(linear_layer.weight)
            nn.init.zeros_(linear_layer.bias)
            self.deep_layers.append(linear_layer)
            self.deep_layers.append(nn.ReLU())
            input_size = layer_size

        # Output layer
        self.output = nn.Linear(input_size, 1)
    

    def load(self, algo, mu):
      self.mu = mu

      # Convert the NumPy array to a PyTorch tensor
      pu = torch.from_numpy(algo.pu)
      qi = torch.from_numpy(algo.qi)
      bu = torch.from_numpy(algo.bu[...,np.newaxis])
      bi = torch.from_numpy(algo.bi[...,np.newaxis])

      # Initialize the Linear layer's weight using the torch_weights tensor
      self.user_factors.weight.data = pu
      self.item_factors.weight.data = qi
      self.user_biases.weight.data = bu
      self.item_biases.weight.data = bi

    def freeze_wide(self):
      for param in [self.user_factors.weight, self.item_factors.weight,
                    self.user_biases.weight, self.item_biases.weight]:
          param.requires_grad = False

    def forward(self, user, item, genre_embeddings):  # Add genres as an input (list of genre ids)
        user_embedding = self.user_factors(user)
        item_embedding = self.item_factors(item)
        user_bias = self.user_biases(user)
        item_bias = self.item_biases(item)

        # Wide part
        wide = (user_embedding * item_embedding).sum(dim=1, keepdim=True)
        

        # Deep part
        x = torch.cat([item_embedding, genre_embeddings], dim=1)  # Concatenate aggregated genre embeddings
        for layer in self.deep_layers:
            x = layer(x)
      

        # Combine wide and deep parts
        output = self.mu + wide + user_bias + item_bias + self.output(x)
        return output.squeeze()


In [None]:
dataset = CustomDataset2(train_df, trainset, anime_df)
dev = torch.device("cpu")
dataloader = torch.utils.data.DataLoader(dataset, batch_size=512, shuffle=True)

# Instantiate your WideAndDeep model
model = WideAndDeep(trainset.n_users,
                    trainset.n_items,
                    n_components)
model.load(svd, trainset.global_mean)
model.freeze_wide()
model = model.to(torch.float64)
model = model.to(dev)


# Define loss function and optimizer"
loss_func = nn.MSELoss()
# Define loss function and optimizer
loss_func = nn.MSELoss()

# Create the optimizer using the combined parameter groups
optimizer = torch.optim.AdamW(model.parameters(), lr=0.01, weight_decay=0.01)
# Training loop
n_epochs = 1
for epoch in tqdm(range(n_epochs)):
  i = 0 
  for batch_user, batch_item, batch_genre, batch_rating in tqdm(dataloader, leave = False):
      batch_user, batch_item, batch_genre, batch_rating = batch_user.to(dev), batch_item.to(dev), batch_genre.to(dev), batch_rating.to(dev)

      optimizer.zero_grad()
      
      predictions = model(batch_user, batch_item, batch_genre)
      loss = loss_func(predictions, batch_rating)

      loss.backward()
      print(loss)
      optimizer.step()
  print(f"Epoch: {epoch+1}, Loss: {loss.item()}")

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/8665 [00:00<?, ?it/s]

tensor(0.5721, dtype=torch.float64, grad_fn=<MseLossBackward0>)
tensor(0.6231, dtype=torch.float64, grad_fn=<MseLossBackward0>)
tensor(0.5803, dtype=torch.float64, grad_fn=<MseLossBackward0>)
tensor(0.5098, dtype=torch.float64, grad_fn=<MseLossBackward0>)
tensor(0.5463, dtype=torch.float64, grad_fn=<MseLossBackward0>)
tensor(0.5617, dtype=torch.float64, grad_fn=<MseLossBackward0>)
tensor(0.5658, dtype=torch.float64, grad_fn=<MseLossBackward0>)
tensor(0.5551, dtype=torch.float64, grad_fn=<MseLossBackward0>)
tensor(0.6405, dtype=torch.float64, grad_fn=<MseLossBackward0>)
tensor(0.5049, dtype=torch.float64, grad_fn=<MseLossBackward0>)
tensor(0.5119, dtype=torch.float64, grad_fn=<MseLossBackward0>)
tensor(0.6822, dtype=torch.float64, grad_fn=<MseLossBackward0>)
tensor(0.5920, dtype=torch.float64, grad_fn=<MseLossBackward0>)
tensor(0.5947, dtype=torch.float64, grad_fn=<MseLossBackward0>)
tensor(0.6038, dtype=torch.float64, grad_fn=<MseLossBackward0>)
tensor(0.5130, dtype=torch.float64, grad

KeyboardInterrupt: ignored