In [None]:
import pandas as pd
import numpy as np
import wandb
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

DATA_DIR = "../data"

In [None]:
run = wandb.init(
    project="book-recommendation-kaggle",
    group="dev",
    job_type="train",
    save_code=True,
)

In [None]:
books = pd.read_csv(f"{DATA_DIR}/Books.csv.zip", compression="zip")
ratings = pd.read_csv(f"{DATA_DIR}/Ratings.csv.zip", compression="zip")
users = pd.read_csv(f"{DATA_DIR}/Users.csv.zip", compression="zip")

In [None]:
books.head()

In [None]:
ratings.head()

In [None]:
users.head()

In [None]:
users["User-ID"].nunique() / len(users)

In [None]:
ratings["target"] = (ratings["Book-Rating"] >= 7).astype(int)
ratings.head()

In [None]:
# Create mappings for categorical features

books["Year-Of-Publication"] = pd.to_numeric(
    books["Year-Of-Publication"], errors="coerce"
)
books["year_bucket"] = pd.cut(
    books["Year-Of-Publication"],
    bins=[0, 1980, 1990, 2000, 2010, 2025],
    labels=["<1980", "1980s", "1990s", "2000s", "2010s+"],
)

# Create vocabs
title_vocab = {title: idx for idx, title in enumerate(books["Book-Title"].unique())}
author_vocab = {author: idx for idx, author in enumerate(books["Book-Author"].unique())}
publisher_vocab = {pub: idx for idx, pub in enumerate(books["Publisher"].unique())}
year_vocab = {year: idx for idx, year in enumerate(books["year_bucket"].cat.categories)}

# Process user features
users["age_bucket"] = pd.cut(
    users["Age"],
    bins=[0, 18, 25, 35, 50, 100],
    labels=["<18", "18-25", "25-35", "35-50", "50+"],
)

# Simple location processing
users["location_simple"] = (
    users["Location"].str.split(",").str[-1].str.strip().str.lower()
)
location_vocab = {loc: idx for idx, loc in enumerate(users["location_simple"].unique())}
age_vocab = {age: idx for idx, age in enumerate(users["age_bucket"].cat.categories)}

feature_mappings = {
    "title_vocab": title_vocab,
    "author_vocab": author_vocab,
    "publisher_vocab": publisher_vocab,
    "year_vocab": year_vocab,
    "location_vocab": location_vocab,
    "age_vocab": age_vocab,
}

In [None]:
# Better data preparation - fix the float issue and missing mappings
print("=== Data Mapping and Cleaning ===")

# Create proper integer mappings
user_id_to_idx = users.reset_index().set_index("User-ID")["index"]
book_id_to_idx = books.reset_index().set_index("ISBN")["index"]

print(f"User mapping size: {len(user_id_to_idx)}")
print(f"Book mapping size: {len(book_id_to_idx)}")

# Add the target column
ratings = ratings.copy()
ratings["target"] = (ratings["Book-Rating"] >= 7.0).astype(int)

# Map user and book IDs, keeping track of missing values
ratings["user_idx"] = ratings["User-ID"].map(user_id_to_idx)
ratings["book_idx"] = ratings["ISBN"].map(book_id_to_idx)

# Check for missing mappings
missing_users = ratings["user_idx"].isnull().sum()
missing_books = ratings["book_idx"].isnull().sum()
print(f"Missing user mappings: {missing_users}")
print(f"Missing book mappings: {missing_books}")

# Only keep rows with valid mappings
ratings_clean = ratings.dropna(subset=["user_idx", "book_idx"]).copy()

# Convert to proper integers
ratings_clean["user_idx"] = ratings_clean["user_idx"].astype("int32")
ratings_clean["book_idx"] = ratings_clean["book_idx"].astype("int32")

In [None]:
# Add feature indices
ratings_clean = ratings_clean.merge(
    books[["ISBN", "Book-Title", "Book-Author", "Publisher", "year_bucket"]], on="ISBN"
)

