# Anime Recommendation Training
Train embeddings for users and anime using PyTorch, with genre features and checkpointing.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import glob
import os
from tqdm import tqdm

## 1. Training parameters and paths

In [2]:
THRESHOLD = 100
EPOCHS = 20
CHECKPOINT_PATH = f"checkpoints/hybrid{THRESHOLD}"
os.makedirs(CHECKPOINT_PATH, exist_ok=True)
CHECKPOINT_INTERVAL = 30_000
USERNAME = 'sirawesomeness'

test is a test


In [3]:
# Count total ratings

threshold_folder = f"data/pt_files{THRESHOLD}"
pattern = os.path.join(threshold_folder, f"user_anime????????????_filtered{THRESHOLD}.pt")
csv_files = glob.glob(pattern)

# anime_counter = Counter()
count = 0
for file_path in tqdm(csv_files):
    df = torch.load(file_path)
    count += len(df['anime_idx'])
    if not len(df):
        print(file_path)
print("TOTAL_RATINGS:", count)
TOTAL_RATINGS = count
BATCH_SIZE = 1024
TOTAL_BATCHES = TOTAL_RATINGS // BATCH_SIZE

100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:02<00:00, 34.52it/s]

TOTAL_RATINGS: 127761997





## 2. Load anime data

In [4]:
anime_id_to_idx = torch.load(f"data/pt_files{THRESHOLD}/anime_id_to_idx.pt", weights_only=False)
anime_genres = torch.load(f"data/pt_files{THRESHOLD}/anime_genres.pt")

num_anime, num_genres = anime_genres.shape

## 3. Map all users

In [5]:
user_id_to_idx = torch.load(f"data/pt_files{THRESHOLD}/user_id_to_idx.pt")
num_users = len(user_id_to_idx)

## 4. Custom Dataset

In [6]:
class RatingsPTDataset(torch.utils.data.IterableDataset):
    def __init__(self, pt_files):
        super().__init__()
        self.pt_files = pt_files

    def __iter__(self):
        for pt_file in self.pt_files:
            data = torch.load(pt_file)
            for u, a, s in zip(data["user_idx"], data["anime_idx"], data["scores"]):
                yield u, a, s


## 5. Define embedding model

In [7]:
class HybridMF(nn.Module):
    def __init__(self, num_users, num_anime, num_genres, 
                 user_dim=64, anime_dim=128, genre_proj_dim=16, hidden_dim=32):
        super().__init__()
        self.user_emb = nn.Embedding(num_users, user_dim, sparse=True)
        self.anime_emb = nn.Embedding(num_anime, anime_dim, sparse=True)
        self.W_genre = nn.Linear(num_genres, genre_proj_dim)
        self.project = nn.Linear(anime_dim + genre_proj_dim, user_dim)
        
        # Nonlinear MLP
        self.mlp = nn.Sequential(
            nn.Linear(user_dim + anime_dim + genre_proj_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, user_idx, anime_idx, anime_genres):
        u = self.user_emb(user_idx)
        v = self.anime_emb(anime_idx)
        g = self.W_genre(anime_genres[anime_idx])
        
        # ----- Linear path -----
        v_combined = torch.cat([v, g], dim=1)
        v_proj = self.project(v_combined)
        linear_pred = (u * v_proj).sum(dim=1)

        # ----- Nonlinear path -----
        mlp_input = torch.cat([u, v, g], dim=1)
        nonlinear_pred = self.mlp(mlp_input).squeeze(1)
        
        return linear_pred + nonlinear_pred
    
    def recommend(self, user_idx, anime_genres, top_k=10, device="cpu", exclude_ids=None):
        self.eval()
        with torch.no_grad():
            # Repeat user_idx across all anime
            user_tensor = torch.tensor([user_idx] * anime_genres.size(0), 
                                       dtype=torch.long, device=device)
            anime_tensor = torch.arange(anime_genres.size(0), 
                                        dtype=torch.long, device=device)
            
            preds = self.forward(user_tensor, anime_tensor, anime_genres)
            
            if exclude_ids is not None:
                preds[exclude_ids] = float("-inf")  # mask out watched anime

            # Get top-k indices
            top_scores, top_indices = torch.topk(preds, top_k)
        
        return top_indices.cpu().tolist(), top_scores.cpu().tolist()

## 6. Training setup

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HybridMF(num_users, num_anime, num_genres).to(device)

sparse_params = list(model.user_emb.parameters()) + list(model.anime_emb.parameters())
dense_params = list(model.W_genre.parameters()) + list(model.project.parameters()) + list(model.mlp.parameters())

optimizer_sparse = optim.SparseAdam(sparse_params, lr=1e-2)
optimizer_dense = optim.Adam(dense_params, lr=1e-3, weight_decay=1e-5)
loss_fn = nn.MSELoss()

def save_checkpoint(model, optimizer_sparse, optimizer_dense, epoch, batch_in_epoch, loss, filename):
    torch.save({
        "epoch": epoch,
        "batch_in_epoch": batch_in_epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_sparse_state_dict": optimizer_sparse.state_dict(),
        "optimizer_dense_state_dict": optimizer_dense.state_dict(),
        "loss": loss,
    }, filename)

def load_checkpoint(model, optimizer_sparse, optimizer_dense, filename, device):
    checkpoint = torch.load(filename, map_location=device)
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer_sparse.load_state_dict(checkpoint["optimizer_sparse_state_dict"])
    optimizer_dense.load_state_dict(checkpoint["optimizer_dense_state_dict"])
    start_epoch = checkpoint.get("epoch", 0)
    start_batch_in_epoch = checkpoint.get("batch_in_epoch", 0)
    loss = checkpoint["loss"]
    return start_epoch, start_batch_in_epoch, loss

pt_files = sorted(glob.glob(f"data/pt_files{THRESHOLD}/user_anime*_filtered{THRESHOLD}.pt"))
dataset = RatingsPTDataset(pt_files)
loader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE)
anime_genres = anime_genres.to(device)

