# Anime Recommendation Training
Train embeddings for users and anime using PyTorch, with genre features and checkpointing.

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import glob
import os
from tqdm import tqdm

## 1. Training parameters and paths

In [3]:
TOTAL_RATINGS = 127_866_421
BATCH_SIZE = 1024
TOTAL_BATCHES = TOTAL_RATINGS // BATCH_SIZE
CHECKPOINT_PATH = "checkpoints/linear"
os.makedirs(CHECKPOINT_PATH, exist_ok=True)
CHECKPOINT_INTERVAL = 15_000
EPOCHS = 20

## 2. Load anime data

In [11]:
anime_id_to_idx = torch.load("data/pt_files/anime_id_to_idx.pt", weights_only=False)
anime_genres = torch.load("data/pt_files/anime_genres.pt")

num_anime, num_genres = anime_genres.shape

## 3. Map all users

In [5]:
user_id_to_idx = torch.load("data/pt_files/user_id_to_idx.pt")
num_users = len(user_id_to_idx)

## 4. Custom Dataset

In [None]:
class RatingsPTDataset(torch.utils.data.IterableDataset):
    def __init__(self, pt_files):
        super().__init__()
        self.pt_files = pt_files

    def __iter__(self):
        for pt_file in self.pt_files:
            data = torch.load(pt_file)
            for u, a, s in zip(data["user_idx"], data["anime_idx"], data["scores"]):
                yield u, a, s


## 5. Define embedding model

In [6]:
class LinearMF(nn.Module):
    def __init__(self, num_users, num_anime, num_genres, 
                 user_dim=64, anime_dim=128, genre_proj_dim=16):
        super().__init__()
        self.user_emb = nn.Embedding(num_users, user_dim)
        self.anime_emb = nn.Embedding(num_anime, anime_dim)
        self.W_genre = nn.Linear(num_genres, genre_proj_dim)
        self.project = nn.Linear(anime_dim + genre_proj_dim, user_dim)

    def forward(self, user_idx, anime_idx, anime_genres):
        u = self.user_emb(user_idx)
        v = self.anime_emb(anime_idx)
        g = self.W_genre(anime_genres[anime_idx])
        v_combined = torch.cat([v, g], dim=1)
        v_proj = self.project(v_combined)
        r_hat = (u * v_proj).sum(dim=1)
        return r_hat
    
    def recommend(self, user_idx, anime_genres, top_k=10, device="cpu", exclude_ids=None):
        self.eval()
        with torch.no_grad():
            # Repeat user_idx across all anime
            user_tensor = torch.tensor([user_idx] * anime_genres.size(0), 
                                       dtype=torch.long, device=device)
            anime_tensor = torch.arange(anime_genres.size(0), 
                                        dtype=torch.long, device=device)
            
            preds = self.forward(user_tensor, anime_tensor, anime_genres)
            
            if exclude_ids is not None:
                preds[exclude_ids] = float("-inf")  # mask out watched anime

            # Get top-k indices
            top_scores, top_indices = torch.topk(preds, top_k)
        
        return top_indices.cpu().tolist(), top_scores.cpu().tolist()

## 6. Training setup

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LinearMF(num_users, num_anime, num_genres).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()

def save_checkpoint(model, optimizer, epoch, batch_in_epoch, loss, filename):
    torch.save({
        "epoch": epoch,
        "batch_in_epoch": batch_in_epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "loss": loss,
    }, filename)

def load_checkpoint(model, optimizer, filename, device):
    checkpoint = torch.load(filename, map_location=device)
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    start_epoch = checkpoint.get("epoch", 0)
    start_batch_in_epoch = checkpoint.get("batch_in_epoch", 0)
    loss = checkpoint["loss"]
    return start_epoch, start_batch_in_epoch, loss

# dataset = RatingsDataset(csv_files, user_id_to_idx, anime_id_to_idx)
pt_files = sorted(glob.glob("data/pt_files/user_anime*_filtered.pt"))
dataset = RatingsPTDataset(pt_files)
loader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE)
anime_genres = anime_genres.to(device)

## 7. Training loop

In [None]:
start_epoch, start_batch_in_epoch, prev_loss = 0, 0, None
checkpoint_files = sorted(glob.glob(os.path.join(CHECKPOINT_PATH, "epoch*.pth")))
if checkpoint_files:
    last_checkpoint = checkpoint_files[-1]
    start_epoch, start_batch_in_epoch, prev_loss = load_checkpoint(model, optimizer, last_checkpoint, device)
    print(f"Resuming from {last_checkpoint}, epoch={start_epoch}, batch_in_epoch={start_batch_in_epoch}, prev_loss={prev_loss:.4f}")
    

# TODO
# Add weight decay
# Switch to sparse adam
# Maybe add hybrid linear/non-linear layer
# checkpoint losses + remove io