ratings_clean = ratings_clean.merge(
    users[["User-ID", "age_bucket", "location_simple"]],
    left_on="User-ID",
    right_on="User-ID",
    how="left",
)

# Map to indices
ratings_clean["title_idx"] = (
    ratings_clean["Book-Title"].map(feature_mappings["title_vocab"]).fillna(0)
)
ratings_clean["author_idx"] = (
    ratings_clean["Book-Author"].map(feature_mappings["author_vocab"]).fillna(0)
)
ratings_clean["publisher_idx"] = (
    ratings_clean["Publisher"].map(feature_mappings["publisher_vocab"]).fillna(0)
)
ratings_clean["year_idx"] = (
    ratings_clean["year_bucket"].map(feature_mappings["year_vocab"]).fillna(0)
)
ratings_clean["location_idx"] = (
    ratings_clean["location_simple"].map(feature_mappings["location_vocab"]).fillna(0)
)
ratings_clean["age_idx"] = (
    ratings_clean["age_bucket"].map(feature_mappings["age_vocab"]).fillna(0)
)

In [None]:
ratings_clean = ratings_clean.drop(
    columns=[col for col in ratings_clean if "Image-URL" in col]
)
print(f"Final clean dataset size: {len(ratings_clean)} (from {len(ratings)} original)")
print(
    f"Data types: user_idx={ratings_clean['user_idx'].dtype}, book_idx={ratings_clean['book_idx'].dtype}"
)

ratings_clean.head()

In [None]:
class HybridContrastiveDataset(Dataset):
    def __init__(self, df, num_random_negatives=7):
        self.df = df.reset_index(drop=True)
        self.num_random_negatives = num_random_negatives

        # Separate positives and explicit negatives
        self.positives = df[df["target"] == 1].reset_index(drop=True)
        self.explicit_negatives = df[df["target"] == 0].reset_index(drop=True)

        print("Dataset composition:")
        print(f"  Positives: {len(self.positives):,}")
        print(f"  Explicit negatives: {len(self.explicit_negatives):,}")
        print(f"  Random negatives per positive: {num_random_negatives}")

        # Create efficient lookup structures
        self.user_positive_books = (
            self.positives.groupby("user_idx_new")["book_idx_new"].apply(set).to_dict()
        )
        self.user_negative_books = (
            self.explicit_negatives.groupby("user_idx_new")["book_idx_new"]
            .apply(set)
            .to_dict()
        )
        self.all_books = set(df["book_idx_new"].unique())

        # Pre-compute features - ensure ALL required features are available
        self.user_features = (
            df.groupby("user_idx_new")
            .first()[["age_idx", "location_idx"]]
            .fillna(0)  # Fill missing values
            .to_dict("index")
        )

        self.book_features = (
            df.groupby("book_idx_new")
            .first()[["title_idx", "author_idx", "publisher_idx", "year_idx"]]
            .fillna(0)  # Fill missing values
            .to_dict("index")
        )

    def __len__(self):
        return len(self.positives)  # One sample per positive interaction

    def __getitem__(self, idx):
        # Get the positive interaction
        pos_row = self.positives.iloc[idx]
        user_id = pos_row["user_idx_new"]
        pos_book_id = pos_row["book_idx_new"]

        # Get user features
        user_feats = self.user_features[user_id]

        # Get positive book features
        pos_book_feats = self.book_features[pos_book_id]

        # Collect all negatives for this user
        negatives = []
        negative_features = []

        # 1. Add explicit negatives (from actual user interactions with target=0)
        if user_id in self.user_negative_books:
            explicit_negs = list(self.user_negative_books[user_id])
            negatives.extend(explicit_negs)
            negative_features.extend(
                [self.book_features[book_id] for book_id in explicit_negs]
            )

        # 2. Add random negatives (books user never interacted with)
        user_all_interactions = self.user_positive_books.get(
            user_id, set()
        ) | self.user_negative_books.get(user_id, set())

        available_books = list(self.all_books - user_all_interactions)

        if len(available_books) >= self.num_random_negatives:
            random_negs = np.random.choice(
                available_books, size=self.num_random_negatives, replace=False
            )
            negatives.extend(random_negs)
            negative_features.extend(
                [self.book_features[book_id] for book_id in random_negs]
            )

        # Ensure we have at least one negative
        if len(negatives) == 0:
            # Add one random negative if no explicit negatives exist
            available_books = list(self.all_books - user_all_interactions)
            if available_books:
                random_neg = np.random.choice(available_books, size=1)[0]
                negatives = [random_neg]
                negative_features = [self.book_features[random_neg]]

        return {
            "user_id": int(user_id),
            "user_age_idx": int(user_feats["age_idx"]),
            "user_location_idx": int(user_feats["location_idx"]),
            "pos_book_id": int(pos_book_id),
            "pos_title_idx": int(pos_book_feats["title_idx"]),
            "pos_author_idx": int(pos_book_feats["author_idx"]),
            "pos_publisher_idx": int(pos_book_feats["publisher_idx"]),
            "pos_year_idx": int(pos_book_feats["year_idx"]),
            "neg_book_ids": torch.LongTensor(negatives),
            "neg_title_idx": torch.LongTensor(
                [int(f["title_idx"]) for f in negative_features]
            ),
            "neg_author_idx": torch.LongTensor(
                [int(f["author_idx"]) for f in negative_features]
            ),
            "neg_publisher_idx": torch.LongTensor(
                [int(f["publisher_idx"]) for f in negative_features]
            ),
            "neg_year_idx": torch.LongTensor(
                [int(f["year_idx"]) for f in negative_features]
            ),
            # Keep track of negative types for analysis
            "num_explicit_negs": len(self.user_negative_books.get(user_id, [])),
            "num_random_negs": self.num_random_negatives,
        }

