# Anime Recommendation Training
Train embeddings for users and anime using PyTorch, with genre features and checkpointing.

In [14]:
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import glob
import os
from tqdm import tqdm
import random
from math import ceil

## 1. Training parameters and paths

In [15]:
THRESHOLD = 500
EPOCHS = 20
CHECKPOINT_PATH = f"checkpoints/hybrid{THRESHOLD}"
os.makedirs(CHECKPOINT_PATH, exist_ok=True)
CHECKPOINT_INTERVAL = 30_000
USERNAME = 'sirawesomeness'
USERNAME = 'catfire8'

In [16]:
# Count total ratings

threshold_folder = f"data/pt_files{THRESHOLD}"
pattern = os.path.join(threshold_folder, f"user_anime????????????_filtered{THRESHOLD}.pt")
csv_files = glob.glob(pattern)

# anime_counter = Counter()
count = 0
for file_path in tqdm(csv_files):
    df = torch.load(file_path)
    count += len(df['anime_idx'])
    if not len(df):
        print(file_path)
print("TOTAL_RATINGS:", count)
print("THRESHOLD:",THRESHOLD)
TOTAL_RATINGS = count
BATCH_SIZE = 1024
TOTAL_BATCHES = TOTAL_RATINGS // BATCH_SIZE

100%|██████████| 70/70 [00:01<00:00, 36.07it/s]

TOTAL_RATINGS: 127102285
THRESHOLD: 500





## 2. Load anime data

In [17]:
anime_id_to_idx = torch.load(f"data/pt_files{THRESHOLD}/anime_id_to_idx.pt", weights_only=False)
anime_genres = torch.load(f"data/pt_files{THRESHOLD}/anime_genres.pt")

num_anime, num_genres = anime_genres.shape
print("NUMBER OF ANIMES:",num_anime)
print("NUMBER OF GENRES:",num_genres)

NUMBER OF ANIMES: 7448
NUMBER OF GENRES: 44


## 3. Map all users

In [18]:
user_id_to_idx = torch.load(f"data/pt_files{THRESHOLD}/user_id_to_idx.pt")
num_users = len(user_id_to_idx)

print("NUMBER OF USERS:", num_users)

NUMBER OF USERS: 977946


## 4. Custom Dataset

In [19]:
class RatingsPTDataset(torch.utils.data.IterableDataset):
    _global_seed = None
    
    def __init__(self, pt_files, seed=None):
        super().__init__()
        self.pt_files = pt_files
        
        if seed:
            RatingsPTDataset._global_seed = seed

        # If no global seed exists yet, generate one
        if RatingsPTDataset._global_seed is None:
            RatingsPTDataset._global_seed = random.randint(0, 2**32 - 1)

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()

        # Global RNG (shared seed ensures same shuffle across all workers)
        rng = random.Random(RatingsPTDataset._global_seed)

        pt_files = list(self.pt_files)

        # Global shuffle (same order for all workers)
        rng.shuffle(pt_files)

        if worker_info is None:
            assigned_files = pt_files
        else:
            num_workers = worker_info.num_workers
            worker_id = worker_info.id

            # Deterministic split, no overlap
            files_per_worker = ceil(len(pt_files) / num_workers)
            start = worker_id * files_per_worker
            end = min(start + files_per_worker, len(pt_files))
            assigned_files = pt_files[start:end]

        # Each worker now has a unique set of files
        # Local RNG for within-file shuffling
        local_rng = random.Random(RatingsPTDataset._global_seed + (worker_id if worker_info else 0))

        for pt_file in assigned_files:
            data = torch.load(pt_file)
            samples = list(zip(data["user_idx"], data["anime_idx"], data["scores"]))

            local_rng.shuffle(samples)

            for u, a, s in samples:
                yield u, a, s

## 5. Define embedding model

