In [8]:
import os, glob, math, gc
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, default_collate
from torch_geometric.data import HeteroData
from torch_geometric.utils import sort_edge_index
from sklearn.model_selection import GroupShuffleSplit, GroupKFold
from sklearn.metrics import roc_auc_score
from torchvision import transforms, models
from PIL import Image
from transformers import DistilBertTokenizer, DistilBertModel



In [9]:
# --------------------------
# Configuration
# --------------------------
class Config:
    text_max_length = 64
    batch_size = 128
    emb_dim = 256
    num_heads = 4
    dropout = 0.5
    lr = 5e-5           # Lower LR for stability
    weight_decay = 1e-5
    epochs = 14
    patience = 5        # For early stopping
    k_folds = 2

config = Config()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.benchmark = True

# --------------------------
# 1. Comprehensive Preprocessing
# --------------------------
def load_and_preprocess():
    # Load raw CSVs
    articles = pd.read_csv("/kaggle/input/h-and-m-personalized-fashion-recommendations/articles.csv")
    customers = pd.read_csv("/kaggle/input/h-and-m-personalized-fashion-recommendations/customers.csv")
    transactions = pd.read_csv("/kaggle/input/h-and-m-personalized-fashion-recommendations/transactions_train.csv")
    
    # Adjust IDs
    def adjust_id(x):
        x = str(x)
        return "0" + x if len(x) == 9 else x
    transactions["article_id"] = transactions["article_id"].apply(adjust_id)
    articles["article_id"] = articles["article_id"].apply(adjust_id)
    
    # Filter customers and transactions
    customers = customers[['customer_id', 'age']].dropna(subset=['age'])
    valid_article_ids = set(articles['article_id'].unique())
    filtered_transactions = transactions[transactions['article_id'].isin(valid_article_ids)]


    
    # Cold-start handling using transaction counts and customer age (demographics)
    transaction_counts = filtered_transactions['customer_id'].value_counts()
    frequent_customers = transaction_counts[transaction_counts > 8].index.tolist()
    cold_start_few = transaction_counts[transaction_counts <= 2].index.tolist()
    cold_start_no = list(set(customers['customer_id']) - set(filtered_transactions['customer_id']))
    
    # Stratified sampling: (50-30-20)
    def sample_customers(group, target_size):
        return np.random.choice(group, size=min(len(group), target_size), replace=False)
    sample_sizes = [50000, 30000, 20000]
    sampled = [
        sample_customers(frequent_customers, sample_sizes[0]),
        sample_customers(cold_start_few, sample_sizes[1]),
        sample_customers(cold_start_no, sample_sizes[2])
    ]
    sampled_customers = np.concatenate(sampled)
    sampled_customers = customers[customers['customer_id'].isin(sampled_customers)].reset_index(drop=True)
    sampled_customers = sampled_customers[['customer_id', 'age']]
    
    # Filter transactions for sampled customers
    filtered_trans = filtered_transactions[
        (filtered_transactions['customer_id'].isin(sampled_customers['customer_id'])) &
        (filtered_transactions['article_id'].isin(valid_article_ids))
    ]


    
    # Get image paths and filter articles with valid images
    all_image_paths = glob.glob("/kaggle/input/h-and-m-personalized-fashion-recommendations/images/*/*.jpg")
    valid_ids = set(os.path.splitext(os.path.basename(p))[0] for p in all_image_paths)
    def get_image_path(aid):
        subfolder = aid[:3]
        path = f"/kaggle/input/h-and-m-personalized-fashion-recommendations/images/{subfolder}/{aid}.jpg"
        return path if aid in valid_ids else None
    articles['image_path'] = articles['article_id'].apply(get_image_path)
    articles = articles.dropna(subset=['image_path']).reset_index(drop=True)
    
    # Merge price information from transactions
    article_prices = filtered_trans[['article_id', 'price']].drop_duplicates(subset=['article_id'])
    articles = articles.merge(article_prices, on='article_id', how='inner')
    articles['detail_desc'] = articles['detail_desc'].fillna('').astype(str)
    articles = articles[['article_id', 'detail_desc', 'image_path', 'price']]
    
    valid_articles = set(articles['article_id'])
    filtered_trans = filtered_trans[filtered_trans['article_id'].isin(valid_articles)]
    
    # Validate images
    articles['image_exists'] = articles['image_path'].apply(os.path.exists)
    missing = articles[~articles['image_exists']]
    if len(missing) > 0:
        print(f"Found {len(missing)} articles with missing images. Samples:")
        print(missing.sample(min(5, len(missing))))
        articles = articles[articles['image_exists']].drop(columns=['image_exists'])
    else:
        print("All images validated successfully!")

    
    # Create mapped IDs for articles and customers
    articles = articles.reset_index(drop=True)
    articles["article_mapped_id"] = articles.index
    sampled_customers = sampled_customers.reset_index(drop=True)
    sampled_customers["customer_mapped_id"] = sampled_customers.index
    
    filtered_trans = filtered_trans.merge(
        articles[['article_id', 'article_mapped_id']],
        on='article_id',
        how='inner'
    ).merge(
        sampled_customers[['customer_id', 'customer_mapped_id']],
        on='customer_id',
        how='inner'
    )
    
    # --------------------------
    # Precompute product raw features: image and text features.
    # --------------------------
    fp = FeatureProcessor()
    img_feats, txt_feats = [], []
    # Use pretrained models for feature extraction
    resnet_model = models.resnet18(pretrained=True).eval().to(device)
    bert_model = DistilBertModel.from_pretrained('distilbert-base-uncased').eval().to(device)
    for idx, row in tqdm(articles.iterrows(), total=len(articles), desc="Precomputing product features"):
        try:
            # Extract image feature
            img = Image.open(row['image_path']).convert('RGB')
            img_tensor = fp.image_transform(img).unsqueeze(0).to(device)
            img_feat = resnet_model(img_tensor).squeeze().cpu()  # shape: [1000]
            img_feats.append(img_feat)
            
            # Extract text feature using DistilBERT
            text_enc = fp.tokenizer(row['detail_desc'], padding='max_length', truncation=True,
                                      max_length=config.text_max_length, return_tensors='pt').to(device)
            txt_feat = bert_model(**text_enc).last_hidden_state[:, 0].squeeze().cpu()  # shape: [768]
            txt_feats.append(txt_feat)
        except Exception as e:
            print(f"Error precomputing features for article {row['article_id']}: {e}")
            img_feats.append(torch.zeros(1000))
            txt_feats.append(torch.zeros(768))
    articles['img_feat'] = img_feats
    articles['txt_feat'] = txt_feats
    
    return sampled_customers, filtered_trans, articles