In [None]:
class UserTower(nn.Module):
    def __init__(self, num_users, num_locations, num_age_buckets, embedding_dim=64):
        super().__init__()

        # Simpler approach - fewer feature embeddings to start
        self.user_embedding = nn.Embedding(num_users, embedding_dim)
        self.location_embedding = nn.Embedding(num_locations, 16)
        self.age_embedding = nn.Embedding(num_age_buckets, 16)

        # Simpler fusion
        self.feature_fusion = nn.Sequential(
            nn.Linear(embedding_dim + 32, embedding_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
        )

    def forward(self, user_ids, location_ids, age_bucket_ids):
        user_emb = self.user_embedding(user_ids)
        location_emb = self.location_embedding(location_ids)
        age_emb = self.age_embedding(age_bucket_ids)

        # Simple concatenation and fusion
        features = torch.cat([user_emb, location_emb, age_emb], dim=-1)
        output = self.feature_fusion(features)

        return F.normalize(output, p=2, dim=-1)


class BookTower(nn.Module):
    def __init__(
        self,
        num_books,
        num_titles,
        num_authors,
        num_publishers,
        num_year_buckets,
        embedding_dim=64,
    ):
        super().__init__()

        # Simpler book model
        self.book_embedding = nn.Embedding(num_books, embedding_dim)
        self.title_embedding = nn.Embedding(num_titles, 32)
        self.author_embedding = nn.Embedding(num_authors, 16)
        self.publisher_embedding = nn.Embedding(num_publishers, 16)
        self.year_embedding = nn.Embedding(num_year_buckets, 8)

        # Simpler fusion
        self.feature_fusion = nn.Sequential(
            nn.Linear(embedding_dim + 72, embedding_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
        )

    def forward(self, book_ids, title_ids, author_ids, publisher_ids, year_bucket_ids):
        book_emb = self.book_embedding(book_ids)
        title_emb = self.title_embedding(title_ids)
        author_emb = self.author_embedding(author_ids)
        publisher_emb = self.publisher_embedding(publisher_ids)
        year_emb = self.year_embedding(year_bucket_ids)

        features = torch.cat([book_emb, title_emb, author_emb, publisher_emb, year_emb], dim=-1)
        output = self.feature_fusion(features)

        return F.normalize(output, p=2, dim=-1)


class ContrastiveTwoTowerModel(nn.Module):
    def __init__(
        self,
        num_users,
        num_books,
        num_locations,
        num_age_buckets,
        num_titles,
        num_authors,
        num_publishers,
        num_year_buckets,
        embedding_dim=64,
        temperature=0.07,
    ):
        super().__init__()
        self.user_tower = UserTower(
            num_users, num_locations, num_age_buckets, embedding_dim
        )
        self.book_tower = BookTower(
            num_books,
            num_titles,
            num_authors,
            num_publishers,
            num_year_buckets,
            embedding_dim,
        )
        self.temperature = temperature

    def forward(self, user_data, book_data):
        user_emb = self.user_tower(
            user_data["user_ids"],
            user_data["location_ids"],
            user_data["age_bucket_ids"],
        )

        book_emb = self.book_tower(
            book_data["book_ids"],
            book_data["title_ids"],
            book_data["author_ids"],
            book_data["publisher_ids"],
            book_data["year_bucket_ids"],
        )

        return user_emb, book_emb

    def compute_similarity(self, user_emb, book_emb):
        return torch.sum(user_emb * book_emb, dim=-1) / self.temperature

In [None]:
def train(model, train_loader, val_loader=None, epochs=10, lr=1e-4):
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.8)
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        num_batches = 0

        for batch_idx, batch in enumerate(train_loader):
            optimizer.zero_grad()

            batch_size = len(batch["user_id"])

            # Convert to tensors
            user_data = {
                "user_ids": torch.LongTensor([batch["user_id"][i] for i in range(batch_size)]),
                "location_ids": torch.LongTensor([batch["user_location_idx"][i] for i in range(batch_size)]),
                "age_bucket_ids": torch.LongTensor([batch["user_age_idx"][i] for i in range(batch_size)]),
            }

            # Positive book data
            pos_book_data = {
                "book_ids": torch.LongTensor([batch["pos_book_id"][i] for i in range(batch_size)]),
                "title_ids": torch.LongTensor([batch["pos_title_idx"][i] for i in range(batch_size)]),
                "author_ids": torch.LongTensor([batch["pos_author_idx"][i] for i in range(batch_size)]),
                "publisher_ids": torch.LongTensor([batch["pos_publisher_idx"][i] for i in range(batch_size)]),
                "year_bucket_ids": torch.LongTensor([batch["pos_year_idx"][i] for i in range(batch_size)]),
            }

            # Handle negative books
            all_neg_book_ids = []
            all_neg_features = {
                "title_ids": [],
                "author_ids": [],
                "publisher_ids": [],
                "year_bucket_ids": [],
            }
            neg_counts = []

            for i in range(batch_size):
                neg_ids = batch["neg_book_ids"][i]
                neg_counts.append(len(neg_ids))

                all_neg_book_ids.extend(neg_ids.tolist())
                all_neg_features["title_ids"].extend(batch["neg_title_idx"][i].tolist())
                all_neg_features["author_ids"].extend(batch["neg_author_idx"][i].tolist())
                all_neg_features["publisher_ids"].extend(batch["neg_publisher_idx"][i].tolist())
                all_neg_features["year_bucket_ids"].extend(batch["neg_year_idx"][i].tolist())

            # Get user and positive book embeddings
            user_emb, pos_book_emb = model(user_data, pos_book_data)

            # Skip batch if no negatives
            if len(all_neg_book_ids) == 0:
                continue

            # Get negative book embeddings
            neg_book_data = {
                "book_ids": torch.LongTensor(all_neg_book_ids),
                "title_ids": torch.LongTensor(all_neg_features["title_ids"]),
                "author_ids": torch.LongTensor(all_neg_features["author_ids"]),
                "publisher_ids": torch.LongTensor(all_neg_features["publisher_ids"]),
                "year_bucket_ids": torch.LongTensor(all_neg_features["year_bucket_ids"]),
            }
            
            neg_book_emb = model.book_tower(
                neg_book_data["book_ids"],
                neg_book_data["title_ids"],
                neg_book_data["author_ids"], 
                neg_book_data["publisher_ids"],
                neg_book_data["year_bucket_ids"]
            )

            # Expand user embeddings to match negative books
            user_emb_expanded = []
            for i, count in enumerate(neg_counts):
                user_emb_expanded.extend([user_emb[i]] * count)
            user_emb_for_negs = torch.stack(user_emb_expanded)

            # Compute similarities
            pos_sim = model.compute_similarity(user_emb, pos_book_emb)
            neg_sim = model.compute_similarity(user_emb_for_negs, neg_book_emb)

            # Reshape negatives back to per-user format and pad
            neg_sim_per_user = []
            start_idx = 0
            for count in neg_counts:
                if count > 0:
                    neg_sim_per_user.append(neg_sim[start_idx : start_idx + count])
                else:
                    neg_sim_per_user.append(torch.tensor([]))
                start_idx += count

            # Pad negatives to same length for batching
            max_negs = max(neg_counts) if neg_counts else 1
            padded_neg_sim = torch.full((batch_size, max_negs), float("-inf"))

            for i, neg_sims in enumerate(neg_sim_per_user):
                if len(neg_sims) > 0:
                    padded_neg_sim[i, : len(neg_sims)] = neg_sims

            # InfoNCE loss: positive is always index 0
            logits = torch.cat([pos_sim.unsqueeze(1), padded_neg_sim], dim=1)
            labels = torch.zeros(batch_size, dtype=torch.long)

            # Only compute loss for samples with valid negatives
            valid_mask = torch.tensor([count > 0 for count in neg_counts])
            if valid_mask.sum() > 0:
                loss = F.cross_entropy(logits[valid_mask], labels[valid_mask])
                
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()

                total_loss += loss.item()
                num_batches += 1

                if batch_idx % 100 == 0:
                    current_lr = optimizer.param_groups[0]['lr']
                    print(f"Batch {batch_idx}: loss = {loss.item():.4f}, lr = {current_lr:.2e}")

        # Update learning rate
        scheduler.step()
        
        if num_batches > 0:
            avg_loss = total_loss / num_batches
            print(f"Epoch {epoch + 1}: Avg Loss = {avg_loss:.4f}")
        else:
            print(f"Epoch {epoch + 1}: No valid batches processed")
            
        wandb.log({
            "epoch": epoch + 1,
            "train_loss": avg_loss if num_batches > 0 else 0,
            "learning_rate": optimizer.param_groups[0]['lr']
        })

In [None]:
# Proper collaborative filtering data preparation
print("=== Proper Collaborative Filtering Data Preparation ===")

# Sample data for faster training (use full dataset later)
sample_size = 1_000_000  # Adjust based on your needs
if len(ratings_clean) > sample_size:
    ratings_sample = ratings_clean.sample(n=sample_size, random_state=42)
    print(f"Sampled {sample_size} rows from {len(ratings_clean)} total ratings")
else:
    ratings_sample = ratings_clean
    print(f"Using full dataset: {len(ratings_sample)} rows")

# Filter for users/books with sufficient interactions
print("\n=== Filtering for Active Users/Popular Books ===")
user_counts = ratings_sample.groupby('user_idx').size()
book_counts = ratings_sample.groupby('book_idx').size()

min_user_interactions = 5
min_book_interactions = 5

active_users = user_counts[user_counts >= min_user_interactions].index
popular_books = book_counts[book_counts >= min_book_interactions].index

print(f"Active users: {len(active_users):,} / {len(user_counts):,}")
print(f"Popular books: {len(popular_books):,} / {len(book_counts):,}")

# Filter dataset
filtered_ratings = ratings_sample[
    (ratings_sample['user_idx'].isin(active_users)) & 
    (ratings_sample['book_idx'].isin(popular_books))
].copy()

print(f"Filtered dataset: {len(filtered_ratings):,} interactions")

# Create compact user/book indices for the filtered data
user_idx_map = {old_idx: new_idx for new_idx, old_idx in enumerate(sorted(active_users))}
book_idx_map = {old_idx: new_idx for new_idx, old_idx in enumerate(sorted(popular_books))}

filtered_ratings['user_idx_new'] = filtered_ratings['user_idx'].map(user_idx_map)
filtered_ratings['book_idx_new'] = filtered_ratings['book_idx'].map(book_idx_map)

print(f"New dimensions - Users: {len(user_idx_map)}, Books: {len(book_idx_map)}")

# Check class balance
print(f"\nTarget distribution: {filtered_ratings['target'].value_counts().to_dict()}")

# User-based split (proper for collaborative filtering)
def user_based_split(df, val_ratio=0.2, test_ratio=0.1):
    train_list = []
    val_list = []
    test_list = []
    
    for user_id in df['user_idx_new'].unique():
        user_data = df[df['user_idx_new'] == user_id]
        
        if len(user_data) >= 5:  # Only split if user has enough interactions
            # First split: train vs (val+test)
            n_test_val = max(2, int(len(user_data) * (val_ratio + test_ratio)))
            train_data = user_data.iloc[:-n_test_val]
            test_val_data = user_data.iloc[-n_test_val:]
            
            # Second split: val vs test
            n_test = max(1, int(len(test_val_data) * (test_ratio / (val_ratio + test_ratio))))
            val_data = test_val_data.iloc[:-n_test]
            test_data = test_val_data.iloc[-n_test:]
            
            train_list.append(train_data)
            val_list.append(val_data)
            test_list.append(test_data)
        else:
            # Keep all data in training for users with few interactions
            train_list.append(user_data)
    
    return pd.concat(train_list), pd.concat(val_list), pd.concat(test_list)

X_train, X_val, X_test = user_based_split(filtered_ratings)

print("Proper collaborative filtering splits:")
print(f"Train: {len(X_train):,} interactions")
print(f"Val: {len(X_val):,} interactions") 
print(f"Test: {len(X_test):,} interactions")

# Store dimensions for model creation
num_users_filtered = len(user_idx_map)
num_books_filtered = len(book_idx_map)

print(f"\nModel dimensions: {num_users_filtered} users, {num_books_filtered} books")

In [None]:
# Simple dataset for validation (no negative sampling needed)
class SimpleRatingsDataset(Dataset):
    def __init__(self, df):
        self.df = df.reset_index(drop=True)
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        return {
            'user_id': torch.LongTensor([row['user_idx_new']]),
            'book_id': torch.LongTensor([row['book_idx_new']]),
            'target': torch.FloatTensor([row['target']])
        }

# Create datasets
print("\n=== Creating Datasets ===")
train_dataset = HybridContrastiveDataset(X_train, num_random_negatives=5)
val_dataset = SimpleRatingsDataset(X_val)

# Create data loaders
batch_size = 128  # Smaller batch size due to more complex data
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: x)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

