In [1]:
!pip install transformers
!pip install torch
!pip install tqdm

Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable


In [2]:
!python -c "import torch; print(torch.__version__)"

2.5.1+cu124


In [1]:
# imports

import json
import pandas as pd


from collections import defaultdict

In [2]:
user_history_filepath = "/scratch/general/vast/u1471428/hugging_face_cache/user_history_data.json" 

In [3]:
json_file = open(user_history_filepath, 'r')
user_history = json.load(json_file)
json_file.close()

In [4]:
def count_user_with_history_sizes(user_h):
    cnt_dict = defaultdict(int)
    for user_id in user_history.keys():
        cnt_dict[len(user_history[user_id])]+=1
    print(cnt_dict)
    
def count_user_with_history_size_above(user_h, above):
    cnt = 0
    for user_id in user_history.keys():
        if len(user_history[user_id]) >= above:
            cnt+=1
    print(cnt)

def filter_users(user_h, min_history_length, max_history_length=-1)->dict:
    filtered_users = {}
    count=0
    for user, history in user_h.items():
        if len(history) >= min_history_length:
            filtered_users[user] = history[:min_history_length]
            count+=1
        
        if count==-1:
            break
    return filtered_users
        
    

In [5]:
filtered_users = filter_users(user_history, 20)
print(len(filtered_users))

7297


In [6]:
for user_id in filtered_users.keys():
    history = (filtered_users[user_id])
    print(len(history))
    print(history[0])
    break