for epoch in range(start_epoch, EPOCHS):
    total_loss = 0
    count = 0
    
    progress_bar = tqdm(loader, total=TOTAL_BATCHES, desc=f"Epoch {epoch+1}/{EPOCHS}")
    
    for i, batch in enumerate(progress_bar):
        if epoch == start_epoch and i < start_batch_in_epoch:
            continue
        
        user_idx_batch, anime_idx_batch, score_batch = batch
        user_idx_batch = user_idx_batch.to(device)
        anime_idx_batch = anime_idx_batch.to(device)
        score_batch = score_batch.to(device)

        optimizer.zero_grad()
        pred = model(user_idx_batch, anime_idx_batch, anime_genres.to(device))
        loss = loss_fn(pred, score_batch)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * len(user_idx_batch)
        count += len(user_idx_batch)
        
        avg_loss = total_loss / count
        progress_bar.set_postfix(loss=avg_loss)

        # periodic checkpoint
        if (i + 1) % CHECKPOINT_INTERVAL == 0:
            ckpt_path = os.path.join(CHECKPOINT_PATH, f"epoch{epoch+1:02d}.pth")
            save_checkpoint(model, optimizer, epoch, i + 1, avg_loss, ckpt_path)
            progress_bar.write(f"💾 Saved checkpoint: {ckpt_path}")

    # end of epoch checkpoint
    ckpt_path = os.path.join(CHECKPOINT_PATH, f"epoch{epoch+1:02d}.pth")
    save_checkpoint(model, optimizer, epoch + 1, 0, total_loss / count, ckpt_path)
    progress_bar.write(f"💾 Saved checkpoint: {ckpt_path}")

Resuming from checkpoints\epoch17.pth, epoch=16, batch_in_epoch=75000, prev_loss=2.7415


Epoch 17/20:  72%|███████▏  | 90003/124869 [19:53<4:42:32,  2.06it/s, loss=2.68] 

💾 Saved checkpoint: checkpoints\epoch17.pth


Epoch 17/20:  84%|████████▍ | 105002/124869 [29:21<1:15:39,  4.38it/s, loss=2.68]

💾 Saved checkpoint: checkpoints\epoch17.pth


Epoch 17/20:  96%|█████████▌| 120002/124869 [41:18<22:57,  3.53it/s, loss=2.68]   

💾 Saved checkpoint: checkpoints\epoch17.pth


Epoch 17/20: 124870it [44:23, 46.89it/s, loss=2.67]                              


💾 Saved checkpoint: checkpoints\epoch17.pth


Epoch 18/20:  12%|█▏        | 15003/124869 [09:52<10:10:25,  3.00it/s, loss=2.68]

💾 Saved checkpoint: checkpoints\epoch18.pth


Epoch 18/20:  24%|██▍       | 30003/124869 [19:09<8:56:13,  2.95it/s, loss=2.68] 

💾 Saved checkpoint: checkpoints\epoch18.pth


Epoch 18/20:  36%|███▌      | 45005/124869 [28:33<3:44:25,  5.93it/s, loss=2.63] 

💾 Saved checkpoint: checkpoints\epoch18.pth


Epoch 18/20:  48%|████▊     | 60004/124869 [38:10<4:14:22,  4.25it/s, loss=2.61] 

💾 Saved checkpoint: checkpoints\epoch18.pth


Epoch 18/20:  60%|██████    | 74998/124869 [48:09<26:15, 31.66it/s, loss=2.6]    

💾 Saved checkpoint: checkpoints\epoch18.pth


Epoch 18/20:  72%|███████▏  | 89999/124869 [58:15<18:30, 31.41it/s, loss=2.59]   

💾 Saved checkpoint: checkpoints\epoch18.pth


Epoch 18/20:  84%|████████▍ | 105004/124869 [1:08:35<1:32:42,  3.57it/s, loss=2.59] 

💾 Saved checkpoint: checkpoints\epoch18.pth


Epoch 18/20:  96%|█████████▌| 120003/124869 [1:18:51<39:11,  2.07it/s, loss=2.59]  

💾 Saved checkpoint: checkpoints\epoch18.pth


Epoch 18/20: 124870it [1:21:52, 25.42it/s, loss=2.59]                            


💾 Saved checkpoint: checkpoints\epoch18.pth


Epoch 19/20:  12%|█▏        | 15003/124869 [09:56<7:10:08,  4.26it/s, loss=2.57] 

💾 Saved checkpoint: checkpoints\epoch19.pth


Epoch 19/20:  24%|██▍       | 30005/124869 [19:23<3:34:45,  7.36it/s, loss=2.56] 

💾 Saved checkpoint: checkpoints\epoch19.pth


Epoch 19/20:  36%|███▌      | 45003/124869 [29:40<7:01:08,  3.16it/s, loss=2.54] 