print(f"DataLoader sizes - Train: {len(train_loader)} batches, Val: {len(val_loader)} batches")

In [None]:
# Custom collate function for variable-length negatives
def contrastive_collate_fn(batch):
    """Collate function for HybridContrastiveDataset"""
    collated = {}
    
    # Handle single values
    for key in ['user_id', 'pos_book_id', 'pos_title_idx', 'pos_author_idx', 
                'pos_publisher_idx', 'pos_year_idx', 'user_age_idx', 'user_location_idx',
                'num_explicit_negs', 'num_random_negs']:
        if key in batch[0]:
            values = [item[key] for item in batch]
            collated[key] = values  # Keep as list for the training function
    
    # Handle variable-length negatives
    neg_keys = ['neg_book_ids', 'neg_title_idx', 'neg_author_idx', 'neg_publisher_idx', 'neg_year_idx']
    for key in neg_keys:
        if key in batch[0]:
            collated[key] = [item[key] for item in batch]
    
    return collated

# Recreate datasets and model with fixes
print("=== Recreating Datasets and Model ===")

# Create new dataset instances
train_dataset = HybridContrastiveDataset(X_train, num_random_negatives=3)

# Create model with improved settings
model = ContrastiveTwoTowerModel(
    num_users=num_users_filtered,
    num_books=num_books_filtered, 
    num_locations=len(feature_mappings['location_vocab']),
    num_age_buckets=len(feature_mappings['age_vocab']),
    num_titles=len(feature_mappings['title_vocab']),
    num_authors=len(feature_mappings['author_vocab']),
    num_publishers=len(feature_mappings['publisher_vocab']),
    num_year_buckets=len(feature_mappings['year_vocab']),
    embedding_dim=64,
    temperature=0.2
)

