# Welcome to Modal notebooks!

Write Python code and collaborate in real time. Your code runs in Modal's
**serverless cloud**, and anyone in the same workspace can join.

This notebook comes with some common Python libraries installed. Run
cells with `Shift+Enter`.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import T5Tokenizer, T5EncoderModel
import pandas as pd
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import numpy as np

batch_size = 32

# -------------------------------------------------
# 1. Load and Clean Data
# -------------------------------------------------
df = pd.read_csv('/root/final_data_with_catalog.csv')

# Check for missing values
print(f"Total rows: {len(df)}")
print(f"Missing price_bin values: {df['price_bin'].isna().sum()}")
print(f"Missing catalog_content values: {df['catalog_content'].isna().sum()}")

# Drop rows with missing price_bin or catalog_content
df = df.dropna(subset=['price_bin', 'catalog_content'])
print(f"Rows after dropping NaN: {len(df)}")

# Convert price_bin intervals to integer indices
# Create a mapping from unique intervals to integer indices
unique_bins = sorted(df['price_bin'].unique())
bin_to_idx = {bin_val: idx for idx, bin_val in enumerate(unique_bins)}
df['price_bin'] = df['price_bin'].map(bin_to_idx)

print(f"Number of unique price bins: {df['price_bin'].nunique()}")
print(f"Price bin range: {df['price_bin'].min()} to {df['price_bin'].max()}")

# Check class distribution
bin_counts = df['price_bin'].value_counts().sort_index()
print(f"\nClass distribution:")
print(bin_counts)

# Filter out bins with less than 2 samples (required for stratified split)
min_samples_per_bin = 2
valid_bins = bin_counts[bin_counts >= min_samples_per_bin].index
df = df[df['price_bin'].isin(valid_bins)]

print(f"\nRows after filtering rare bins: {len(df)}")
print(f"Remaining unique bins: {df['price_bin'].nunique()}")

# Remap price_bin to consecutive integers after filtering
unique_bins_filtered = sorted(df['price_bin'].unique())
bin_remap = {old_idx: new_idx for new_idx, old_idx in enumerate(unique_bins_filtered)}
df['price_bin'] = df['price_bin'].map(bin_remap)

print(f"Final price bin range: {df['price_bin'].min()} to {df['price_bin'].max()}")

# Train-test split with stratification
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42, stratify=df['price_bin'])

# -------------------------------------------------
# 2. Tokenizer and T5 Encoder
# -------------------------------------------------
tokenizer = T5Tokenizer.from_pretrained("t5-base")
t5_encoder = T5EncoderModel.from_pretrained("t5-base")

# Freeze T5 (optional if you don't want fine-tuning)
for param in t5_encoder.parameters():
    param.requires_grad = False

# -------------------------------------------------
# 3. Dataset Class
# -------------------------------------------------
class TextBinDataset(Dataset):
    def __init__(self, df, tokenizer, max_len=105):
        self.texts = df['catalog_content'].tolist()
        self.labels = df['price_bin'].tolist()
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        enc = self.tokenizer(
            text, padding='max_length', truncation=True,
            max_length=self.max_len, return_tensors='pt'
        )
        return {
            'input_ids': enc['input_ids'].squeeze(0),
            'attention_mask': enc['attention_mask'].squeeze(0),
            'label': torch.tensor(label, dtype=torch.long)
        }