💾 Saved checkpoint: checkpoints\epoch19.pth


Epoch 19/20:  48%|████▊     | 59998/124869 [40:42<5:43:52,  3.14it/s, loss=2.52] 

💾 Saved checkpoint: checkpoints\epoch19.pth


Epoch 19/20:  60%|██████    | 75003/124869 [51:12<4:28:51,  3.09it/s, loss=2.52] 

💾 Saved checkpoint: checkpoints\epoch19.pth


Epoch 19/20:  72%|███████▏  | 90003/124869 [1:01:09<2:27:07,  3.95it/s, loss=2.52]

💾 Saved checkpoint: checkpoints\epoch19.pth


Epoch 19/20:  84%|████████▍ | 105006/124869 [1:10:26<1:35:06,  3.48it/s, loss=2.51]

💾 Saved checkpoint: checkpoints\epoch19.pth


Epoch 19/20:  96%|█████████▌| 120003/124869 [1:19:45<14:23,  5.64it/s, loss=2.51]  

💾 Saved checkpoint: checkpoints\epoch19.pth


Epoch 19/20: 124870it [1:23:02, 25.06it/s, loss=2.51]                              


💾 Saved checkpoint: checkpoints\epoch19.pth


Epoch 20/20:  12%|█▏        | 15004/124869 [09:43<2:40:29, 11.41it/s, loss=2.5]  

💾 Saved checkpoint: checkpoints\epoch20.pth


Epoch 20/20:  24%|██▍       | 30003/124869 [19:21<3:12:29,  8.21it/s, loss=2.5]  

💾 Saved checkpoint: checkpoints\epoch20.pth


Epoch 20/20:  36%|███▌      | 45003/124869 [29:19<2:49:53,  7.83it/s, loss=2.48] 

💾 Saved checkpoint: checkpoints\epoch20.pth


Epoch 20/20:  48%|████▊     | 60005/124869 [39:33<4:41:55,  3.83it/s, loss=2.47] 

💾 Saved checkpoint: checkpoints\epoch20.pth


Epoch 20/20:  60%|██████    | 75005/124869 [49:23<1:58:15,  7.03it/s, loss=2.46] 

💾 Saved checkpoint: checkpoints\epoch20.pth


Epoch 20/20:  72%|███████▏  | 90000/124869 [59:25<3:18:18,  2.93it/s, loss=2.46] 

💾 Saved checkpoint: checkpoints\epoch20.pth


Epoch 20/20:  84%|████████▍ | 105006/124869 [1:09:28<1:23:57,  3.94it/s, loss=2.46]

💾 Saved checkpoint: checkpoints\epoch20.pth


Epoch 20/20:  96%|█████████▌| 119998/124869 [1:19:31<02:34, 31.48it/s, loss=2.46]  

💾 Saved checkpoint: checkpoints\epoch20.pth


Epoch 20/20: 124870it [1:22:41, 25.17it/s, loss=2.46]                              


💾 Saved checkpoint: checkpoints\epoch20.pth


## 8. Make Recommendations


In [None]:
checkpoint_files = sorted(glob.glob(os.path.join(CHECKPOINT_PATH, "epoch*.pth")))
if checkpoint_files:
    last_checkpoint = checkpoint_files[-1]
    _, _, loss = load_checkpoint(model, optimizer, last_checkpoint, device)
    print(f"Making prediction from {last_checkpoint}, loss={loss:.4f}")

USERNAME = 'sirawesomeness'
# USERNAME = 'catfire8'

user_idx = user_id_to_idx[USERNAME]
top_indices, top_scores = model.recommend(user_idx, anime_genres, device=device)
df = pd.read_csv('data/original/anime.csv',sep='\t')
df = df.iloc[top_indices]
print(f"Recommendations for {USERNAME}")
print(df[['title','anime_url']])

Making prediction from checkpoints/linear\epoch20.pth, loss=2.4619
Recommendations for catfire8
                                                  title  \
6411  Chounai Shinsengumi wo Tasuke-gumi Mezashi-tai...   
532                        Seiken Gakuin no Makentsukai   
191   Tensei Kenja no Isekai Life: Dai-2 no Shokugyo...   
497                   Benriya Saitou-san, Isekai ni Iku   
9642                                           Guo Qiao   
258                          IDOLiSH7 3rd Season Part 2   
656                                  Kocchi Muite Miiko   
6579                    Kaitou Queen wa Circus ga Osuki   
9774                                Xiao Xiao Ji Qi Ren   
192   Shijou Saikyou no Daimaou, Murabito A ni Tense...   

                                              anime_url  
6411  https://myanimelist.net/anime/35152/Chounai_Sh...  
532   https://myanimelist.net/anime/50184/Seiken_Gak...  
191   https://myanimelist.net/anime/47163/Tensei_Ken...  
497   https://myanimel