# --------------------------
# 2. Feature Processing (for text tokenization, etc.)
# --------------------------
class FeatureProcessor:
    def __init__(self):
        self.image_transform = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406],
                                 [0.229, 0.224, 0.225])
        ])
        self.tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
    def process_text(self, texts):
        return self.tokenizer(texts, padding='max_length', truncation=True,
                              max_length=config.text_max_length, return_tensors='pt')

# --------------------------
# Build product feature dictionary for fast lookup during training
# --------------------------
def build_product_feature_dict(articles):
    prod_dict = {}
    for idx, row in articles.iterrows():
        prod_dict[int(row['article_mapped_id'])] = {
            'img_feat': row['img_feat'],
            'txt_feat': row['txt_feat'],
            'price': torch.tensor(row['price'], dtype=torch.float32)
        }
    return prod_dict

# --------------------------
# 3. Graph Construction (now we also add user age)
# --------------------------
def build_graph(transactions, articles, customers):
    data = HeteroData()
    data['user'].num_nodes = len(customers)
    data['product'].num_nodes = len(articles)
    edge_index = torch.tensor([
        transactions['customer_mapped_id'].values,
        transactions['article_mapped_id'].values
    ], dtype=torch.long)
    edge_index = sort_edge_index(edge_index)
    data['user', 'buys', 'product'].edge_index = edge_index
    data['product', 'rev_buys', 'user'].edge_index = edge_index.flip(0)
    data['user'].x = torch.arange(len(customers), dtype=torch.long)
    data['user'].age = torch.tensor(customers['age'].values, dtype=torch.float32).unsqueeze(1)
    data['product'].x = torch.zeros(len(articles), dtype=torch.float32)
    data['product'].price = torch.tensor(articles.price.values, dtype=torch.float32).unsqueeze(1)
    return data


# --------------------------
# 4. Model Architecture with Age Incorporation
# --------------------------
from torch_geometric.nn import HGTConv

class MultiModalGNN(nn.Module):
    def __init__(self, metadata, num_users, num_products):
        super(MultiModalGNN, self).__init__()
        # User embedding for ID
        self.user_emb = nn.Embedding(num_users, config.emb_dim)
        # New age encoder for customer age (input dimension 1 to emb_dim)
        self.age_encoder = nn.Linear(1, config.emb_dim)
        # Product feature layers to combine image, text, and price features
        self.img_fc = nn.Linear(1000, config.emb_dim)
        self.txt_fc = nn.Linear(768, config.emb_dim)
        self.price_encoder = nn.Sequential(
            nn.Linear(1, 64),
            nn.ReLU(),
            nn.Linear(64, config.emb_dim)
        )
        # Graph Neural Network layers
        self.conv1 = HGTConv(config.emb_dim, config.emb_dim, metadata, heads=config.num_heads)
        self.conv2 = HGTConv(config.emb_dim, config.emb_dim, metadata, heads=config.num_heads)
        self.dropout = nn.Dropout(config.dropout)
    def forward(self, x_dict, edge_index_dict, user_age):
        # x_dict['user'] contains user IDs.
        user_ids = x_dict['user'].to(device)
        base_emb = self.user_emb(user_ids)  # [num_users, emb_dim]
        age_emb = self.age_encoder(user_age.to(device))  # [num_users, emb_dim]
        # Combine the ID embedding with the age embedding
        x_dict['user'] = base_emb + age_emb
        x_dict = self.conv1(x_dict, edge_index_dict)
        x_dict = {k: F.gelu(v) for k, v in x_dict.items()}
        x_dict = {k: self.dropout(v) for k, v in x_dict.items()}
        x_dict = self.conv2(x_dict, edge_index_dict)
        return x_dict

# --------------------------
# 5. Data Splitting
# --------------------------
def create_splits(transactions):
    splitter = GroupShuffleSplit(n_splits=1, test_size=0.3, random_state=42)
    train_idx, test_idx = next(splitter.split(transactions, groups=transactions['customer_mapped_id']))
    train_trans = transactions.iloc[train_idx]
    temp_trans = transactions.iloc[test_idx]
    splitter = GroupShuffleSplit(n_splits=1, test_size=0.5, random_state=42)
    val_idx, test_idx = next(splitter.split(temp_trans, groups=temp_trans['customer_mapped_id']))
    return train_trans, temp_trans.iloc[val_idx], temp_trans.iloc[test_idx]