20
{'rating': 3.0, 'review_title': "Didn't work for my needs", 'review_text': "Truth be told, I actually got this playpen as I was looking for an alternative to having a crate for the puppy I was getting. I am truly disappointed when I first used it. First, I had a real hard time putting this together. As a single 60 something year old woman, I had to painstakingly, reading instructions, try to put this together by myself. I guess I was hoping I could just open the box and 'pop' open the playpen. Instead, I had to put a load of pipe like pieces together and then string the netting and pace over it. After finally getting it together, I was ready to try the new puppy in there. The puppy was not happy (of course, probably to be expected) and I found her almost able to rip the netting apart. It also did not form the strongest of a surround set up (it was kind of flimsy). My puppy had more fun trying to hide under it when she was outside of it. So, this all being said, I am not sure how it 

In [7]:
assert all(len(user) == 20 for user in filtered_users.values()), "Not all users have a history length of 20"

In [8]:
from transformers import BertTokenizer, BertModel
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, MinMaxScaler
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from collections import Counter

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [10]:
# device = torch.device("cpu")

In [11]:
users = list(filtered_users.keys())
train_users, test_users = train_test_split(users, test_size=0.2, random_state=42)
train_data = {user: filtered_users[user] for user in train_users}
test_data = {user: filtered_users[user] for user in test_users}

In [12]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
main_category_encoder = LabelEncoder()
category_encoder = LabelEncoder()
product_encoder = LabelEncoder()

In [13]:
#Fit Encoders

main_categories = [entry["main_category"] for user in users for entry in filtered_users[user]]
main_category_encoder.fit(main_categories)
categories = [entry["categories"][-1] for user in users for entry in filtered_users[user]]
category_encoder.fit(categories)
product_ids = [entry["product_id"] for user in users for entry in filtered_users[user]]
product_encoder.fit(product_ids)

LabelEncoder()

In [14]:
print(categories)

['Playards', 'Playard Bedding', 'Door & Stair Gates', 'Door & Stair Gates', 'Snack Foods', 'Playard Bedding', 'Crib Bedding Sets', 'Sun Protection', 'Wipes & Refills', 'Wipes & Refills', 'Step Stools', 'Lightweight', 'Door & Stair Gates', 'Disposable Diapers', 'Growth Charts', 'Video Monitors', 'Disposable Diapers', 'Disposable Diapers', 'Disposable Diapers', 'Disposable Diapers', 'Receiving Blankets', 'Diaper Stackers & Caddies', 'Disposable Diapers', 'Washcloths & Wash Gloves', 'Swaddling Blankets', 'Gift Sets', 'Grooming & Healthcare Kits', 'Food Storage', 'Thermometers', 'Storage Bins & Boxes', 'Night Lights', 'Audio Monitors', 'Parent Cup Holders', 'Lotions', 'Swaddling Blankets', 'Swaddling Blankets', 'Strap & Belt Covers', 'Booster', 'Convertible', 'Audio Monitors', 'Bottles', 'Disposable Diapers', 'Gift Sets', 'Cups', 'Diaper Stackers & Caddies', 'Disposable Diapers', 'Bottles', 'Disposable Diapers', 'Disposable Diapers', 'Sun Protection', 'Drooling Bibs', 'Cups', 'Flatware Set

In [15]:
def generate_category_product_mappings(data, category_encoder, product_encoder):
    category_index_to_products = defaultdict(list)
    for user, histories in data.items():
        for history in histories:
            #print([history['categories'][-1]])
            category_idx = category_encoder.transform([history['categories'][-1]])[0]
            product_idx = product_encoder.transform([history['product_id']])[0]
            category_index_to_products[category_idx].append(product_idx)
        
        category_index_to_products = {k: list(set(v)) for k, v in category_index_to_products.items()}
        return category_index_to_products

category_index_to_products = generate_category_product_mappings(filtered_users, category_encoder, product_encoder)

In [16]:
print(len(category_index_to_products))

12


In [17]:
def normalize_ratings(histories):
    all_ratings = [h["rating"] for user in histories for h in histories[user]]
    scaler = MinMaxScaler(feature_range=(0,1))
    scaler.fit([[r] for r in all_ratings])
    return scaler

rating_scaler = normalize_ratings(filtered_users)

def preprocess_history(history):
    texts = [
        f"{h.get('review_title','')} {h.get('title','')} {h.get('main_category')}" for h in history
    ]
    
    texts = [text for text in texts if text.strip()]
    
    tokens = tokenizer(texts, padding="max_length", truncation=True, return_tensors="pt", max_length=128)
    
    ratings = torch.tensor([rating_scaler.transform([[h["rating"]]])[0][0] for h in history], dtype=torch.float32)
    
    #categories = torch.tensor(category_encoder.transform([h["main_category"] for h in history]))
    
    categories = torch.tensor(category_encoder.transform([h["categories"][-1] for h in history]))
    
    product_ids = torch.tensor(product_encoder.transform([h["product_id"] for h in history]))
    
    return tokens, ratings, categories, product_ids

In [18]:
class UserDataset(Dataset):
    def __init__(self, data, tokenizer, category_encoder, product_encoder, seq_len=15, pred_len=5):
        self.data = data
        self.tokenizer = tokenizer
        self.category_encoder = category_encoder # Is it required?
        self.product_encoder = product_encoder # Is it required?
        self.seq_len = seq_len
        self.pred_len = pred_len
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        user, history = list(self.data.items())[idx]
        tokens, ratings, categories, product_ids = preprocess_history(history)
        input_tokens = {
            "input_ids":tokens["input_ids"][:self.seq_len],
            "attention_mask": tokens["attention_mask"][:self.seq_len],
            "token_type_ids": tokens["token_type_ids"][:self.seq_len],
        }
        target_category_ids = categories[self.seq_len: self.seq_len+self.pred_len]
        target_ids = product_ids[self.seq_len: self.seq_len+self.pred_len]
        
        return input_tokens, ratings[:self.seq_len], categories[:self.seq_len], target_category_ids, target_ids

train_dataset = UserDataset(train_data, tokenizer, category_encoder, product_encoder)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

In [19]:

class MultiStageTransformerRecommendationModel(nn.Module):
    def __init__(self, bert_model_name="bert-base-uncased", num_categories=800, num_total_products=1000, d_model=128, nhead=8, num_decoder_layers=3):
        super(MultiStageTransformerRecommendationModel, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        bert_hidden_size = self.bert.config.hidden_size 
        
        # Category prediction branch
        self.category_fc = nn.Linear(bert_hidden_size+1, num_categories)
        
        # Product prediction branch
        self.fc_features = nn.Linear(bert_hidden_size+1, d_model) # Bert output + rating
        self.category_embedding = nn.Embedding(num_categories, d_model)
        self.tgt_embedding = nn.Embedding(num_total_products, d_model)
        
        # Transformer Decoder
        self.decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
        self.decoder = nn.TransformerDecoder(self.decoder_layer, num_layers=num_decoder_layers)
        
        self.product_category_fc = nn.Linear(d_model, num_categories)
        self.product_id_fc = nn.Linear(d_model, num_total_products)
        
        self.activation = nn.ReLU()
        self.layer_norm = nn.LayerNorm(d_model)
        
        self.init_weights()
    
    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
        
    
    def forward(self, tokens, ratings, categories, tgt, stage="category", category_to_product_map=None):
        batch_size, seq_len, max_token_length = tokens["input_ids"].shape
        output_seq_len = tgt.shape[1]
        
        input_ids = tokens["input_ids"].view(-1, max_token_length)  # [batch_size * seq_len, max_token_length]
        attention_mask = tokens["attention_mask"].view(-1, max_token_length)
        token_type_ids = tokens["token_type_ids"].view(-1, max_token_length)
        
        bert_output = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        
        #sequence_output = bert_output.last_hidden_state
        pooler_ouput = bert_output.pooler_output
        
        ratings = ratings.view(batch_size * seq_len, -1).float()
        features = torch.cat([pooler_ouput, ratings], dim=-1)
        
        if stage == "category":
            category_logits = self.category_fc(features)
            return category_logits.view(batch_size, seq_len,-1)
        
        elif stage == "product":
            x = self.fc_features(features)
            x = self.activation(x)
            x = self.layer_norm(x)
            
            category_embeds = self.category_embedding(categories.view(-1)).float()
            x = x + category_embeds
            x = x.view(batch_size, seq_len, -1)
            
            memory = x
            
            tgt_embeds = self.tgt_embedding(tgt)
            decoded = self.decoder(tgt=tgt_embeds, memory=memory)
            
            product_category_logits = self.product_category_fc(decoded)
            product_id_logits = self.product_id_fc(decoded)
            
            
            if category_to_product_map is not None:
                mask = torch.zeros_like(product_id_logits)
                for i in range(batch_size):
                    for j in range(output_seq_len):
                        predicted_category = torch.argmax(product_category_logits[i,j],dim=-1).item()
                        if predicted_category in category_to_product_map:
                            product_indices = category_to_product_map[predicted_category]
                        else:
                            product_indices = list(range(product_id_logits.size(-1)))
#                         product_indices = category_to_product_map[predicted_category]
                        mask[i, j, product_indices] = 1
                product_id_logits = product_id_logits*mask
            
            return product_category_logits,product_id_logits

In [20]:
category_indices = category_encoder.transform(categories)
category_counts = Counter(category_indices)

print(category_counts)
# Compute weights: Inverse of frequency
num_categories = len(category_counts)
total_samples = len(category_indices)
class_weights = {category: total_samples / count for category, count in category_counts.items()}

# Normalize weights (optional, ensures sum of weights = 1)
normalized_weights = {category: weight / sum(class_weights.values()) for category, weight in class_weights.items()}

# Convert to tensor for use in PyTorch
weights_tensor = torch.tensor([normalized_weights[i] for i in range(num_categories)], dtype=torch.float)

category_loss_fn = torch.nn.CrossEntropyLoss(weight=weights_tensor.to(device))

Counter({165: 6696, 187: 6059, 502: 4512, 149: 2991, 93: 2908, 496: 2851, 190: 2773, 555: 2741, 379: 2271, 109: 2226, 160: 2121, 547: 2117, 231: 2020, 175: 1910, 59: 1718, 29: 1702, 410: 1586, 401: 1375, 84: 1350, 342: 1293, 515: 1293, 209: 1278, 143: 1267, 478: 1266, 176: 1232, 223: 1231, 79: 1213, 99: 1199, 458: 1165, 227: 1135, 288: 1088, 378: 1084, 482: 1057, 469: 1042, 512: 1041, 395: 1027, 212: 1021, 85: 993, 399: 978, 334: 976, 398: 957, 199: 936, 531: 934, 495: 931, 251: 927, 6: 905, 62: 844, 528: 835, 433: 833, 400: 830, 427: 822, 83: 809, 98: 800, 159: 789, 51: 740, 262: 739, 8: 721, 91: 715, 364: 712, 465: 708, 275: 697, 357: 680, 55: 678, 307: 677, 158: 672, 95: 671, 137: 667, 358: 667, 405: 654, 476: 650, 211: 650, 321: 640, 260: 632, 311: 626, 350: 616, 474: 610, 532: 587, 34: 577, 19: 573, 552: 552, 329: 532, 513: 532, 92: 530, 363: 528, 461: 527, 48: 524, 540: 522, 183: 513, 64: 513, 234: 509, 277: 506, 411: 503, 113: 501, 497: 495, 108: 490, 161: 489, 181: 486, 372: 45

In [21]:
model = MultiStageTransformerRecommendationModel(num_categories=len(category_encoder.classes_), num_total_products=len(product_encoder.classes_)).to(device)
# category_loss_fn = nn.CrossEntropyLoss()
product_loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

In [22]:
print(len(category_encoder.classes_))
print(len(product_encoder.classes_))

561
40356


In [23]:
num_epochs = 1
for epoch in range(num_epochs):
    model.train()
    total_category_loss = 0.0
    total_product_category_loss = 0.0
    total_product_id_loss = 0.0
    
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}", leave=True)
    
    for tokens, ratings, categories, target_categories, target_ids in progress_bar:
        
        tokens = {
            "input_ids": tokens["input_ids"].to(device),
            "attention_mask": tokens["attention_mask"].to(device),
            "token_type_ids": tokens["token_type_ids"].to(device),
        }
        ratings = ratings.to(device)
        categories = categories.to(device)
        target_categories = target_categories.to(device)
        target_ids = target_ids.to(device)
        
        
        # Stage1: Category Prediction
        optimizer.zero_grad()
        category_logits = model(tokens=tokens, ratings=ratings, categories=categories, tgt=target_ids, stage="category")
        category_logits = category_logits.view(-1, category_logits.size(-1))
#         print(category_logits.shape)
#         print(categories.shape)
        category_loss = category_loss_fn(category_logits, categories.view(-1))
#         category_loss.backward()
#         optimizer.step()
        total_category_loss+=category_loss.item()
        
#         with torch.no_grad():
#             predicted_categories = torch.argmax(category_logits, dim=-1)
#             print(predicted_categories)
        
#         optimizer.zero_grad()
        product_category_logits, product_id_logits = model(
            tokens = tokens,
            ratings = ratings,
            categories = categories,
            tgt = target_ids,
            stage="product",
            category_to_product_map = category_index_to_products
        )
        
        product_category_loss = category_loss_fn(
            product_category_logits.view(-1, product_category_logits.size(-1)), target_categories.view(-1)
        )
#         category_loss += product_category_loss
#         product_category_loss.backward()
        
        product_id_loss = product_loss_fn(
            product_id_logits.view(-1, product_id_logits.size(-1)), target_ids.view(-1)
        )
#         product_id_loss.backward()

        total_loss = category_loss + product_category_loss + product_id_loss
        total_loss.backward()
        optimizer.step()
        
        total_category_loss += category_loss.item()
        total_product_category_loss += product_category_loss.item()
        total_product_id_loss += product_id_loss.item()
        
        progress_bar.set_postfix({
            "Cat Loss": category_loss.item(),
            "Prod Cat Loss": product_category_loss.item(),
            "Prod ID Loss": product_id_loss.item(),
        })
        
        
    print(f"Epoch {epoch + 1}/{num_epochs}")
    print(f"  Avg Category Loss: {total_category_loss / len(train_loader):.4f}")
    print(f"  Avg Product Category Loss: {total_product_category_loss / len(train_loader):.4f}")
    print(f"  Avg Product ID Loss: {total_product_id_loss / len(train_loader):.4f}")

Epoch 1: 100%|██████████| 365/365 [19:02<00:00,  3.13s/it, Cat Loss=6.38, Prod Cat Loss=6.58, Prod ID Loss=10.6]

Epoch 1/1
  Avg Category Loss: 12.6021
  Avg Product Category Loss: 6.4747
  Avg Product ID Loss: 10.5776