In [None]:
class HybridMF(nn.Module):
    def __init__(self, num_users, num_anime, num_genres, 
                 user_dim=64, anime_dim=128, genre_proj_dim=16, hidden_dim=32):
        super().__init__()
        self.user_emb = nn.Embedding(num_users, user_dim, sparse=True)
        self.anime_emb = nn.Embedding(num_anime, anime_dim, sparse=True)
        self.W_genre = nn.Linear(num_genres, genre_proj_dim)
        self.project = nn.Linear(anime_dim + genre_proj_dim, user_dim)
        
        # Nonlinear MLP
        self.mlp = nn.Sequential(
            nn.Linear(user_dim + anime_dim + genre_proj_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, user_idx, anime_idx, anime_genres):
        u = self.user_emb(user_idx)
        v = self.anime_emb(anime_idx)
        g = self.W_genre(anime_genres[anime_idx])
        
        # ----- Linear path -----
        v_combined = torch.cat([v, g], dim=1)
        v_proj = self.project(v_combined)
        linear_pred = (u * v_proj).sum(dim=1)

        # ----- Nonlinear path -----
        mlp_input = torch.cat([u, v, g], dim=1)
        nonlinear_pred = self.mlp(mlp_input).squeeze(1)
        
        return linear_pred + nonlinear_pred
    
    def recommend(self, user_idx, anime_genres, top_k=10, device="cpu", exclude_ids=None):
        self.eval()
        with torch.no_grad():
            # Repeat user_idx across all anime
            user_tensor = torch.tensor([user_idx] * anime_genres.size(0), 
                                       dtype=torch.long, device=device)
            anime_tensor = torch.arange(anime_genres.size(0), 
                                        dtype=torch.long, device=device)
            
            preds = self.forward(user_tensor, anime_tensor, anime_genres)
            
            if exclude_ids is not None:
                preds[exclude_ids] = float("-inf")  # mask out watched anime

            # Get top-k indices
            top_scores, top_indices = torch.topk(preds, top_k)
        
        return top_indices.cpu().tolist(), top_scores.cpu().tolist()

## 6. Training setup

In [21]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HybridMF(num_users, num_anime, num_genres).to(device)

sparse_params = list(model.user_emb.parameters()) + list(model.anime_emb.parameters())
dense_params = list(model.W_genre.parameters()) + list(model.project.parameters()) + list(model.mlp.parameters())

optimizer_sparse = optim.SparseAdam(sparse_params, lr=5e-3)
optimizer_dense = optim.Adam(dense_params, lr=1e-3, weight_decay=1e-5)
loss_fn = nn.MSELoss()

def save_checkpoint(model, optimizer_sparse, optimizer_dense, epoch, batch_in_epoch, loss, filename):
    torch.save({
        "epoch": epoch,
        "batch_in_epoch": batch_in_epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_sparse_state_dict": optimizer_sparse.state_dict(),
        "optimizer_dense_state_dict": optimizer_dense.state_dict(),
        "loss": loss,
        "seed": RatingsPTDataset._global_seed,
    }, filename)

def load_checkpoint(model, optimizer_sparse, optimizer_dense, filename, device):
    checkpoint = torch.load(filename, map_location=device)
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer_sparse.load_state_dict(checkpoint["optimizer_sparse_state_dict"])
    optimizer_dense.load_state_dict(checkpoint["optimizer_dense_state_dict"])
    start_epoch = checkpoint.get("epoch", 0)
    start_batch_in_epoch = checkpoint.get("batch_in_epoch", 0)
    loss = checkpoint["loss"]
    RatingsPTDataset._global_seed = checkpoint["seed"]
    return start_epoch, start_batch_in_epoch, loss

pt_files = sorted(glob.glob(f"data/pt_files{THRESHOLD}/user_anime*_filtered{THRESHOLD}.pt"))
dataset = RatingsPTDataset(pt_files)
loader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE)
anime_genres = anime_genres.to(device)

## 7. Training loop

In [33]:
start_epoch, start_batch_in_epoch, prev_loss = 0, 0, None
checkpoint_files = sorted(glob.glob(os.path.join(CHECKPOINT_PATH, "epoch*.pth")))
if checkpoint_files:
    last_checkpoint = checkpoint_files[-1]
    start_epoch, start_batch_in_epoch, prev_loss = load_checkpoint(model, optimizer_sparse, optimizer_dense, last_checkpoint, device)
    print(f"Resuming from {last_checkpoint}, epoch={start_epoch}, batch_in_epoch={start_batch_in_epoch}, prev_loss={prev_loss:.4f}")
    