# --------------------------
# 6. Custom Neighbor Loader (unchanged)
# --------------------------
def sample_neighbors(data: HeteroData, edge_type: tuple, src_nodes: torch.Tensor, num_samples: int) -> torch.Tensor:
    edge_index = data[edge_type].edge_index
    src = edge_index[0]
    tgt = edge_index[1]
    sampled_list = []
    for node in src_nodes.tolist():
        mask = (src == node)
        candidates = tgt[mask]
        if candidates.numel() == 0:
            continue
        if candidates.numel() > num_samples:
            perm = torch.randperm(candidates.numel())[:num_samples]
            sampled = candidates[perm]
        else:
            sampled = candidates
        sampled_list.append(sampled)
    if sampled_list:
        sampled_tgts = torch.cat(sampled_list)
    else:
        sampled_tgts = torch.tensor([], dtype=torch.long)
    return torch.unique(sampled_tgts)

class CustomNeighborLoader:
    def __init__(self, data: HeteroData, input_nodes: tuple, batch_size: int,
                 num_neighbors: dict, shuffle: bool = True):
        self.data = data
        self.input_nodes = input_nodes
        self.batch_size = batch_size
        self.num_neighbors = num_neighbors
        self.shuffle = shuffle
        self.node_type = input_nodes[0]
        self.node_indices = input_nodes[1]
        if self.shuffle:
            self.node_indices = self.node_indices[torch.randperm(self.node_indices.size(0))]
        self.num_batches = math.ceil(self.node_indices.size(0) / batch_size)
    
    def __len__(self):
        return self.num_batches
    
    def __iter__(self):
        for i in range(self.num_batches):
            batch_seed = self.node_indices[i * self.batch_size: (i+1) * self.batch_size]
            n1 = self.num_neighbors.get(('user', 'buys', 'product'), [0])[0]
            sampled_products = sample_neighbors(self.data, ('user', 'buys', 'product'), batch_seed, n1)
            n2 = self.num_neighbors.get(('product', 'rev_buys', 'user'), [0])[0]
            sampled_users_hop2 = sample_neighbors(self.data, ('product', 'rev_buys', 'user'), sampled_products, n2)
            final_users = torch.unique(torch.cat([batch_seed, sampled_users_hop2]))
            final_products = sampled_products
            sub_data = HeteroData()
            sorted_users, _ = torch.sort(final_users)
            user_map = {int(u.item()): i for i, u in enumerate(sorted_users)}
            sub_data['user'].num_nodes = sorted_users.size(0)
            sub_data['user'].x = sorted_users.clone().to(torch.long)
            # Also pass along user age from the full graph
            sub_data['user'].age = self.data['user'].age[sorted_users]

            sorted_products, _ = torch.sort(final_products)
            prod_map = {int(p.item()): i for i, p in enumerate(sorted_products)}
            sub_data['product'].num_nodes = sorted_products.size(0)
            sub_data['product'].x = torch.zeros(len(sorted_products), dtype=torch.float32)
            edge_index = self.data['user', 'buys', 'product'].edge_index
            mask = (torch.isin(edge_index[0], sorted_users) & torch.isin(edge_index[1], sorted_products))
            sub_edge_index = edge_index[:, mask].clone()
            for j in range(sub_edge_index.size(1)):
                src = int(sub_edge_index[0, j].item())
                tgt = int(sub_edge_index[1, j].item())
                sub_edge_index[0, j] = user_map[src]
                sub_edge_index[1, j] = prod_map[tgt]
            sub_data['user', 'buys', 'product'].edge_index = sub_edge_index
            edge_index = self.data['product', 'rev_buys', 'user'].edge_index
            mask = (torch.isin(edge_index[0], sorted_products) & torch.isin(edge_index[1], sorted_users))
            sub_edge_index = edge_index[:, mask].clone()
            for j in range(sub_edge_index.size(1)):
                src = int(sub_edge_index[0, j].item())
                tgt = int(sub_edge_index[1, j].item())
                sub_edge_index[0, j] = prod_map[src]
                sub_edge_index[1, j] = user_map[tgt]
            sub_data['product', 'rev_buys', 'user'].edge_index = sub_edge_index
            seed_mask = torch.zeros(sorted_users.size(0), dtype=torch.bool)
            for u in batch_seed.tolist():
                if u in user_map:
                    seed_mask[user_map[u]] = True
            sub_data['user'].seed_mask = seed_mask
            yield sub_data