## 7. Training loop

In [9]:
start_epoch, start_batch_in_epoch, prev_loss = 0, 0, None
checkpoint_files = sorted(glob.glob(os.path.join(CHECKPOINT_PATH, "epoch*.pth")))
if checkpoint_files:
    last_checkpoint = checkpoint_files[-1]
    start_epoch, start_batch_in_epoch, prev_loss = load_checkpoint(model, optimizer_sparse, optimizer_dense, last_checkpoint, device)
    print(f"Resuming from {last_checkpoint}, epoch={start_epoch}, batch_in_epoch={start_batch_in_epoch}, prev_loss={prev_loss:.4f}")
    
for epoch in range(start_epoch, EPOCHS):
    total_loss = 0
    count = 0
    
    checkpoint_loss = 0
    checkpoint_count = 0
    
    progress_bar = tqdm(loader, total=TOTAL_BATCHES, desc=f"Epoch {epoch+1}/{EPOCHS}")
    
    for i, batch in enumerate(progress_bar):
        if epoch == start_epoch and i < start_batch_in_epoch:
            continue
        
        user_idx_batch, anime_idx_batch, score_batch = batch
        user_idx_batch = user_idx_batch.to(device)
        anime_idx_batch = anime_idx_batch.to(device)
        score_batch = score_batch.to(device)

        optimizer_sparse.zero_grad()
        optimizer_dense.zero_grad()
        pred = model(user_idx_batch, anime_idx_batch, anime_genres)
        loss = loss_fn(pred, score_batch)
        loss.backward()
        optimizer_sparse.step()
        optimizer_dense.step()

        total_loss += loss.item() * len(user_idx_batch)
        count += len(user_idx_batch)
        
        checkpoint_loss += loss.item() * len(user_idx_batch)
        checkpoint_count += len(user_idx_batch)
        

        # periodic checkpoint
        if (i + 1) % CHECKPOINT_INTERVAL == 0:
            ckpt_path = os.path.join(CHECKPOINT_PATH, f"epoch{epoch+1:02d}.pth")
            save_checkpoint(model, optimizer_sparse, optimizer_dense, epoch, i + 1, total_loss / count, ckpt_path)
            progress_bar.write(f"💾 Saved checkpoint: {ckpt_path}, Checkpoint Loss: {checkpoint_loss / max(1, checkpoint_count)}")
            checkpoint_loss, checkpoint_count = 0, 0

    # end of epoch checkpoint
    ckpt_path = os.path.join(CHECKPOINT_PATH, f"epoch{epoch+1:02d}.pth")
    save_checkpoint(model, optimizer_sparse, optimizer_dense, epoch + 1, 0, total_loss / count, ckpt_path)
    progress_bar.write(f"💾 Saved checkpoint: {ckpt_path}, Epoch Loss: {total_loss / count}")