for epoch in range(start_epoch, EPOCHS):
    total_loss = 0
    count = 0
    
    checkpoint_loss = 0
    checkpoint_count = 0
    
    progress_bar = tqdm(loader, total=TOTAL_BATCHES, desc=f"Epoch {epoch+1}/{EPOCHS}")
    
    for i, batch in enumerate(progress_bar):
        if epoch == start_epoch and i < start_batch_in_epoch:
            continue
        
        user_idx_batch, anime_idx_batch, score_batch = batch
        user_idx_batch = user_idx_batch.to(device)
        anime_idx_batch = anime_idx_batch.to(device)
        score_batch = score_batch.to(device)

        optimizer_sparse.zero_grad()
        optimizer_dense.zero_grad()
        pred = model(user_idx_batch, anime_idx_batch, anime_genres)
        loss = loss_fn(pred, score_batch)
        loss.backward()
        optimizer_sparse.step()
        optimizer_dense.step()

        total_loss += loss.item() * len(user_idx_batch)
        count += len(user_idx_batch)
        
        checkpoint_loss += loss.item() * len(user_idx_batch)
        checkpoint_count += len(user_idx_batch)
        

        # periodic checkpoint
        if (i + 1) % CHECKPOINT_INTERVAL == 0:
            ckpt_path = os.path.join(CHECKPOINT_PATH, f"epoch{epoch+1:02d}.pth")
            save_checkpoint(model, optimizer_sparse, optimizer_dense, epoch, i + 1, total_loss / count, ckpt_path)
            progress_bar.write(f"💾 Saved checkpoint: {ckpt_path}, Checkpoint Loss: {checkpoint_loss / max(1, checkpoint_count)}")
            checkpoint_loss, checkpoint_count = 0, 0

    # end of epoch checkpoint
    ckpt_path = os.path.join(CHECKPOINT_PATH, f"epoch{epoch+1:02d}.pth")
    save_checkpoint(model, optimizer_sparse, optimizer_dense, epoch + 1, 0, total_loss / count, ckpt_path)
    progress_bar.write(f"💾 Saved checkpoint: {ckpt_path}, Epoch Loss: {total_loss / count}")

Resuming from checkpoints/hybrid500\epoch16.pth, epoch=15, batch_in_epoch=30000, prev_loss=1.6277


Epoch 16/20:  48%|████▊     | 60022/124123 [19:44<29:24, 36.34it/s]   

💾 Saved checkpoint: checkpoints/hybrid500\epoch16.pth, Checkpoint Loss: 1.6264327452371519


Epoch 16/20:  73%|███████▎  | 90015/124123 [31:02<16:36, 34.23it/s]   

💾 Saved checkpoint: checkpoints/hybrid500\epoch16.pth, Checkpoint Loss: 1.6187181806628903


Epoch 16/20:  97%|█████████▋| 120008/124123 [41:11<03:38, 18.86it/s]  

💾 Saved checkpoint: checkpoints/hybrid500\epoch16.pth, Checkpoint Loss: 1.6194991505141059


Epoch 16/20: 124124it [42:55, 48.20it/s]                            


💾 Saved checkpoint: checkpoints/hybrid500\epoch16.pth, Epoch Loss: 1.62213110350526


Epoch 17/20:  24%|██▍       | 30013/124123 [12:05<1:04:15, 24.41it/s] 

💾 Saved checkpoint: checkpoints/hybrid500\epoch17.pth, Checkpoint Loss: 1.6141579028129578


Epoch 17/20:  48%|████▊     | 60012/124123 [24:45<44:13, 24.16it/s]   

💾 Saved checkpoint: checkpoints/hybrid500\epoch17.pth, Checkpoint Loss: 1.6126426739409565


Epoch 17/20:  73%|███████▎  | 90013/124123 [37:10<24:08, 23.55it/s]   

💾 Saved checkpoint: checkpoints/hybrid500\epoch17.pth, Checkpoint Loss: 1.6049756613428394