# --------------------------
# 7. Training & Evaluation (Improved with In-Batch Negatives, LR Scheduler, Early Stopping)
# --------------------------
def train(model, train_data, val_data, optimizer, articles, prod_feature_dict, save_path='best_model.pth'):
    best_ndcg = -1
    scaler = torch.cuda.amp.GradScaler()
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2, verbose=True)
    epochs_no_improve = 0
    for epoch in range(config.epochs):
        model.train()
        epoch_loss = 0
        loader = CustomNeighborLoader(
            data=train_data,
            input_nodes=('user', torch.arange(train_data['user'].num_nodes, device='cpu')),
            batch_size=config.batch_size,
            num_neighbors={('user', 'buys', 'product'): [10],
                           ('product', 'rev_buys', 'user'): [5]},
            shuffle=True
        )
        print(f"\nStarting epoch {epoch+1}")
        for batch_idx, batch in enumerate(loader):
            batch = batch.to(device)
            # Quickly retrieve product raw features from the dictionary
            prod_indices = batch['product'].x.cpu().numpy().astype(int)
            img_feat_list, txt_feat_list, price_list = [], [], []
            for pid in prod_indices:
                feat = prod_feature_dict.get(pid)
                if feat is None:
                    continue
                img_feat_list.append(feat['img_feat'].to(device))
                txt_feat_list.append(feat['txt_feat'].to(device))
                price_list.append(feat['price'].to(device))
            if len(img_feat_list) == 0:
                print("No product features in this batch; skipping.")
                continue
            img_feats = torch.stack(img_feat_list)
            txt_feats = torch.stack(txt_feat_list)
            prices = torch.stack(price_list)
            img_emb = model.img_fc(img_feats)
            txt_emb = model.txt_fc(txt_feats)
            price_emb = model.price_encoder(prices.unsqueeze(1))
            prod_emb_batch = img_emb + txt_emb + price_emb
            batch['product'].x = prod_emb_batch
            optimizer.zero_grad()
            # Get user age from the batch
            user_age = batch['user'].age.to(device)
            with torch.amp.autocast('cuda',enabled=True):
                out = model(batch.x_dict, batch.edge_index_dict, user_age)
                # Get the positive edge indices (shape: [2, E])
                pos_edges = batch['user', 'buys', 'product'].edge_index  
                # Compute positive scores directly for each edge:
                user_pos = out['user'][pos_edges[0]]
                prod_pos = out['product'][pos_edges[1]]
                pos_scores = (user_pos * prod_pos).sum(dim=1)  # shape: [E]
                
                # For each positive edge, sample one negative product from the batch
                num_edges = pos_edges.size(1)
                neg_indices = torch.randint(0, out['product'].size(0), (num_edges,), device=device)
                neg_scores = (user_pos * out['product'][neg_indices]).sum(dim=1)
                

                # Use margin ranking loss (which encourages pos_scores > neg_scores + margin)
                margin = 0.2
                loss = F.margin_ranking_loss(pos_scores, neg_scores, target=torch.ones_like(pos_scores), margin=margin)

            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            epoch_loss += loss.item()
        avg_loss = epoch_loss / len(loader)
        if val_data is not None:
            ndcg, recall, auc, map12, recall12 = evaluate(model, val_data, articles)
            print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, NDCG@10={ndcg:.4f}, Recall@10={recall:.4f}, AUC={auc:.4f}, MAP@12={map12:.4f}, Recall@12={recall12:.4f}")
            scheduler.step(ndcg)
            if ndcg > best_ndcg:
                best_ndcg = ndcg
                epochs_no_improve = 0
                torch.save(model.state_dict(), save_path)
                print(f"Best model saved at epoch {epoch+1}")
            else:
                epochs_no_improve += 1
                if epochs_no_improve >= config.patience:
                    print("Early stopping triggered.")
                    break
        else:
            print(f"Epoch {epoch+1}: Loss={avg_loss:.4f} (No evaluation)")
    return best_ndcg

def evaluate(model, val_data, articles):
    model.eval()
    scaler = torch.cuda.amp.GradScaler()
    loader = CustomNeighborLoader(
        data=val_data,
        input_nodes=('user', torch.arange(val_data['user'].num_nodes, device='cpu')),
        batch_size=config.batch_size,
        num_neighbors={('user', 'buys', 'product'): [10],
                       ('product', 'rev_buys', 'user'): [5]},
        shuffle=False
    )
    all_ndcgs, all_recalls, all_aucs, all_maps_12, all_recalls_12 = [], [], [], [], []
    for batch_idx, batch in enumerate(loader):
        batch = batch.to(device)
        prod_indices = batch['product'].x.cpu().numpy().astype(int)
        img_feat_list, txt_feat_list, price_list = [], [], []
        for pid in prod_indices:
            feat = prod_feature_dict.get(pid)
            if feat is None:
                continue
            img_feat_list.append(feat['img_feat'].to(device))
            txt_feat_list.append(feat['txt_feat'].to(device))
            price_list.append(feat['price'].to(device))
        if len(img_feat_list) == 0:
            continue
        img_feats = torch.stack(img_feat_list)
        txt_feats = torch.stack(txt_feat_list)
        prices = torch.stack(price_list)
        img_emb = model.img_fc(img_feats)
        txt_emb = model.txt_fc(txt_feats)
        price_emb = model.price_encoder(prices.unsqueeze(1))
        batch['product'].x = (img_emb + txt_emb + price_emb).to(device)
        try:
            with torch.amp.autocast('cuda',enabled=True):
                # Get user age for evaluation as well
                user_age = batch['user'].age.to(device)
                out = model(batch.x_dict, batch.edge_index_dict, user_age)
            user_embeddings = out['user'].detach()
            product_embeddings = out['product'].detach()
            scores = torch.mm(user_embeddings, product_embeddings.t())
            
            pos_edges = batch['user', 'buys', 'product'].edge_index
            ndcg = calculate_ndcg(scores, pos_edges, k=10)
            recall = calculate_recall(scores, pos_edges, k=10)
            auc = calculate_auc(scores, pos_edges)
            map_12 = calculate_map(scores, pos_edges, k=12)
            recall_12 = calculate_recall(scores, pos_edges, k=12)
            all_ndcgs.append(ndcg)
            all_recalls.append(recall)
            all_aucs.append(auc)
            all_maps_12.append(map_12)
            all_recalls_12.append(recall_12)
        except Exception as e:
            print(f"Error during evaluation of batch {batch_idx+1}: {e}")
            continue
    return (np.mean(all_ndcgs) if all_ndcgs else 0,
            np.mean(all_recalls) if all_recalls else 0,
            np.mean(all_aucs) if all_aucs else 0,
            np.mean(all_maps_12) if all_maps_12 else 0,
            np.mean(all_recalls_12) if all_recalls_12 else 0)


# --------------------------
# 8. Metrics (unchanged)
# --------------------------
def calculate_ndcg(scores, edges, k=10):
    scores = scores.detach().cpu()
    user_items = {}
    for u, i in edges.t().tolist():
        if u not in user_items:
            user_items[u] = set()
        user_items[u].add(i)
    ndcgs = []
    for u in user_items.keys():
        relevant_items = user_items[u]
        if len(relevant_items) == 0:
            continue
        user_scores = scores[u]
        top_k_items = torch.topk(user_scores, k=k).indices.tolist()
        dcg = sum(1 / np.log2(rank + 2) for rank, item_id in enumerate(top_k_items) if item_id in relevant_items)
        ideal_dcg = sum(1 / np.log2(i + 2) for i in range(min(len(relevant_items), k)))
        if ideal_dcg > 0:
            ndcgs.append(dcg / ideal_dcg)
    return np.mean(ndcgs) if ndcgs else 0.0