# -------------------------------------------------
# 4. Model Definition
# -------------------------------------------------
class T5MLPClassifier(nn.Module):
    def __init__(self, t5_model, hidden_dim=512, num_classes=50):
        super().__init__()
        self.t5 = t5_model
        self.mlp = nn.Sequential(
            nn.Linear(768, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim // 2, num_classes)
        )

    def forward(self, input_ids, attention_mask):
        with torch.no_grad():
            outputs = self.t5(input_ids=input_ids, attention_mask=attention_mask)
            embeddings = outputs.last_hidden_state.mean(dim=1)  # mean pooling
        logits = self.mlp(embeddings)
        return logits

# -------------------------------------------------
# 5. Dataloaders
# -------------------------------------------------
train_dataset = TextBinDataset(train_df, tokenizer)
val_dataset = TextBinDataset(val_df, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# -------------------------------------------------
# 6. Training Setup
# -------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = T5MLPClassifier(t5_encoder, hidden_dim=512, num_classes=df['price_bin'].nunique()).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.mlp.parameters(), lr=1e-4)

# -------------------------------------------------
# 7. Training Loop
# -------------------------------------------------
def evaluate(model, dataloader):
    """Compute loss, accuracy, and MAD on validation set."""
    model.eval()
    val_loss, correct, total = 0, 0, 0
    preds_all, labels_all = [], []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            logits = model(input_ids, attention_mask)
            loss = criterion(logits, labels)
            val_loss += loss.item() * labels.size(0)

            preds = torch.argmax(logits, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

            preds_all.extend(preds.cpu().numpy())
            labels_all.extend(labels.cpu().numpy())

    val_acc = correct / total
    mad = np.abs(np.array(preds_all) - np.array(labels_all)).mean()
    avg_loss = val_loss / total
    return avg_loss, val_acc, mad


epochs = 60
for epoch in range(epochs):
    model.train()
    total_loss, total_correct, total_samples = 0, 0, 0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]"):
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        logits = model(input_ids, attention_mask)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * labels.size(0)
        preds = torch.argmax(logits, dim=1)
        total_correct += (preds == labels).sum().item()
        total_samples += labels.size(0)

    train_acc = total_correct / total_samples
    train_loss = total_loss / total_samples

    # Run validation at the end of every epoch
    val_loss, val_acc, val_mad = evaluate(model, val_loader)

    print(f"\nEpoch {epoch+1} Summary:")
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"  Val   Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} | Val MAD: {val_mad:.4f}")
    print("-" * 60)

Total rows: 75000
Missing price_bin values: 0
Missing catalog_content values: 0
Rows after dropping NaN: 75000
Number of unique price bins: 54
Price bin range: 0 to 53

Class distribution:
price_bin
0      738
1      746
2     1506
3     1505
4     1147
5     1860
6      750
7     1481
8     1473
9     1554
10    1497
11     739
12    1449
13    1540
14    1497
15    1498
16    1510
17    1494
18    1500
19    1498
20    1335
21    1668
22    1426
23    1575
24    1500
25    1500
26    1515
27    1446
28    1477
29    1500
30    1498
31    1500
32    1501
33    1576
34    1491
35    1497
36    1504
37    1502
38    1174
39    1818
40    1485
41    1491
42    1244
43    1769
44       1
45      10
46    1519
47    1133
48    1865
49    1505
50    1505
51    1495
52    1502
53    1491
Name: count, dtype: int64

Rows after filtering rare bins: 74999
Remaining unique bins: 53
Final price bin range: 0 to 52


spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


model.safetensors:   0%|          | 0.00/892M [00:00<?, ?B/s]

Epoch 1/60 [Train]:   0%|                                                             | 0/1875 [00:00<?, ?it/s]Epoch 1/60 [Train]:   0%|                                                     | 1/1875 [00:01<41:37,  1.33s/it]Epoch 1/60 [Train]:   0%|                                                     | 2/1875 [00:01<20:27,  1.53it/s]Epoch 1/60 [Train]:   0%|                                                     | 3/1875 [00:01<13:39,  2.28it/s]Epoch 1/60 [Train]:   0%|                                                     | 4/1875 [00:01<10:28,  2.98it/s]Epoch 1/60 [Train]:   0%|▏                                                    | 5/1875 [00:02<08:54,  3.50it/s]Epoch 1/60 [Train]:   0%|▏                                                    | 6/1875 [00:02<07:48,  3.99it/s]Epoch 1/60 [Train]:   0%|▏                                                    | 7/1875 [00:02<07:06,  4.38it/s]Epoch 1/60 [Train]:   0%|▏                                                    | 8/1875 [00:02<06:37,  4


Epoch 1 Summary:
  Train Loss: 3.7541 | Train Acc: 0.0488
  Val   Loss: 3.3292 | Val Acc: 0.1070 | Val MAD: 13.7677
------------------------------------------------------------


Epoch 2/60 [Train]:   0%|                                                             | 0/1875 [00:00<?, ?it/s]Epoch 2/60 [Train]:   0%|                                                     | 1/1875 [00:00<05:35,  5.58it/s]Epoch 2/60 [Train]:   0%|                                                     | 2/1875 [00:00<05:37,  5.55it/s]Epoch 2/60 [Train]:   0%|                                                     | 3/1875 [00:00<05:38,  5.53it/s]Epoch 2/60 [Train]:   0%|                                                     | 4/1875 [00:00<05:38,  5.53it/s]Epoch 2/60 [Train]:   0%|▏                                                    | 5/1875 [00:00<05:38,  5.52it/s]Epoch 2/60 [Train]:   0%|▏                                                    | 6/1875 [00:01<05:38,  5.52it/s]Epoch 2/60 [Train]:   0%|▏                                                    | 7/1875 [00:01<05:37,  5.53it/s]Epoch 2/60 [Train]:   0%|▏                                                    | 8/1875 [00:01<05:37,  5


Epoch 2 Summary:
  Train Loss: 3.0984 | Train Acc: 0.1336
  Val   Loss: 2.6965 | Val Acc: 0.2244 | Val MAD: 10.7993
------------------------------------------------------------


Epoch 3/60 [Train]:   0%|                                                             | 0/1875 [00:00<?, ?it/s]Epoch 3/60 [Train]:   0%|                                                     | 1/1875 [00:00<05:47,  5.40it/s]Epoch 3/60 [Train]:   0%|                                                     | 2/1875 [00:00<05:41,  5.48it/s]Epoch 3/60 [Train]:   0%|                                                     | 3/1875 [00:00<05:42,  5.47it/s]Epoch 3/60 [Train]:   0%|                                                     | 4/1875 [00:00<05:42,  5.46it/s]Epoch 3/60 [Train]:   0%|▏                                                    | 5/1875 [00:00<05:43,  5.44it/s]Epoch 3/60 [Train]:   0%|▏                                                    | 6/1875 [00:01<05:42,  5.46it/s]Epoch 3/60 [Train]:   0%|▏                                                    | 7/1875 [00:01<05:41,  5.48it/s]Epoch 3/60 [Train]:   0%|▏                                                    | 8/1875 [00:01<05:40,  5


Epoch 3 Summary:
  Train Loss: 2.6457 | Train Acc: 0.2140
  Val   Loss: 2.3706 | Val Acc: 0.2793 | Val MAD: 10.0455
------------------------------------------------------------


Epoch 4/60 [Train]:   0%|                                                             | 0/1875 [00:00<?, ?it/s]Epoch 4/60 [Train]:   0%|                                                     | 1/1875 [00:00<05:41,  5.49it/s]Epoch 4/60 [Train]:   0%|                                                     | 2/1875 [00:00<05:40,  5.50it/s]Epoch 4/60 [Train]:   0%|                                                     | 3/1875 [00:00<05:50,  5.34it/s]Epoch 4/60 [Train]:   0%|                                                     | 4/1875 [00:00<05:45,  5.42it/s]Epoch 4/60 [Train]:   0%|▏                                                    | 5/1875 [00:00<05:46,  5.40it/s]Epoch 4/60 [Train]:   0%|▏                                                    | 6/1875 [00:01<05:46,  5.39it/s]Epoch 4/60 [Train]:   0%|▏                                                    | 7/1875 [00:01<05:49,  5.35it/s]Epoch 4/60 [Train]:   0%|▏                                                    | 8/1875 [00:01<05:49,  5


Epoch 4 Summary:
  Train Loss: 2.3386 | Train Acc: 0.2810
  Val   Loss: 2.0430 | Val Acc: 0.3609 | Val MAD: 8.7309
------------------------------------------------------------


Epoch 5/60 [Train]:   0%|                                                             | 0/1875 [00:00<?, ?it/s]Epoch 5/60 [Train]:   0%|                                                     | 1/1875 [00:00<05:41,  5.49it/s]Epoch 5/60 [Train]:   0%|                                                     | 2/1875 [00:00<05:39,  5.52it/s]Epoch 5/60 [Train]:   0%|                                                     | 3/1875 [00:00<05:39,  5.52it/s]Epoch 5/60 [Train]:   0%|                                                     | 4/1875 [00:00<05:39,  5.50it/s]Epoch 5/60 [Train]:   0%|▏                                                    | 5/1875 [00:00<05:40,  5.50it/s]Epoch 5/60 [Train]:   0%|▏                                                    | 6/1875 [00:01<05:40,  5.50it/s]Epoch 5/60 [Train]:   0%|▏                                                    | 7/1875 [00:01<05:40,  5.49it/s]Epoch 5/60 [Train]:   0%|▏                                                    | 8/1875 [00:01<05:39,  5


Epoch 5 Summary:
  Train Loss: 2.1211 | Train Acc: 0.3363
  Val   Loss: 1.8231 | Val Acc: 0.4309 | Val MAD: 7.6001
------------------------------------------------------------


Epoch 6/60 [Train]:   0%|                                                             | 0/1875 [00:00<?, ?it/s]Epoch 6/60 [Train]:   0%|                                                     | 1/1875 [00:00<05:40,  5.50it/s]Epoch 6/60 [Train]:   0%|                                                     | 2/1875 [00:00<05:42,  5.48it/s]Epoch 6/60 [Train]:   0%|                                                     | 3/1875 [00:00<05:42,  5.46it/s]Epoch 6/60 [Train]:   0%|                                                     | 4/1875 [00:00<05:41,  5.48it/s]Epoch 6/60 [Train]:   0%|▏                                                    | 5/1875 [00:00<05:39,  5.50it/s]Epoch 6/60 [Train]:   0%|▏                                                    | 6/1875 [00:01<05:38,  5.51it/s]Epoch 6/60 [Train]:   0%|▏                                                    | 7/1875 [00:01<05:41,  5.47it/s]Epoch 6/60 [Train]:   0%|▏                                                    | 8/1875 [00:01<05:40,  5


Epoch 6 Summary:
  Train Loss: 1.9669 | Train Acc: 0.3795
  Val   Loss: 1.6641 | Val Acc: 0.4795 | Val MAD: 6.7883
------------------------------------------------------------


Epoch 7/60 [Train]:   0%|                                                             | 0/1875 [00:00<?, ?it/s]Epoch 7/60 [Train]:   0%|                                                     | 1/1875 [00:00<05:40,  5.50it/s]Epoch 7/60 [Train]:   0%|                                                     | 2/1875 [00:00<05:39,  5.51it/s]Epoch 7/60 [Train]:   0%|                                                     | 3/1875 [00:00<05:40,  5.50it/s]Epoch 7/60 [Train]:   0%|                                                     | 4/1875 [00:00<05:40,  5.49it/s]Epoch 7/60 [Train]:   0%|▏                                                    | 5/1875 [00:00<05:40,  5.49it/s]Epoch 7/60 [Train]:   0%|▏                                                    | 6/1875 [00:01<05:40,  5.49it/s]Epoch 7/60 [Train]:   0%|▏                                                    | 7/1875 [00:01<05:39,  5.51it/s]Epoch 7/60 [Train]:   0%|▏                                                    | 8/1875 [00:01<05:39,  5


Epoch 7 Summary:
  Train Loss: 1.8325 | Train Acc: 0.4180
  Val   Loss: 1.5334 | Val Acc: 0.5110 | Val MAD: 6.3609
------------------------------------------------------------


Epoch 8/60 [Train]:   0%|                                                             | 0/1875 [00:00<?, ?it/s]Epoch 8/60 [Train]:   0%|                                                     | 1/1875 [00:00<05:39,  5.51it/s]Epoch 8/60 [Train]:   0%|                                                     | 2/1875 [00:00<05:39,  5.52it/s]Epoch 8/60 [Train]:   0%|                                                     | 3/1875 [00:00<05:39,  5.51it/s]Epoch 8/60 [Train]:   0%|                                                     | 4/1875 [00:00<05:40,  5.50it/s]Epoch 8/60 [Train]:   0%|▏                                                    | 5/1875 [00:00<05:39,  5.51it/s]Epoch 8/60 [Train]:   0%|▏                                                    | 6/1875 [00:01<05:38,  5.51it/s]Epoch 8/60 [Train]:   0%|▏                                                    | 7/1875 [00:01<05:38,  5.52it/s]Epoch 8/60 [Train]:   0%|▏                                                    | 8/1875 [00:01<05:39,  5