Epoch 17/20:  97%|█████████▋| 120010/124123 [50:29<02:35, 26.46it/s]  

💾 Saved checkpoint: checkpoints/hybrid500\epoch17.pth, Checkpoint Loss: 1.6059218395690122


Epoch 17/20: 124124it [52:14, 39.60it/s]                             


💾 Saved checkpoint: checkpoints/hybrid500\epoch17.pth, Epoch Loss: 1.6097923115418045


Epoch 18/20:  24%|██▍       | 30009/124123 [12:25<1:26:10, 18.20it/s] 

💾 Saved checkpoint: checkpoints/hybrid500\epoch18.pth, Checkpoint Loss: 1.6017880773713191


Epoch 18/20:  48%|████▊     | 60017/124123 [24:39<41:47, 25.56it/s]   

💾 Saved checkpoint: checkpoints/hybrid500\epoch18.pth, Checkpoint Loss: 1.6005885450065136


Epoch 18/20:  73%|███████▎  | 90014/124123 [35:11<23:17, 24.40it/s]   

💾 Saved checkpoint: checkpoints/hybrid500\epoch18.pth, Checkpoint Loss: 1.5926718141600489


Epoch 18/20:  97%|█████████▋| 120013/124123 [46:39<02:27, 27.89it/s]  

💾 Saved checkpoint: checkpoints/hybrid500\epoch18.pth, Checkpoint Loss: 1.5938965509931247


Epoch 18/20: 124124it [48:15, 42.87it/s]                             


💾 Saved checkpoint: checkpoints/hybrid500\epoch18.pth, Epoch Loss: 1.5975936180401593


Epoch 19/20:  24%|██▍       | 30019/124123 [11:28<52:20, 29.96it/s]   

💾 Saved checkpoint: checkpoints/hybrid500\epoch19.pth, Checkpoint Loss: 1.5901327853664755


Epoch 19/20:  48%|████▊     | 60016/124123 [23:28<36:51, 28.99it/s]   

💾 Saved checkpoint: checkpoints/hybrid500\epoch19.pth, Checkpoint Loss: 1.5891346212198336


Epoch 19/20:  73%|███████▎  | 90016/124123 [34:33<18:21, 30.95it/s]   

💾 Saved checkpoint: checkpoints/hybrid500\epoch19.pth, Checkpoint Loss: 1.5814408552984396


Epoch 19/20:  97%|█████████▋| 120019/124123 [46:10<02:02, 33.63it/s]  

💾 Saved checkpoint: checkpoints/hybrid500\epoch19.pth, Checkpoint Loss: 1.5819073496490716


Epoch 19/20: 124124it [47:42, 43.36it/s]                             


💾 Saved checkpoint: checkpoints/hybrid500\epoch19.pth, Epoch Loss: 1.585927664000823


Epoch 20/20:  24%|██▍       | 30018/124123 [11:22<54:22, 28.84it/s]   

💾 Saved checkpoint: checkpoints/hybrid500\epoch20.pth, Checkpoint Loss: 1.5798948830217123


Epoch 20/20:  48%|████▊     | 60014/124123 [23:17<34:19, 31.13it/s]   

💾 Saved checkpoint: checkpoints/hybrid500\epoch20.pth, Checkpoint Loss: 1.5780280316467086


Epoch 20/20:  73%|███████▎  | 90010/124123 [34:13<20:36, 27.58it/s]   

💾 Saved checkpoint: checkpoints/hybrid500\epoch20.pth, Checkpoint Loss: 1.5710084089492757


Epoch 20/20:  97%|█████████▋| 120013/124123 [45:46<02:52, 23.79it/s]  

💾 Saved checkpoint: checkpoints/hybrid500\epoch20.pth, Checkpoint Loss: 1.5710629819765687


Epoch 20/20: 124124it [47:20, 43.70it/s]                             


💾 Saved checkpoint: checkpoints/hybrid500\epoch20.pth, Epoch Loss: 1.5752388100541495


## 8. Make Recommendations


In [34]:
# Get watched animes