def calculate_recall(scores, edges, k=10):
    user_items = {}
    for u, i in edges.t().tolist():
        user_items.setdefault(u, set()).add(i)
    recalls = []
    for u in range(scores.size(0)):
        pred = set(scores[u].argsort(descending=True)[:k].tolist())
        rel = user_items.get(u, set())
        if len(rel) == 0:
            continue
        recalls.append(len(pred & rel)/len(rel))
    return np.mean(recalls) if recalls else 0

def calculate_auc(scores, edges):
    pos_pairs = scores[edges[0], edges[1]]
    neg_pairs = scores[torch.randint(0, scores.size(0), (len(pos_pairs),)),
                       torch.randint(0, scores.size(1), (len(pos_pairs),))]
    y_true = torch.cat([torch.ones_like(pos_pairs), torch.zeros_like(neg_pairs)])
    y_score = torch.cat([pos_pairs.sigmoid(), neg_pairs.sigmoid()])
    return roc_auc_score(y_true.cpu().numpy(), y_score.cpu().numpy())

def calculate_map(scores, edges, k=12):
    user_items = {}
    for u, i in edges.t().tolist():
        user_items.setdefault(u, set()).add(i)
    maps = []
    for u in range(scores.size(0)):
        rel = user_items.get(u, set())
        if not rel: continue
        pred = scores[u].argsort(descending=True)[:k].tolist()
        hits, ap = 0, 0
        for i, item in enumerate(pred):
            if item in rel:
                hits += 1
                ap += hits / (i + 1)
        maps.append(ap / min(len(rel), k))
    return np.mean(maps) if maps else 0.0
# --------------------------
# 9. Cross-Validation (unchanged)
# --------------------------
def cross_validate(articles, customers, transactions):
    kf = GroupKFold(config.k_folds)
    for fold, (train_idx, val_idx) in enumerate(kf.split(transactions, groups=transactions['customer_mapped_id'])):
        print(f"\n=== Fold {fold+1} ===")
        
        fold_train_trans = transactions.iloc[train_idx]
        fold_val_trans = transactions.iloc[val_idx]
        
        # Extract user IDs from the split transactions
        train_users = fold_train_trans['customer_mapped_id'].unique()
        val_users = fold_val_trans['customer_mapped_id'].unique()

        # Keep only the users involved in this fold
        fold_customers = customers[customers['customer_mapped_id'].isin(np.concatenate([train_users, val_users]))].copy()

        # Reindex fold_customers to have contiguous IDs
        fold_customers = fold_customers.reset_index(drop=True)
        fold_customers['customer_mapped_id'] = fold_customers.index

        # Re-map the transactions accordingly
        user_id_map = dict(zip(fold_customers['customer_id'], fold_customers['customer_mapped_id']))
        fold_train_trans = fold_train_trans.copy()
        fold_val_trans = fold_val_trans.copy()
        fold_train_trans['customer_mapped_id'] = fold_train_trans['customer_id'].map(user_id_map)
        fold_val_trans['customer_mapped_id'] = fold_val_trans['customer_id'].map(user_id_map)

        # Filter articles
        fold_article_ids = set(fold_train_trans.article_mapped_id).union(set(fold_val_trans.article_mapped_id))
        fold_articles = articles[articles.article_mapped_id.isin(fold_article_ids)].copy()

        train_data = build_graph(fold_train_trans, fold_articles, fold_customers)
        val_data = build_graph(fold_val_trans, fold_articles, fold_customers)

        num_users = fold_customers['customer_mapped_id'].max() + 1
        num_products = fold_articles['article_mapped_id'].max() + 1
        model = MultiModalGNN(train_data.metadata(), num_users, num_products).to(device)
        optimizer = optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
        _ = train(model, train_data, val_data, optimizer, fold_articles, prod_feature_dict, save_path=f"best_model_fold{fold+1}.pth")
        del model, optimizer
        torch.cuda.empty_cache()
    print("\nCross-validation complete.")




