In [9]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""    # Force CPU-only

import pickle
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import HeteroData
from torch_geometric.utils import sort_edge_index
from torch_geometric.nn import HGTConv

# All computation on CPU
device = torch.device("cpu")

# --------------------------
# 1. CONFIG (must match your training config)
# --------------------------
class Config:
    emb_dim   = 256
    num_heads = 4
    dropout   = 0.5

config = Config()

# --------------------------
# 2. GRAPH CONSTRUCTION
# --------------------------
def build_graph(transactions, articles, customers):
    data = HeteroData()
    n_users    = len(customers)
    n_products = len(articles)

    # User nodes
    data['user'].num_nodes = n_users
    data['user'].x         = torch.arange(n_users, dtype=torch.long)
    data['user'].age       = torch.tensor(customers['age'].values, dtype=torch.float32).unsqueeze(1)

    # Product nodes
    data['product'].num_nodes = n_products
    data['product'].x         = torch.zeros(n_products, dtype=torch.float32)
    data['product'].price     = torch.tensor(articles['price'].values, dtype=torch.float32).unsqueeze(1)

    # Edges: user -> buys -> product
    src = torch.tensor(transactions['customer_mapped_id'].values, dtype=torch.long)
    dst = torch.tensor(transactions['article_mapped_id'].values,  dtype=torch.long)
    edge_index = torch.stack([src, dst], dim=0)
    edge_index = sort_edge_index(edge_index)
    data['user','buys','product'].edge_index    = edge_index
    data['product','rev_buys','user'].edge_index = edge_index.flip(0)

    return data

# --------------------------
# 3. MODEL ARCHITECTURE
# --------------------------
class MultiModalGNN(nn.Module):
    def __init__(self, metadata, num_users, num_products):
        super().__init__()
        # Must match your checkpoint
        self.user_emb      = nn.Embedding(num_users, config.emb_dim)
        self.age_encoder   = nn.Linear(1, config.emb_dim)
        self.img_fc        = nn.Linear(1000, config.emb_dim)
        self.txt_fc        = nn.Linear(768, config.emb_dim)
        self.price_encoder = nn.Sequential(
            nn.Linear(1, 64),
            nn.ReLU(),
            nn.Linear(64, config.emb_dim)
        )
        self.conv1   = HGTConv(config.emb_dim, config.emb_dim, metadata, heads=config.num_heads)
        self.conv2   = HGTConv(config.emb_dim, config.emb_dim, metadata, heads=config.num_heads)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x_dict, edge_index_dict, user_age):
        # Combine ID + age embeddings
        id_emb  = self.user_emb(x_dict['user'])
        age_emb = self.age_encoder(user_age)
        x_dict['user'] = id_emb + age_emb

        # GNN layers
        x_dict = self.conv1(x_dict, edge_index_dict)
        x_dict = {k: F.gelu(v) for k, v in x_dict.items()}
        x_dict = {k: self.dropout(v) for k, v in x_dict.items()}
        x_dict = self.conv2(x_dict, edge_index_dict)
        return x_dict

# --------------------------
# 4. MAIN
# --------------------------
if __name__ == "__main__":
    # Paths to your preprocessed pickles in Kaggle
    PRE = "/kaggle/input/preprocessed-data-7"
    articles     = pd.read_pickle(os.path.join(PRE, "articles.pkl"))
    customers    = pd.read_pickle(os.path.join(PRE, "customers.pkl"))
    transactions = pd.read_pickle(os.path.join(PRE, "transactions.pkl"))

    # 4a) Build & save full_graph.pkl
    full_graph = build_graph(transactions, articles, customers)
    with open("full_graph.pkl", "wb") as f:
        pickle.dump(full_graph, f)
    print("✅ Saved full_graph.pkl")

    # 4b) Load trained model checkpoint
    ckpt = torch.load(
        "/kaggle/input/cold-start-gnn-modal/pytorch/default/1/final_model_retrained.pth",
        map_location="cpu"
    )
    model = MultiModalGNN(
        metadata    = ckpt['metadata'],
        num_users   = full_graph['user'].num_nodes,
        num_products= full_graph['product'].num_nodes
    ).to(device)
    model.load_state_dict(ckpt['state_dict'])
    model.eval()

    # 4c) Load product features and compute 256-dim embeddings
    with open("/kaggle/input/prod-feature-dict/prod_feature_dict.pkl", "rb") as f:
        prod_feature_dict = pickle.load(f)

    prod_embeddings = []
    for pid in range(full_graph['product'].num_nodes):
        feat = prod_feature_dict[pid]
        img_emb   = model.img_fc(feat['img_feat'])
        txt_emb   = model.txt_fc(feat['txt_feat'])
        price_emb = model.price_encoder(feat['price'].unsqueeze(0)).squeeze(0)
        prod_embeddings.append(img_emb + txt_emb + price_emb)
    prod_embeddings_tensor = torch.stack(prod_embeddings, dim=0)  # [num_products × emb_dim]

    # 4d) Run full-graph GNN to get user embeddings
    x_dict = {
        'user':    full_graph['user'].x,
        'product': prod_embeddings_tensor
    }
    edge_idx = {
        etype: eidx for etype, eidx in full_graph.edge_index_dict.items()
    }

    with torch.no_grad():
        out = model(x_dict, edge_idx, full_graph['user'].age)
        user_embs = out['user'].numpy()                         
        user_ages = full_graph['user'].age.numpy().squeeze() 

    # 4e) Group by age & average
    from collections import defaultdict
    age_to_list = defaultdict(list)
    for emb, age in zip(user_embs, user_ages):
        age_to_list[int(age)].append(emb)

    age_to_avg = {
        age: np.mean(embs, axis=0)
        for age, embs in age_to_list.items()
    }
    
    with open("age_to_avg_user_embedding.pkl", "wb") as f:
        pickle.dump(age_to_avg, f)
    print("✅ Saved age_to_avg_user_embedding.pkl")




✅ Saved full_graph.pkl


  ckpt = torch.load(


✅ Saved age_to_avg_user_embedding.pkl