filtered_folder = f"data/pt_files{THRESHOLD}"
pattern = os.path.join(filtered_folder, f"user_anime????????????_filtered{THRESHOLD}.pt")
pt_files = glob.glob(pattern)
user_idx = int(user_id_to_idx[USERNAME]) # convert to Python int
anime_indices_list = []
scores = []

for file in tqdm(pt_files):
    data = torch.load(file, map_location='cpu')  # dict of tensors

    # Boolean mask for rows of this user
    mask = data['user_idx'] == user_idx

    # Extract anime indices for this user
    user_animes = data['anime_idx'][mask]
    user_scores = data['scores'][mask]

    if user_animes.numel() > 0:
        anime_indices_list.extend(user_animes)
        scores.extend(user_scores)
        
anime_path = "data/filteredRatings/anime_filtered.csv" if not THRESHOLD else f"data/filteredRatingsThreshold{THRESHOLD}/anime_filtered.csv"
df = pd.read_csv(anime_path)
df = df.iloc[anime_indices_list]
print(f"Anime {USERNAME} rated 10 to verify")
for title,score in zip(df['title'],scores):
    score = int(score)
    if score == 10:
        print(score, title)


  0%|          | 0/70 [00:00<?, ?it/s]

100%|██████████| 70/70 [00:02<00:00, 33.38it/s]


Anime sirawesomeness rated 10 to verify
10 Naruto: Shippuuden
10 Naruto
10 Fate/stay night Movie: Heaven's Feel - I. Presage Flower
10 JoJo no Kimyou na Bouken Part 3: Stardust Crusaders 2nd Season
10 Fate/stay night: Unlimited Blade Works 2nd Season
10 Koe no Katachi
10 Gintama°
10 Code Geass: Hangyaku no Lelouch R2
10 Re:Zero kara Hajimeru Isekai Seikatsu
10 Fate/stay night Movie: Heaven's Feel - II. Lost Butterfly
10 Fate/stay night Movie: Heaven's Feel - III. Spring Song
10 Haikyuu!! Movie 4: Concept no Tatakai
10 Boku no Hero Academia 4th Season
10 Mushoku Tensei: Isekai Ittara Honki Dasu
10 Re:Zero kara Hajimeru Isekai Seikatsu 2nd Season
10 Shokugeki no Souma: Shin no Sara
10 Shingeki no Kyojin: The Final Season
10 BNA
10 Haikyuu!! To the Top 2nd Season
10 Shokugeki no Souma: Gou no Sara
10 Mushoku Tensei: Isekai Ittara Honki Dasu Part 2
10 Odd Taxi


In [39]:
checkpoint_files = sorted(glob.glob(os.path.join(CHECKPOINT_PATH, "epoch*.pth")))
if checkpoint_files:
    last_checkpoint = checkpoint_files[-1]
    _, _, loss = load_checkpoint(model, optimizer_sparse, optimizer_dense, last_checkpoint, device)
    print(f"Making prediction from {last_checkpoint}, loss={loss:.4f}")

user_idx = user_id_to_idx[USERNAME]
# top_indices, top_scores = model.recommend(user_idx, anime_genres, device=device)
top_indices, top_scores = model.recommend(user_idx, anime_genres, device=device,exclude_ids=anime_indices_list)
anime_path = "data/filteredRatings/anime_filtered.csv" if not THRESHOLD else f"data/filteredRatingsThreshold{THRESHOLD}/anime_filtered.csv"
df = pd.read_csv(anime_path)
df = df.iloc[top_indices]
print(f"Recommendations for {USERNAME}")

# Calculate dynamic padding based on the longest title
max_title_len = df['title'].map(len).max()
title_col_width = max(max_title_len, len("Title")) + 2  # +2 for a little breathing room

# Print formatted header
print(f"{'Index':<6}  {'Title':<{title_col_width}}  {'URL'}")
print("-" * (title_col_width + 40))  # Adjust line length for aesthetics

# Print rows with dynamic spacing
index = 1
for _, row in df.iterrows():
    url = f"https://myanimelist.net/anime/{row['anime_id']}"
    print(f"{index:<6}  {row['title']:<{title_col_width}}  {url}")
    index += 1