In [10]:
import pickle
# --------------------------
# 10. Main Execution: Step-by-Step
# --------------------------
if __name__ == "__main__":
    # Step 1: Preprocess data
    # sampled_customers, transactions, articles = load_and_preprocess()
    PREPROCESSED_DIR = "/kaggle/input/preprocessed-data-7/"

    # Load dataframes
    articles = pd.read_pickle(os.path.join(PREPROCESSED_DIR, "articles.pkl"))
    customers = pd.read_pickle(os.path.join(PREPROCESSED_DIR, "customers.pkl"))
    transactions = pd.read_pickle(os.path.join(PREPROCESSED_DIR, "transactions.pkl"))
    
    print(f"Articles: {len(articles)}, Customers: {len(customers)}, Transactions: {len(transactions)}")
    
    # Make this dictionary globally accessible for evaluate() and train()
    global prod_feature_dict

    with open("/kaggle/input/prod-feature-dict/prod_feature_dict.pkl", "rb") as f:
        prod_feature_dict = pickle.load(f)

    # # Step 2: Build product feature dictionary from precomputed features
    # prod_feature_dict = build_product_feature_dict(articles)
    # # Define a directory to save preprocessed files
    # PROD_FEATUREDICT_DIR = "/kaggle/working/prod_feature_dict"
    # ZIP_FILE = "/kaggle/working/prod_feature_dict.zip"
    # os.makedirs(PROD_FEATUREDICT_DIR, exist_ok=True)
    
    # # Save dataframes (articles, customers, transactions)
    # with open(os.path.join(PROD_FEATUREDICT_DIR, "prod_feature_dict.pkl"), 'wb') as f:
    #     pickle.dump(prod_feature_dict, f)
    
    # # ✅ Create a ZIP archive containing all preprocessed files
    # with zipfile.ZipFile(ZIP_FILE, 'w') as zipf:
    #     for file in os.listdir(PROD_FEATUREDICT_DIR):
    #         zipf.write(os.path.join(PROD_FEATUREDICT_DIR, file), arcname=file)
    
    # print(f"✅ MULTI completed! Saved as {ZIP_FILE}")
    
    # Step 3: Create train/val/test splits
    train_trans, val_trans, test_trans = create_splits(transactions)
    # Step 4: Cross-validate on training subset (optional)
    print("Starting cross-validation on training data...")
    cross_validate(articles, customers, train_trans)
    print("Cross-validation complete.")
     # ✅ Step A: Combine train + val transactions
    trainval_trans = pd.concat([train_trans, val_trans], ignore_index=True)
    
    # ✅ Step B: Rebuild graph using train + val data
    trainval_data = build_graph(trainval_trans, articles, customers)
    
    # ✅ Step C: Initialize model with same architecture
    model = MultiModalGNN(trainval_data.metadata(), customers['customer_mapped_id'].max() + 1, articles['article_mapped_id'].max() + 1).to(device)
    #  # ✅ Step D: Load best weights from CV
    # model.load_state_dict(torch.load("best_model.pth"))
    # print("✅ Loaded best model from CV.")
    # # Step 5: Build graphs for train, validation, and test splits
    # # train_data = build_graph(train_trans, articles, customers)
    # # val_data   = build_graph(val_trans, articles, customers)
    # test_data  = build_graph(test_trans, articles, customers)    
    # # Step 6: Initialize model with metadata from train_data and incorporate user age
    # # num_users = sampled_customers['customer_mapped_id'].max() + 1
    # # num_products = articles['article_mapped_id'].max() + 1
    # # model = MultiModalGNN(train_data.metadata(), num_users, num_products).to(device)
    # optimizer = optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
    
    # # Step 7: Train model using the improved training loop (with in-batch negatives, LR scheduler, and early stopping)
    # best_ndcg = train(model, train_data, val_data, optimizer, articles, prod_feature_dict)
    
    # # Step 8: Evaluate on test set
    # ndcg, recall, auc = evaluate(model, test_data, articles)
    # print(f"\nFinal Test Performance:")
    # print(f"NDCG@10: {ndcg:.4f}")
    # print(f"Recall@10: {recall:.4f}")
    # print(f"AUC: {auc:.4f}")
    
    # # Step 9: Save final model
    # torch.save({
    #     'state_dict': model.state_dict(),
    #     'metadata': train_data.metadata(),
    #     'config': config.__dict__
    # }, "final_model.pth")
    # print("Training and evaluation complete.")

Articles: 80654, Customers: 88647, Transactions: 2092109
Starting cross-validation on training data...