Resuming from checkpoints/hybrid100\epoch01.pth, epoch=1, batch_in_epoch=0, prev_loss=2.6973


Epoch 2/20:  24%|██████████████████▎                                                         | 30009/124767 [10:39<1:22:49, 19.07it/s]

💾 Saved checkpoint: checkpoints/hybrid100\epoch02.pth, Checkpoint Loss: 2.493814210877816


Epoch 2/20:  25%|███████████████████▎                                                          | 30970/124767 [11:09<33:47, 46.26it/s]


KeyboardInterrupt: 

## 8. Make Recommendations


In [None]:
# Get watched animes

filtered_folder = f"data/pt_files{THRESHOLD}"
pattern = os.path.join(filtered_folder, f"user_anime????????????_filtered{THRESHOLD}.pt")
pt_files = glob.glob(pattern)
user_idx = int(user_id_to_idx[USERNAME]) # convert to Python int

anime_indices_list = []
scores = []

for file in tqdm(pt_files):
    data = torch.load(file, map_location='cpu')  # dict of tensors

    # Boolean mask for rows of this user
    mask = data['user_idx'] == user_idx

    # Extract anime indices for this user
    user_animes = data['anime_idx'][mask]
    user_scores = data['scores'][mask]

    if user_animes.numel() > 0:
        anime_indices_list.extend(user_animes)
        scores.extend(user_scores)
        
df = pd.read_csv('data/original/anime.csv',sep='\t')
df = df.iloc[anime_indices_list]

for title,score in zip(df['title'],scores):
    # print(score, title)
    print(int(score), title)



  0%|          | 0/70 [00:00<?, ?it/s]

{'user_idx': tensor([959803, 959803, 959803,  ..., 298264, 298264, 298264]), 'anime_idx': tensor([ 3116, 12587,  9123,  ...,  2995,  3843,  3901]), 'scores': tensor([8., 8., 7.,  ..., 6., 7., 6.])}





RuntimeError: No active exception to reraise

In [35]:
checkpoint_files = sorted(glob.glob(os.path.join(CHECKPOINT_PATH, "epoch*.pth")))
if checkpoint_files:
    last_checkpoint = checkpoint_files[-1]
    _, _, loss = load_checkpoint(model, optimizer_sparse, optimizer_dense, last_checkpoint, device)
    print(f"Making prediction from {last_checkpoint}, loss={loss:.4f}")

user_idx = user_id_to_idx[USERNAME]
top_indices, top_scores = model.recommend(user_idx, anime_genres, device=device,exclude_ids=[])
df = pd.read_csv('data/original/anime.csv',sep='\t')
df = df.iloc[top_indices]
print(f"Recommendations for {USERNAME}")
print(df[['title','anime_url']])

Making prediction from checkpoints/hybrid\epoch20.pth, loss=1.5750
Recommendations for killerlovers
                                                 title  \
9953  Ginga Eiyuu Densetsu: Die Neue These - Gekitotsu   
4892                        Xingguang Jiazu Tebie Pian   
6496                            Osakini Douzo Arigatou   
445                        Konu to Tanoshii Otomodachi   
9429                      Pokopon no Yukai na Saiyuuki   
4436                               Xiang Shi Chuanshuo   
353                                         Jigokuraku   
549                                    Ore wa Chokkaku   
9963                                Matsugae wo Musubi   
9862                         Xiong Chumo: Bian Xing Ji   

                                              anime_url  
9953  https://myanimelist.net/anime/42886/Ginga_Eiyu...  
4892  https://myanimelist.net/anime/48040/Xingguang_...  
6496  https://myanimelist.net/anime/35175/Osakini_Do...  
445   https://myanimelist.net