Making prediction from checkpoints/hybrid500\epoch20.pth, loss=1.5752
Recommendations for catfire8
Index   Title                                  URL
-----------------------------------------------------------------------------
1       Monster                                https://myanimelist.net/anime/19
2       Ginga Eiyuu Densetsu                   https://myanimelist.net/anime/820
3       Owarimonogatari 2nd Season             https://myanimelist.net/anime/35247
4       Shoujo☆Kageki Revue Starlight Movie    https://myanimelist.net/anime/40664
5       3-gatsu no Lion 2nd Season             https://myanimelist.net/anime/35180
6       Mushishi Zoku Shou 2nd Season          https://myanimelist.net/anime/24701
7       Mushishi                               https://myanimelist.net/anime/457
8       Mushishi Zoku Shou: Suzu no Shizuku    https://myanimelist.net/anime/28957
9       Ping Pong the Animation                https://myanimelist.net/anime/22135
10      Monogatari Series: Secon

In [13]:
def relu_hook(name, stats_dict):
    def hook(module, input, output):
        zeros = (output == 0).sum().item()
        total = output.numel()
        stats_dict[name] = {
            "zero_count": zeros,
            "total_count": total,
            "zero_fraction": zeros / total
        }
    return hook

def check_dead_neurons(model, user_idx_batch, anime_idx_batch, anime_genres, device):
    activation_stats = {}

    # Register hooks on all ReLU layers
    hooks = []
    for name, module in model.named_modules():
        if isinstance(module, nn.ReLU):
            hooks.append(module.register_forward_hook(relu_hook(name, activation_stats)))

    model.eval()
    with torch.no_grad():
        user_idx_batch = user_idx_batch.to(device)
        anime_idx_batch = anime_idx_batch.to(device)
        anime_genres = anime_genres.to(device)
        _ = model(user_idx_batch, anime_idx_batch, anime_genres)

    # Remove hooks
    for h in hooks:
        h.remove()

    return activation_stats

BATCH_SIZE_CHECK = 32
sample_data = torch.load(pt_files[0], map_location='cpu')
user_idx_batch = torch.tensor(sample_data["user_idx"][:BATCH_SIZE_CHECK])
anime_idx_batch = torch.tensor(sample_data["anime_idx"][:BATCH_SIZE_CHECK])

# Loop over all checkpoints
checkpoint_files = sorted(glob.glob(os.path.join(CHECKPOINT_PATH, "epoch*.pth")))

print("\nDead ReLU neuron fraction per checkpoint:")
print("Epoch | % Dead Neurons")
print("----------------------")
for ckpt_file in checkpoint_files:
    epoch = int(os.path.basename(ckpt_file).replace("epoch", "").replace(".pth", ""))
    load_checkpoint(model, optimizer_sparse, optimizer_dense, ckpt_file, device)
    stats = check_dead_neurons(model, user_idx_batch, anime_idx_batch, anime_genres, device)

    for layer, s in stats.items():
        print(f"Epoch {epoch}: {s['zero_fraction']*100:.2f}% zeros")

  user_idx_batch = torch.tensor(sample_data["user_idx"][:BATCH_SIZE_CHECK])
  anime_idx_batch = torch.tensor(sample_data["anime_idx"][:BATCH_SIZE_CHECK])



Dead ReLU neuron fraction per checkpoint:
Epoch | % Dead Neurons
----------------------
  Epoch 1: 63.96% zeros
  Epoch 2: 72.46% zeros
  Epoch 3: 83.20% zeros
  Epoch 4: 88.09% zeros
  Epoch 5: 79.69% zeros
  Epoch 6: 91.89% zeros
  Epoch 7: 86.72% zeros
  Epoch 8: 85.74% zeros
  Epoch 9: 91.99% zeros
  Epoch 10: 94.92% zeros
  Epoch 11: 90.43% zeros
  Epoch 12: 91.89% zeros
  Epoch 13: 90.62% zeros
  Epoch 14: 92.09% zeros
  Epoch 15: 89.16% zeros
  Epoch 16: 90.72% zeros
  Epoch 17: 88.18% zeros
  Epoch 18: 89.55% zeros
  Epoch 19: 91.41% zeros
  Epoch 20: 93.75% zeros