print(f"Updated model created with {sum(p.numel() for p in model.parameters()):,} parameters")

# Create new data loader
train_loader = DataLoader(
    train_dataset, 
    batch_size=64,  # Reasonable batch size
    shuffle=True, 
    collate_fn=contrastive_collate_fn
)

print(f"Updated train_loader with {len(train_loader)} batches")

# Test the batch structure
test_batch = next(iter(train_loader))
print(f"Batch keys: {sorted(test_batch.keys())}")
print(f"Batch size: {len(test_batch['user_id'])}")

In [None]:
# Test the improved training setup
print("=== Testing Improved Training Setup ===")

# Start with a few batches to verify everything works
print("Running 5 test batches...")
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)

for batch_idx, batch in enumerate(train_loader):
    if batch_idx >= 5:  # Only test first 5 batches
        break
        
    optimizer.zero_grad()
    batch_size = len(batch["user_id"])
    
    # Convert to tensors (fixed data handling)
    user_data = {
        "user_ids": torch.LongTensor([batch["user_id"][i] for i in range(batch_size)]),
        "location_ids": torch.LongTensor([batch["user_location_idx"][i] for i in range(batch_size)]),
        "age_bucket_ids": torch.LongTensor([batch["user_age_idx"][i] for i in range(batch_size)]),
    }

    pos_book_data = {
        "book_ids": torch.LongTensor([batch["pos_book_id"][i] for i in range(batch_size)]),
        "title_ids": torch.LongTensor([batch["pos_title_idx"][i] for i in range(batch_size)]),
        "author_ids": torch.LongTensor([batch["pos_author_idx"][i] for i in range(batch_size)]),
        "publisher_ids": torch.LongTensor([batch["pos_publisher_idx"][i] for i in range(batch_size)]),
        "year_bucket_ids": torch.LongTensor([batch["pos_year_idx"][i] for i in range(batch_size)]),
    }
    
    # Get embeddings and compute similarity
    user_emb, pos_book_emb = model(user_data, pos_book_data)
    pos_sim = model.compute_similarity(user_emb, pos_book_emb)
    
    print(f"Batch {batch_idx}: user_emb shape={user_emb.shape}, pos_sim range=[{pos_sim.min():.3f}, {pos_sim.max():.3f}]")
    
    # Simple test loss (just positive similarities)
    test_loss = -pos_sim.mean()  # Simple test loss
    test_loss.backward()
    optimizer.step()
    
    print(f"  Test loss: {test_loss.item():.4f}")

print("\n✅ Test successful! All features are working.")
print("\nNow starting full training with improved setup...")

# Run full training with fixes
train(model, train_loader, epochs=3, lr=1e-4)

In [None]:
run.finish()