=== Fold 1 ===


  edge_index = torch.tensor([
  scaler = torch.cuda.amp.GradScaler()



Starting epoch 1


  scaler = torch.cuda.amp.GradScaler()


Epoch 1: Loss=0.4435, NDCG@10=0.0519, Recall@10=0.0812, AUC=0.5980, MAP@12=0.0285, Recall@12=0.0949
Best model saved at epoch 1

Starting epoch 2


  scaler = torch.cuda.amp.GradScaler()


Epoch 2: Loss=0.0200, NDCG@10=0.0580, Recall@10=0.0925, AUC=0.5771, MAP@12=0.0320, Recall@12=0.1086
Best model saved at epoch 2

Starting epoch 3


  scaler = torch.cuda.amp.GradScaler()


Epoch 3: Loss=0.0178, NDCG@10=0.0805, Recall@10=0.1293, AUC=0.5723, MAP@12=0.0465, Recall@12=0.1506
Best model saved at epoch 3

Starting epoch 4


  scaler = torch.cuda.amp.GradScaler()


Epoch 4: Loss=0.0162, NDCG@10=0.0812, Recall@10=0.1304, AUC=0.5756, MAP@12=0.0471, Recall@12=0.1515
Best model saved at epoch 4

Starting epoch 5


  scaler = torch.cuda.amp.GradScaler()


Epoch 5: Loss=0.0147, NDCG@10=0.0812, Recall@10=0.1302, AUC=0.5722, MAP@12=0.0473, Recall@12=0.1513
Best model saved at epoch 5

Starting epoch 6
No product features in this batch; skipping.


  scaler = torch.cuda.amp.GradScaler()


Epoch 6: Loss=0.0135, NDCG@10=0.0864, Recall@10=0.1373, AUC=0.5700, MAP@12=0.0505, Recall@12=0.1589
Best model saved at epoch 6

Starting epoch 7


  scaler = torch.cuda.amp.GradScaler()


Error during evaluation of batch 433: selected index k out of range
Epoch 7: Loss=0.0130, NDCG@10=0.0861, Recall@10=0.1363, AUC=0.5816, MAP@12=0.0499, Recall@12=0.1581

Starting epoch 8


  scaler = torch.cuda.amp.GradScaler()


Epoch 8: Loss=0.0126, NDCG@10=0.0839, Recall@10=0.1340, AUC=0.5799, MAP@12=0.0488, Recall@12=0.1555

Starting epoch 9


  scaler = torch.cuda.amp.GradScaler()


Error during evaluation of batch 433: selected index k out of range
Epoch 9: Loss=0.0121, NDCG@10=0.0882, Recall@10=0.1402, AUC=0.5870, MAP@12=0.0511, Recall@12=0.1630
Best model saved at epoch 9

Starting epoch 10


  scaler = torch.cuda.amp.GradScaler()


Error during evaluation of batch 433: selected index k out of range
Epoch 10: Loss=0.0117, NDCG@10=0.0968, Recall@10=0.1538, AUC=0.5870, MAP@12=0.0566, Recall@12=0.1786
Best model saved at epoch 10

Starting epoch 11


  scaler = torch.cuda.amp.GradScaler()


Epoch 11: Loss=0.0114, NDCG@10=0.0925, Recall@10=0.1482, AUC=0.5900, MAP@12=0.0541, Recall@12=0.1720

Starting epoch 12
No product features in this batch; skipping.


  scaler = torch.cuda.amp.GradScaler()


Epoch 12: Loss=0.0111, NDCG@10=0.0946, Recall@10=0.1514, AUC=0.5915, MAP@12=0.0556, Recall@12=0.1756

Starting epoch 13


  scaler = torch.cuda.amp.GradScaler()


Epoch 13: Loss=0.0111, NDCG@10=0.0968, Recall@10=0.1544, AUC=0.5994, MAP@12=0.0569, Recall@12=0.1789
Best model saved at epoch 13

Starting epoch 14


  scaler = torch.cuda.amp.GradScaler()


Epoch 14: Loss=0.0107, NDCG@10=0.0974, Recall@10=0.1554, AUC=0.5894, MAP@12=0.0578, Recall@12=0.1801
Best model saved at epoch 14

=== Fold 2 ===


  scaler = torch.cuda.amp.GradScaler()



Starting epoch 1


  scaler = torch.cuda.amp.GradScaler()


Epoch 1: Loss=0.4151, NDCG@10=0.0504, Recall@10=0.0773, AUC=0.6050, MAP@12=0.0273, Recall@12=0.0900
Best model saved at epoch 1

Starting epoch 2


  scaler = torch.cuda.amp.GradScaler()


Epoch 2: Loss=0.0260, NDCG@10=0.0588, Recall@10=0.0915, AUC=0.5795, MAP@12=0.0324, Recall@12=0.1061
Best model saved at epoch 2

Starting epoch 3


  scaler = torch.cuda.amp.GradScaler()


Epoch 3: Loss=0.0215, NDCG@10=0.0662, Recall@10=0.1023, AUC=0.5635, MAP@12=0.0367, Recall@12=0.1194
Best model saved at epoch 3

Starting epoch 4
No product features in this batch; skipping.


  scaler = torch.cuda.amp.GradScaler()


Epoch 4: Loss=0.0194, NDCG@10=0.0710, Recall@10=0.1098, AUC=0.5716, MAP@12=0.0396, Recall@12=0.1275
Best model saved at epoch 4

Starting epoch 5


  scaler = torch.cuda.amp.GradScaler()


Epoch 5: Loss=0.0174, NDCG@10=0.0697, Recall@10=0.1072, AUC=0.5793, MAP@12=0.0387, Recall@12=0.1249

Starting epoch 6


  scaler = torch.cuda.amp.GradScaler()


Epoch 6: Loss=0.0161, NDCG@10=0.0697, Recall@10=0.1071, AUC=0.5765, MAP@12=0.0387, Recall@12=0.1250

Starting epoch 7


  scaler = torch.cuda.amp.GradScaler()


Epoch 7: Loss=0.0155, NDCG@10=0.0725, Recall@10=0.1117, AUC=0.5681, MAP@12=0.0405, Recall@12=0.1301
Best model saved at epoch 7

Starting epoch 8
No product features in this batch; skipping.


  scaler = torch.cuda.amp.GradScaler()


Epoch 8: Loss=0.0148, NDCG@10=0.0727, Recall@10=0.1121, AUC=0.5755, MAP@12=0.0405, Recall@12=0.1307
Best model saved at epoch 8

Starting epoch 9


  scaler = torch.cuda.amp.GradScaler()


Epoch 9: Loss=0.0143, NDCG@10=0.0743, Recall@10=0.1142, AUC=0.5861, MAP@12=0.0416, Recall@12=0.1331
Best model saved at epoch 9

Starting epoch 10


  scaler = torch.cuda.amp.GradScaler()


Epoch 10: Loss=0.0138, NDCG@10=0.0747, Recall@10=0.1149, AUC=0.5910, MAP@12=0.0417, Recall@12=0.1343
Best model saved at epoch 10

Starting epoch 11
No product features in this batch; skipping.


  scaler = torch.cuda.amp.GradScaler()


Epoch 11: Loss=0.0133, NDCG@10=0.0756, Recall@10=0.1160, AUC=0.5889, MAP@12=0.0424, Recall@12=0.1349
Best model saved at epoch 11

Starting epoch 12


  scaler = torch.cuda.amp.GradScaler()


Epoch 12: Loss=0.0130, NDCG@10=0.0764, Recall@10=0.1167, AUC=0.5955, MAP@12=0.0429, Recall@12=0.1355
Best model saved at epoch 12

Starting epoch 13


  scaler = torch.cuda.amp.GradScaler()


Epoch 13: Loss=0.0132, NDCG@10=0.0768, Recall@10=0.1177, AUC=0.6022, MAP@12=0.0431, Recall@12=0.1370
Best model saved at epoch 13

Starting epoch 14


  scaler = torch.cuda.amp.GradScaler()


Epoch 14: Loss=0.0127, NDCG@10=0.0767, Recall@10=0.1181, AUC=0.6010, MAP@12=0.0429, Recall@12=0.1376

Cross-validation complete.
Cross-validation complete.


fold 1
Epoch 14: Loss=0.0107, NDCG@10=0.0974, Recall@10=0.1554, AUC=0.5894, MAP@12=0.0578, Recall@12=0.1801
fold 2
Epoch 14: Loss=0.0127, NDCG@10=0.0767, Recall@10=0.1181, AUC=0.6010, MAP@12=0.0429, Recall@12=0.1376


In [16]:
import pickle
global prod_feature_dict

with open("/kaggle/input/prod-feature-dict/prod_feature_dict.pkl", "rb") as f:
    prod_feature_dict = pickle.load(f)

PREPROCESSED_DIR = "/kaggle/input/preprocessed-data-7/"
articles = pd.read_pickle(os.path.join(PREPROCESSED_DIR, "articles.pkl"))
customers = pd.read_pickle(os.path.join(PREPROCESSED_DIR, "customers.pkl"))
transactions = pd.read_pickle(os.path.join(PREPROCESSED_DIR, "transactions.pkl"))
train_trans, val_trans, test_trans = create_splits(transactions)

trainval_trans = pd.concat([train_trans, val_trans], ignore_index=True)

# ✅ Step B: Rebuild graph using train + val data
trainval_data = build_graph(trainval_trans, articles, customers)

# ✅ Step C: Initialize model with same architecture
model = MultiModalGNN(trainval_data.metadata(), customers['customer_mapped_id'].max() + 1, articles['article_mapped_id'].max() + 1).to(device)


# Load the checkpoint
checkpoint = torch.load("/kaggle/input/cold-start-rec-modal/other/default/1/best_model_fold1.pth", weights_only=True)
state_dict = checkpoint

# Get the checkpoint user embeddings and current model's user embeddings
old_user_emb = state_dict["user_emb.weight"]   # Shape: [55299, 256]
new_user_emb = model.user_emb.weight             # Shape: [88647, 256]

# Check how many rows to copy
num_overlap = old_user_emb.size(0)

# Replace the first num_overlap rows of the new model's embedding with the loaded weights
new_user_emb.data[:num_overlap] = old_user_emb

# Optionally, leave the remaining rows as is (randomly initialized) or initialize them as desired
state_dict["user_emb.weight"] = new_user_emb

# Now load state_dict with strict=False to avoid errors on the rest of the mismatched keys
model.load_state_dict(state_dict, strict=False)

print("✅ Loaded best model from CV.")

# ✅ Step E: Reinitialize optimizer (for retraining)
optimizer = optim.AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)

# ✅ Step F: Retrain model on train + val (no val_data used now)
print("\n🔁 Retraining on Train + Validation set...")
best_ndcg = train(model, trainval_data, val_data=None, optimizer=optimizer, articles=articles, prod_feature_dict=prod_feature_dict)
print("best_ndcg", best_ndcg)
# ✅ Step G: Evaluate on the test set
print("\n🧪 Final Evaluation on Test Set")
test_data  = build_graph(test_trans, articles, customers)    

ndcg, recall, auc, map12, recall12 = evaluate(model, test_data, articles)
print(f"\n✅ Final Test Performance:")
print(f"NDCG@10: {ndcg:.4f}")
print(f"Recall@10: {recall:.4f}")
print(f"AUC: {auc:.4f}")
print(f"MAP@12: {map12:.4f}")
print(f"Recall@12: {recall12:.4f}")

# ✅ Step H: Save the retrained final model
torch.save({
    'state_dict': model.state_dict(),
    'metadata': trainval_data.metadata(),
    'config': config.__dict__
}, "final_model_retrained.pth")


✅ Loaded best model from CV.

🔁 Retraining on Train + Validation set...

Starting epoch 1


  scaler = torch.cuda.amp.GradScaler()


Epoch 1: Loss=0.0144 (No evaluation)

Starting epoch 2
Epoch 2: Loss=0.0139 (No evaluation)

Starting epoch 3
Epoch 3: Loss=0.0135 (No evaluation)

Starting epoch 4
Epoch 4: Loss=0.0132 (No evaluation)

Starting epoch 5
Epoch 5: Loss=0.0130 (No evaluation)

Starting epoch 6
Epoch 6: Loss=0.0127 (No evaluation)

Starting epoch 7
Epoch 7: Loss=0.0125 (No evaluation)

Starting epoch 8
Epoch 8: Loss=0.0123 (No evaluation)

Starting epoch 9
Epoch 9: Loss=0.0122 (No evaluation)

Starting epoch 10
Epoch 10: Loss=0.0121 (No evaluation)

Starting epoch 11
Epoch 11: Loss=0.0122 (No evaluation)

Starting epoch 12
Epoch 12: Loss=0.0120 (No evaluation)

Starting epoch 13
Epoch 13: Loss=0.0120 (No evaluation)

Starting epoch 14
Epoch 14: Loss=0.0120 (No evaluation)
best_ndcg -1

🧪 Final Evaluation on Test Set


  scaler = torch.cuda.amp.GradScaler()



✅ Final Test Performance:
NDCG@10: 0.2743
Recall@10: 0.4915
AUC: 0.6716
MAP@12: 0.1916
Recall@12: 0.5502


In [None]:
batch_size = 128
emb_dim = 256
fold 1
Epoch 1: Loss=0.3129, NDCG=0.0806, Recall=0.1324, AUC=0.5680
Epoch 2: Loss=0.0271, NDCG=0.0975, Recall=0.1651, AUC=0.5530
Epoch 3: Loss=0.0230, NDCG=0.1106, Recall=0.1876, AUC=0.5580
fold 2
Epoch 1: Loss=0.3602, NDCG=0.0699, Recall=0.1122, AUC=0.5223
Epoch 2: Loss=0.0305, NDCG=0.0835, Recall=0.1372, AUC=0.5217   
Epoch 3: Loss=0.0231, NDCG=0.0859, Recall=0.1403, AUC=0.5170

In [None]:
14