# Head Weight Estimation

This notebook esimates weights for four heads in the recommender system's value model. A value model uses these four heads as targets and learns to predict the probability of each head for every user–video pair. At serving time, the system can combine the predicted head probabilities into a single value score using a weighted sum and then rank candidate videos by this value score:
$$
\text{score} \;=\; w_1 \,\hat y_{\text{complete}}
               \;+\; w_2 \,\hat y_{\text{long}}
               \;+\; w_3 \,\hat y_{\text{rewatch}}
               \;-\; w_4 \,\hat y_{\text{neg}},
$$

Head predictions can be done using GNN, which parameters have been stored.

### 1. Extract (level 1) categorical distribution for each user session

In [1]:
import pandas as pd
import numpy as np
from pathlib import Path
from time import time

pd.set_option('future.no_silent_downcasting', True)

BASE = Path("/Users/haozhangao/Desktop/RecSys Research/KuaiRec 2.0/data/processed")
DATA_PATH = BASE / "processed_data.parquet"

df = pd.read_parquet(DATA_PATH)
df.head()

Unnamed: 0,user_id,video_id,play_duration,time,date,timestamp,watch_ratio,burst_id,session,sess_rank,...,hist_ema_y_rewatch,hist_ema_y_neg,hist_ema_watchratio,hist_cat_ema_complete,hist_cat_entropy_l2,hist_author_recency_days,hist_last_complete_author,hist_has_author_history,hist_prev_sess_len,hist_intersess_gap_h
0,0,3649,13838,2020-07-05 00:08:23.438,20200705,1593879000.0,1.273397,1,1,1,...,,,,,0.0,,,0,0.0,0.0
1,0,9598,13665,2020-07-05 00:13:41.297,20200705,1593879000.0,1.244082,1,1,2,...,0.0,0.0,1.273396,1.0,-0.0,,,0,0.0,0.0
2,0,5262,851,2020-07-05 00:16:06.687,20200705,1593879000.0,0.107613,1,1,3,...,0.0,0.0,1.270465,,-0.0,,,0,0.0,0.0
3,0,1963,862,2020-07-05 00:20:26.792,20200705,1593880000.0,0.089885,1,1,4,...,0.0,0.1,1.15418,1.0,0.636514,,,0,0.0,0.0
4,0,8234,858,2020-07-05 00:43:05.128,20200705,1593881000.0,0.078,1,1,5,...,0.0,0.19,1.04775,,0.562335,,,0,0.0,0.0


In [2]:
# There are some number of rows with missing L1 category
n_nan = df["i_cat_level1_id"].isna().sum()
n_total = len(df)

print(n_nan, n_total)

df["i_cat_level1_id"] = df["i_cat_level1_id"].fillna(-124) # replace it with -124 (UNKNOWN)
df["i_cat_level1_name"] = (
    df["i_cat_level1_name"]
    .replace({None: "UNKNOWN"})
    .fillna("UNKNOWN")
)

0 12527912


In [3]:
# Map for level-1 category
cat_l1_map = (
    df[["i_cat_level1_id", "i_cat_level1_name"]]
    .drop_duplicates()
    .sort_values("i_cat_level1_id")
)

print(len(cat_l1_map))
cat_l1_map

39


Unnamed: 0,i_cat_level1_id,i_cat_level1_name
331,-124.0,UNKNOWN
6,1.0,舞蹈
118,2.0,音乐
611,3.0,游戏
20,4.0,美妆
2,5.0,时尚
4,6.0,明星娱乐
127,7.0,运动
64,8.0,颜值
11,9.0,喜剧


In [4]:
# Create level-1 categorical distribution for each session

# 1) counts by user–session–L1 category
counts = (
    df.groupby(["user_id", "session", "i_cat_level1_id"])
      .size()
      .reset_index(name="n_cat")
)

# 2) grab session length from the original df
sess_len = (
    df.groupby(["user_id", "session"])["sess_len"]
      .first()                      # same within a session
      .rename("sess_total")
      .reset_index()
)

# 3) merge and compute probabilities
counts = counts.merge(sess_len, on=["user_id", "session"], how="left")
counts["p_cat"] = counts["n_cat"] / counts["sess_total"]

# 4) all possible L1 categories (including -124 for UNKNOWN)
ALL_L1_CATS = np.sort(df["i_cat_level1_id"].unique())

# 5) build observed category distribution for ALL sessions
obs_cat_all = (
    counts
    .pivot_table(
        index=["user_id", "session"],
        columns="i_cat_level1_id",
        values="p_cat",
        aggfunc="sum",
        fill_value=0.0,
    )
)

# 6) make sure columns are in a fixed order and include all categories
obs_cat_all = obs_cat_all.reindex(columns=ALL_L1_CATS, fill_value=0.0)

obs_cat_all

Unnamed: 0_level_0,i_cat_level1_id,-124.0,1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,...,29.0,31.0,32.0,33.0,34.0,35.0,36.0,37.0,38.0,39.0
user_id,session,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
0,1,0.0,0.000000,0.0,0.0,0.00,0.166667,0.166667,0.0,0.000000,0.000000,...,0.0,0.000000,0.0,0.000000,0.166667,0.0,0.0,0.0,0.0,0.0
0,2,0.0,1.000000,0.0,0.0,0.00,0.000000,0.000000,0.0,0.000000,0.000000,...,0.0,0.000000,0.0,0.000000,0.000000,0.0,0.0,0.0,0.0,0.0
0,3,0.0,0.000000,0.0,0.0,0.00,0.000000,0.000000,0.0,0.000000,0.000000,...,0.0,0.000000,0.0,0.000000,0.000000,0.0,0.0,0.0,0.0,0.0
0,4,0.0,0.100000,0.0,0.0,0.05,0.000000,0.050000,0.0,0.000000,0.100000,...,0.0,0.000000,0.0,0.000000,0.050000,0.0,0.0,0.0,0.0,0.0
0,5,0.0,0.090909,0.0,0.0,0.00,0.036364,0.018182,0.0,0.018182,0.090909,...,0.0,0.018182,0.0,0.036364,0.036364,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
7175,110,0.0,0.000000,0.0,0.0,0.00,0.000000,1.000000,0.0,0.000000,0.000000,...,0.0,0.000000,0.0,0.000000,0.000000,0.0,0.0,0.0,0.0,0.0
7175,111,0.0,0.000000,0.0,0.0,0.00,0.000000,0.000000,0.0,0.500000,0.000000,...,0.0,0.000000,0.0,0.000000,0.000000,0.0,0.0,0.0,0.0,0.0
7175,112,0.0,0.000000,0.0,0.0,0.00,0.333333,0.000000,0.0,0.000000,0.000000,...,0.0,0.000000,0.0,0.000000,0.000000,0.0,0.0,0.0,0.0,0.0
7175,113,0.0,0.000000,0.0,0.0,0.00,0.000000,0.000000,0.0,1.000000,0.000000,...,0.0,0.000000,0.0,0.000000,0.000000,0.0,0.0,0.0,0.0,0.0


### 2. Predict 4 heads using estimated GNN parameters

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GNN(nn.Module):
    """
    GNN that learns node embeddings with three GCN layers and predicts
    4 edge-level scores by feeding [h_src, h_dst, edge_attr] into an MLP.
    """
    def __init__(self, num_nodes, num_edge_features, hidden_dim=64):
        super().__init__()

        self.node_emb = nn.Embedding(num_nodes, hidden_dim)

        self.gcn1 = GCNConv(hidden_dim, hidden_dim)
        self.gcn2 = GCNConv(hidden_dim, hidden_dim)
        self.gcn3 = GCNConv(hidden_dim, hidden_dim)

        self.edge_mlp = nn.Sequential(
            nn.Linear(2 * hidden_dim + num_edge_features, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 4)
        )

    def encode_nodes(self, edge_index):
        x = self.node_emb.weight
        x = F.relu(self.gcn1(x, edge_index))
        x = F.relu(self.gcn2(x, edge_index))
        x = F.relu(self.gcn3(x, edge_index))
        return x                    # [num_nodes, hidden_dim]

    def score_edges(self, x, edge_index, edge_attr):
        src, dst = edge_index
        h_src = x[src]
        h_dst = x[dst]

        h_edge = torch.cat([h_src, h_dst, edge_attr], dim=-1)
        logits = self.edge_mlp(h_edge)    # [E_batch, 4]
        return logits

In [6]:
# Load graph tensors used in training
gnn_data_path = BASE / "gnn_data.pt"

data_gnn = torch.load(gnn_data_path)

edge_index = data_gnn["edge_index"]          # [2, E]
train_idx  = data_gnn["train_idx"]
test_idx   = data_gnn["test_idx"]
edge_attr  = torch.nan_to_num(data_gnn["edge_attr"], nan=0.0)  # [E, D_edge]
num_nodes  = data_gnn["num_nodes"]
D_edge     = edge_attr.size(1)

with torch.no_grad():
    mu  = edge_attr[train_idx].mean(dim=0, keepdim=True)
    std = edge_attr[train_idx].std(dim=0, keepdim=True)
    std[std < 1e-6] = 1.0          # avoid divide by 0
    edge_attr = (edge_attr - mu) / std
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
edge_index = edge_index.to(device)
edge_attr  = edge_attr.to(device)

In [7]:
# Rebuild model and load checkpoint
SAVE_DIR  = Path("/Users/haozhangao/Desktop/RecSys Research/KuaiRec 2.0/models/GNN")
ckpt_path = SAVE_DIR / "GNN_multihead_best.pt"

ckpt = torch.load(ckpt_path, map_location=device)

model = GNN(
    num_nodes=num_nodes,
    num_edge_features=D_edge,
    hidden_dim=128,
).to(device)

model.load_state_dict(ckpt["state_dict"])
model.eval()


GNN(
  (node_emb): Embedding(17904, 128)
  (gcn1): GCNConv(128, 128)
  (gcn2): GCNConv(128, 128)
  (gcn3): GCNConv(128, 128)
  (edge_mlp): Sequential(
    (0): Linear(in_features=320, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear(in_features=128, out_features=64, bias=True)
    (5): ReLU()
    (6): Linear(in_features=64, out_features=4, bias=True)
  )
)

In [8]:
# Predict 4 heads for all edges
@torch.no_grad()
def predict_all_heads(batch_size=50000):
    E = edge_index.size(1)
    all_probs = []

    # node embeddings once
    x = model.encode_nodes(edge_index)

    for start in range(0, E, batch_size):
        end = min(start + batch_size, E)
        ei_batch = edge_index[:, start:end]        # [2, B]
        ea_batch = edge_attr[start:end]            # [B, D_edge]

        logits_batch = model.score_edges(x, ei_batch, ea_batch)  # [B, 4]
        probs_batch  = torch.sigmoid(logits_batch)               # [B, 4]

        all_probs.append(probs_batch.cpu())

    probs = torch.cat(all_probs, dim=0)   # [E, 4], aligned with edge_index rows
    return probs

probs_all = predict_all_heads()
probs_all.shape   # should be [num_edges, 4]


torch.Size([12527912, 4])

In [9]:
# Double check the AUCs on test set
from sklearn.metrics import roc_auc_score

# 1) grab labels and test index from gnn_data
y_all    = data_gnn["y"].cpu().numpy()      # shape [N, 4]
test_idx = data_gnn["test_idx"].cpu().numpy()   # shape [N_test]

# 2) predicted probs as numpy
probs_np = probs_all.cpu().numpy()              # shape [N, 4]

# 3) compute AUC per head on test set
test_auc_heads = []
for h in range(4):
    y_h = y_all[test_idx, h]
    p_h = probs_np[test_idx, h]

    # guard against degenerate labels
    if np.all(y_h == 0) or np.all(y_h == 1):
        auc_h = float("nan")
    else:
        auc_h = roc_auc_score(y_h, p_h)
    test_auc_heads.append(auc_h)

test_auc_mean = np.nanmean(test_auc_heads)

print("Test AUC per head:", [f"{a:.4f}" for a in test_auc_heads])
print("Test mean AUC:", f"{test_auc_mean:.4f}")

Test AUC per head: ['0.8082', '0.7961', '0.8321', '0.8537']
Test mean AUC: 0.8225


In [10]:
# Merge back to df
probs_np = probs_all.cpu().numpy()
df["yhat_complete"] = probs_np[:, 0]
df["yhat_long"]     = probs_np[:, 1]
df["yhat_rewatch"]  = probs_np[:, 2]
df["yhat_neg"]      = probs_np[:, 3]

df[["y_complete", "yhat_complete"]].head()

Unnamed: 0,y_complete,yhat_complete
0,1,0.315376
1,1,0.464749
2,0,0.598627
3,0,0.669049
4,0,0.476548


### 3. Load the toy recommender system

#### (1) Split edge features into user / item / session-group and create corresponding feature tables 

In [11]:
# Encode categorical columns as integers
cat_cols = df.select_dtypes(include="category").columns.tolist()
for c in cat_cols:
    df[c] = df[c].cat.codes.astype("int32")

In [12]:
feature_cols = data_gnn["feature_names"]
len(feature_cols)

feature_cols = list(feature_cols)

# user-level features
u_cols = [c for c in feature_cols if c.startswith("u_")]

# item-level features
i_cols = [c for c in feature_cols if c.startswith("i_")]

# session-level features: burst_id, session, ctx_*, hist_*
sess_fixed = ["burst_id", "session"]
ctx_cols   = [c for c in feature_cols if c.startswith("ctx_")]
hist_cols  = [c for c in feature_cols if c.startswith("hist_")]

sess_cols = sess_fixed + ctx_cols + hist_cols

len(u_cols), len(i_cols), len(sess_cols)


(34, 13, 17)

In [13]:
# User features
user_feats = (
    df[["user_id"] + u_cols]
    .drop_duplicates(subset="user_id")
    .set_index("user_id")
    .sort_index()
)

user_feats.shape

(7176, 34)

In [14]:
# Item features
item_feats = (
    df[["video_id"] + i_cols]
    .drop_duplicates(subset="video_id")
    .set_index("video_id")
    .sort_index()
)

item_feats.shape

(10728, 13)

In [15]:
# Sessioin features

session_feats = (
    df[["user_id"] + sess_cols]
    .groupby(["user_id", "session"], sort=False)
    .first()
)

# make sure 'session' is also a column (used as a feature)
session_feats = session_feats.copy()
session_feats["session"] = session_feats.index.get_level_values("session")

# overwrite author-history features with simple defaults
session_feats["hist_has_author_history"]   = 0.0     # treat as "no author history"
session_feats["hist_author_recency_days"]  = 60.0   # a large-ish recency cap
session_feats["hist_last_complete_author"] = 0.0     # no strong prior on last completion

session_feats = session_feats.sort_index()

session_feats.shape

(555591, 17)

In [16]:
# Select valid session (length >= 5) to reduce noise in estimation

MIN_SESS_LEN = 5

sess_len_by_sess = (
    df.groupby(["user_id", "session"])["sess_len"]
      .first()
)

valid_sess_idx = sess_len_by_sess[sess_len_by_sess >= MIN_SESS_LEN].index

session_feats = session_feats.loc[valid_sess_idx].sort_index()

valid_sessions = session_feats.index.to_frame(index=False)


In [17]:
item_feats.columns

Index(['i_aspect_ratio', 'i_author_id', 'i_video_type', 'i_upload_type',
       'i_visible_status', 'i_music_id', 'i_video_tag_id', 'i_video_tag_name',
       'i_video_duration', 'i_age_since_upload_days', 'i_cat_level1_id',
       'i_cat_level2_id', 'i_cat_level3_id'],
      dtype='object')

#### (2) Predict heads of all candidate videos for each session

In [18]:
def build_edge_features_for_session(
    user_id,
    session,
    user_feats,
    session_feats,
    item_feats,
    feature_cols,
    mu,
    std,
    device="cpu",
    candidate_video_ids=None,
    return_df=False,   # <--- new arg
):
    """
    For a given (user_id, session), build edge features for a set of candidate
    videos (default: all videos in the catalog).

    Returns:
      video_ids: 1D np.array of candidate video_ids (length V)
      edge_tensor: torch.FloatTensor of shape [V, D] with normalized features,
                   columns ordered as feature_cols.
      edge_df (optional): pandas DataFrame with raw (unnormalized) features
                          in feature_cols order (if return_df=True).
    """

    # 0) choose candidate videos: default = all items
    if candidate_video_ids is None:
        candidate_video_ids = item_feats.index.values

    video_ids = np.asarray(candidate_video_ids)

    # 1) item features for these videos
    item_sub = item_feats.loc[video_ids].copy()   # rows: videos, columns: i_*

    # 2) single user + session rows
    u_row = user_feats.loc[user_id]                # Series with u_* features
    s_row = session_feats.loc[(user_id, session)]  # Series with burst/ctx/hist/session

    # 3) start edge_df from item features
    edge_df = item_sub.copy()

    # add user-level features as constant columns
    for col in u_cols:
        edge_df[col] = u_row[col]

    # add session-level features as constant columns
    for col in sess_cols:
        edge_df[col] = s_row[col]

    # 4) reorder columns to match the exact GNN feature order
    edge_df = edge_df[feature_cols]


    # 4.5) make sure there are no pandas NA scalars left
    edge_df = edge_df.replace({pd.NA: np.nan})
    edge_df = edge_df.infer_objects(copy=False)
    
    # fill any remaining missing numeric values with 0.0
    num_cols = edge_df.select_dtypes(include=["number"]).columns
    edge_df[num_cols] = edge_df[num_cols].fillna(0.0)
    
    edge_np = edge_df.to_numpy(dtype="float32", copy=False)
    edge_np = np.nan_to_num(edge_np, nan=0.0)
    edge_tensor = torch.tensor(edge_np, dtype=torch.float32, device=device)

    
    # 5) convert to tensor and normalize
    edge_tensor = torch.tensor(
        edge_df.to_numpy().astype("float32"),
        dtype=torch.float32,
        device=device,
    )

    mu_t  = mu.to(device)
    std_t = std.to(device)

    edge_tensor = (edge_tensor - mu_t) / std_t

    if return_df:
        return video_ids, edge_tensor, edge_df
    else:
        return video_ids, edge_tensor


In [19]:
# test: pick the first (user_id, session) from valid sessions
u0, s0 = session_feats.index[0]

video_ids_0, edge_tensor_0, edge_df_0 = build_edge_features_for_session(
    user_id=u0,
    session=s0,
    user_feats=user_feats,
    session_feats=session_feats,
    item_feats=item_feats,
    feature_cols=feature_cols,
    mu=mu,
    std=std,
    device=device,
    return_df=True,
)

print(edge_df_0.shape)                 # should be (num_videos, 64)
edge_df_0.columns.tolist()      # check column names

(10728, 64)


['burst_id',
 'session',
 'u_user_active_degree',
 'u_is_lowactive_period',
 'u_is_live_streamer',
 'u_is_video_author',
 'u_follow_user_num',
 'u_follow_user_num_range',
 'u_fans_user_num',
 'u_fans_user_num_range',
 'u_friend_user_num',
 'u_friend_user_num_range',
 'u_register_days',
 'u_register_days_range',
 'u_onehot_feat0',
 'u_onehot_feat1',
 'u_onehot_feat2',
 'u_onehot_feat3',
 'u_onehot_feat4',
 'u_onehot_feat5',
 'u_onehot_feat6',
 'u_onehot_feat7',
 'u_onehot_feat8',
 'u_onehot_feat9',
 'u_onehot_feat10',
 'u_onehot_feat11',
 'u_onehot_feat12',
 'u_onehot_feat13',
 'u_onehot_feat14',
 'u_onehot_feat15',
 'u_onehot_feat16',
 'u_onehot_feat17',
 'u_follow_user_num_log1p',
 'u_fans_user_num_log1p',
 'u_friend_user_num_log1p',
 'u_register_days_log1p',
 'i_aspect_ratio',
 'i_author_id',
 'i_video_type',
 'i_upload_type',
 'i_visible_status',
 'i_music_id',
 'i_video_tag_id',
 'i_video_tag_name',
 'i_video_duration',
 'i_age_since_upload_days',
 'i_cat_level1_id',
 'i_cat_level2

In [20]:
# Precompute node embeddings once
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

edge_index_global = data_gnn["edge_index"].to(device)

model.to(device)
model.eval()

with torch.no_grad():
    x_nodes = model.encode_nodes(edge_index_global)   # [num_nodes, hidden_dim]


In [21]:
n_users = int(df["user_id"].max() + 1)

@torch.no_grad()
def predict_heads_for_session(user_id, session, candidate_video_ids=None):
    # 1) build edge features for this session and candidate videos
    video_ids, edge_feat = build_edge_features_for_session(
        user_id=user_id,
        session=session,
        user_feats=user_feats,
        session_feats=session_feats,
        item_feats=item_feats,
        feature_cols=feature_cols,
        mu=mu,
        std=std,
        device=device,
        candidate_video_ids=candidate_video_ids,
    )

    V = edge_feat.size(0)

    # 2) build edge_index consistent with training:
    #    src = user_id
    #    dst = n_users + video_id
    src_idx = torch.full(
        (V,),
        fill_value=int(user_id),
        dtype=torch.long,
        device=device,
    )
    dst_idx = torch.as_tensor(video_ids, dtype=torch.long, device=device) + n_users

    edge_index_cf = torch.stack([src_idx, dst_idx], dim=0)  # [2, V]

    # 3) use precomputed node embeddings from the full training graph
    logits = model.score_edges(x_nodes, edge_index_cf, edge_feat)   # [V, 4]
    probs  = torch.sigmoid(logits)                                  # [V, 4]

    return video_ids, probs


In [35]:
# test
tik = time()
video_ids, probs = predict_heads_for_session(u0, s0)
tok = time()
print(video_ids.shape)  # (num_videos,)
print(probs.shape)      # (num_videos, 4)
print("Execution time (s):",tok-tik)


(10728,)
torch.Size([10728, 4])
Execution time (s): 0.03943586349487305


#### (3) Precompute heads for all valid test sessions and save output
Only sessions with length >= 5 are considered as valid to reduce noise in estimation.

In [39]:
# Extract valid test sessions

# 1) define "valid" sessions: sess_len >= 5
valid_sessions_df = (
    df.loc[df["sess_len"] >= 5, ["user_id", "session"]]
      .drop_duplicates()
      .astype({"user_id": int, "session": int})
)

valid_sessions = pd.MultiIndex.from_frame(
    valid_sessions_df,
    names=["user_id", "session"],
)

# 2) sessions that appear in the test edges
test_pairs = df.loc[test_idx, ["user_id", "session"]].astype({"user_id": int, "session": int}).drop_duplicates()

test_sessions = pd.MultiIndex.from_frame(
    test_pairs,
    names=["user_id", "session"],
)

# 3) intersection: valid sessions that are also in test set
valid_test_sessions = test_sessions.intersection(valid_sessions)

print("Total test sessions:", len(test_sessions))
print("Valid test sessions (sess_len >= 5):", len(valid_test_sessions))

valid_test_sessions_list = list(valid_test_sessions)

Total test sessions: 118762
Valid test sessions (sess_len >= 5): 66423


In [40]:
import joblib
from joblib import Parallel, delayed
from tqdm.auto import tqdm

sessions = list(valid_test_sessions_list)   # list of (user_id, session) pairs

def _compute_heads_one_session(us):
    user_id, sess = us

    video_ids, probs = predict_heads_for_session(user_id, sess)
    # probs: torch tensor [V, 4]

    return (
        (int(user_id), int(sess)),
        np.asarray(video_ids, dtype=np.int32),
        probs.detach().cpu().numpy().astype(np.float32),
    )

# parallel over sessions, with a progress bar
results = Parallel(
    n_jobs=-1,           # adjust based on your CPU
    prefer="threads",   # safer with PyTorch
)(
    delayed(_compute_heads_one_session)(us)
    for us in tqdm(sessions, desc="Sessions", mininterval=5)
)

# pack into a dict
precomputed_heads_test = {}
for key, vids, probs in results:
    precomputed_heads_test[key] = {
        "video_ids": vids,   # shape [V]
        "probs": probs,      # shape [V, 4]
    }

print(f"Finished {len(precomputed_heads_test)} sessions.")

# save to file
out_path = BASE / "precomputed_heads.pkl"
joblib.dump(precomputed_heads_test, out_path)
print(f"Saved precomputed heads to: {out_path}")


Sessions:   0%|          | 0/66423 [00:00<?, ?it/s]

Finished 66423 sessions.
Saved precomputed heads to: /Users/haozhangao/Desktop/RecSys Research/KuaiRec 2.0/data/processed/precomputed_heads.pkl


#### (4) Score and recommend top-K videos

In [42]:
# build sess_len map (K per session) from df
sess_len_map = (
    df[["user_id", "session", "sess_len"]]
    .drop_duplicates()
    .set_index(["user_id", "session"])["sess_len"]
)

def recommend_for_all_sessions(weights, precomputed_heads, sess_len_map):
    """
    weights: array-like of shape (4,)
    precomputed_heads: dict[(user_id, session)] -> {"video_ids": [V], "probs": [V,4]}
    sess_len_map: Series indexed by (user_id, session) giving true sess_len

    Returns:
      recs: dict[(user_id, session)] -> {
          "video_ids": np.array[K],   # recommended video ids in rank order
          "scores":    np.array[K],   # corresponding scores
      }
    """
    w = np.asarray(weights, dtype=float).reshape(-1)   # [4]
    recs = {}

    for (user_id, sess), data in precomputed_heads.items():
        vids = data["video_ids"]          # [V]
        probs = data["probs"]             # [V, 4]

        # 1) score each candidate video
        scores = probs @ w                # [V]

        # 2) K = session length
        K = int(sess_len_map.loc[(user_id, sess)])
        K_eff = min(K, scores.shape[0])

        # 3) top-K
        top_idx = np.argpartition(-scores, K_eff - 1)[:K_eff]
        top_idx = top_idx[np.argsort(-scores[top_idx])]

        rec_vids = vids[top_idx]
        rec_scores = scores[top_idx]

        recs[(user_id, sess)] = {
            "video_ids": rec_vids,
            "scores": rec_scores,
        }

    return recs

In [45]:
def cat_distribution_from_recs(
    recs,
    item_feats,
    level_col="i_cat_level1_id",
    all_cats=None,
):
    """
    recs: dict[(user_id, session)] -> {"video_ids": np.array[K], "scores": np.array[K]}
    item_feats: DataFrame indexed by video_id, with a categorical column level_col
    level_col: which category level to use (e.g. "i_cat_level1_id")
    all_cats: optional list/array of category ids to fix the column order;
              if None, inferred from item_feats[level_col].
    Returns:
      DataFrame: index = (user_id, session), columns = category ids,
                 values = probabilities from recommended videos.
    """
    if all_cats is None:
        all_cats = np.sort(item_feats[level_col].dropna().unique())
    all_cats = np.asarray(all_cats)
    cat_to_pos = {c: i for i, c in enumerate(all_cats)}
    num_cats = len(all_cats)

    rows = []
    idx = []

    for (user_id, sess), d in recs.items():
        vids = d["video_ids"]
        # lookup categories
        cats = item_feats.loc[vids, level_col].to_numpy()

        # map to positions, skip NaNs
        mask = ~pd.isna(cats)
        cats = cats[mask]

        cat_pos = np.fromiter(
            (cat_to_pos[c] for c in cats),
            dtype=np.int64,
            count=len(cats),
        ) if len(cats) > 0 else np.array([], dtype=np.int64)

        counts = np.bincount(cat_pos, minlength=num_cats).astype(float)
        if counts.sum() > 0:
            probs = counts / counts.sum()
        else:
            probs = np.zeros(num_cats, dtype=float)

        rows.append(probs)
        idx.append((user_id, sess))

    cat_df = pd.DataFrame(
        rows,
        index=pd.MultiIndex.from_tuples(idx, names=["user_id", "session"]),
        columns=all_cats,
    )
    return cat_df


In [55]:
# test
tick = time()
recs_test = recommend_for_all_sessions(
    weights=[0.25, 0.25, 0.25, 0.25],
    precomputed_heads=precomputed_heads_test,
    sess_len_map=sess_len_map,
)

sim_cat_l1_test = cat_distribution_from_recs(
    recs_test,
    item_feats=item_feats,
    level_col="i_cat_level1_id",
    all_cats=ALL_L1_CATS,
)
tock = time()

print("run time:", tok-tik)
print(sim_cat_l1_test.shape)

8.116857290267944
(66423, 39)


### 3. Weight estimation

#### (1) Loss function and objective function

In [48]:
def cat_cross_entropy_loss(sim_cat_df, obs_cat_df, eps=1e-8):
    """
    Average cross-entropy between observed and simulated category distributions.

    obs_cat_df: observed probs Q, index=(user_id, session), columns=category ids
    sim_cat_df: simulated probs P, same structure (rows/cols may be a superset)

    Returns:
      float: mean cross-entropy over sessions
    """
    # align rows (sessions) and columns (categories)
    obs_aligned, sim_aligned = obs_cat_df.align(sim_cat_df, join="inner", axis=0)
    sim_aligned = sim_aligned[obs_aligned.columns]

    Q = obs_aligned.to_numpy(dtype=float)   # observed
    P = sim_aligned.to_numpy(dtype=float)   # simulated

    P_clipped = np.clip(P, eps, 1.0)
    per_sess_loss = -(Q * np.log(P_clipped)).sum(axis=1)

    return float(per_sess_loss.mean())


In [56]:
# test
w0 = [0.25, 0.25, 0.25, 0.25]

# restrict observed distributions to valid test sessions
obs_cat_valid_test = obs_cat_all.loc[valid_test_sessions]

# recommend for all valid test sessions using precomputed heads
recs_valid_test = recommend_for_all_sessions(
    weights=w0,
    precomputed_heads=precomputed_heads_test,   # dict for valid test sessions
    sess_len_map=sess_len_map,                  # from df["sess_len"]
)

# simulated categorical distribution for these sessions
sim_cat_valid_test = cat_distribution_from_recs(
    recs=recs_valid_test,
    item_feats=item_feats,
    level_col="i_cat_level1_id",
    all_cats=obs_cat_valid_test.columns,        # ensure same category set/order
)

# compute loss
loss_valid_test = cat_cross_entropy_loss(
    sim_cat_df=sim_cat_valid_test,
    obs_cat_df=obs_cat_valid_test,
)

print("Avg cross-entropy loss on valid test set:", loss_valid_test)

Avg cross-entropy loss on valid test set: 12.700291245452902


In [57]:
def obj_head_weights(weights, obs_cat_df):
    """
    Objective: given head weights and observed category distributions,
    return average cross-entropy loss on the valid test sessions.

    weights: array-like of shape (4,)
    obs_cat_df: DataFrame with observed probs, index=(user_id, session),
                columns = category ids (e.g. L1 category ids).
    """
    # sessions we actually have precomputed heads for
    sess_keys = list(precomputed_heads_test.keys())
    obs_sub = obs_cat_df.loc[sess_keys]

    # 1) recommend videos for all these sessions under given weights
    recs = recommend_for_all_sessions(
        weights=weights,
        precomputed_heads=precomputed_heads_test,
        sess_len_map=sess_len_map,
    )

    # 2) simulated categorical distributions
    sim_cat = cat_distribution_from_recs(
        recs=recs,
        item_feats=item_feats,
        level_col="i_cat_level1_id",
        all_cats=obs_sub.columns,
    )

    # 3) cross-entropy loss
    loss = cat_cross_entropy_loss(sim_cat_df=sim_cat, obs_cat_df=obs_sub)
    
    return loss

#### (2) Optimization

In [61]:
from scipy.optimize import minimize

def optimize_head_weights_L1(
    obs_cat_df,
    w0_full=(0.25, 0.25, 0.25, 0.25),
    maxiter=200,
    xatol=1e-3,
    fatol=1e-4,
):
    """
    Optimize 4 head weights with:
      w4_raw = 1 - (w1_raw + w2_raw + w3_raw)
      w = w_raw / sum(|w_raw|)
    Only x = (w1_raw, w2_raw, w3_raw) are decision variables.
    """

    # initialize 3-d params from your starting 4-d weights
    x0 = np.asarray(w0_full[:3], dtype=float)

    def objective_3(x):
        x1, x2, x3 = x
        w_raw = np.array(
            [x1, x2, x3, 1.0 - (x1 + x2 + x3)],
            dtype=float,
        )

        # L1-normalize to kill the scale degeneracy
        norm = np.sum(np.abs(w_raw))
        if norm < 1e-8:
            # avoid division by ~0 – just use uniform as a fallback
            w = np.array([0.25, 0.25, 0.25, 0.25], dtype=float)
        else:
            w = w_raw / norm

        L = obj_head_weights(w, obs_cat_df)
        print(f"loss = {L:.6f}, w = {w}")
        return L

    res = minimize(
        objective_3,
        x0=x0,
        method="Nelder-Mead",
        options={
            "maxiter": maxiter,
            "xatol": xatol,
            "fatol": fatol,
            "disp": True,
        },
    )

    # recover final normalized weights from the optimizer’s x
    x1_opt, x2_opt, x3_opt = res.x
    w_raw_opt = np.array(
        [x1_opt, x2_opt, x3_opt, 1.0 - (x1_opt + x2_opt + x3_opt)],
        dtype=float,
    )
    norm_opt = np.sum(np.abs(w_raw_opt))
    if norm_opt < 1e-8:
        w_opt = np.array([0.25, 0.25, 0.25, 0.25], dtype=float)
    else:
        w_opt = w_raw_opt / norm_opt

    return w_opt, res


In [62]:
# Optimization
w0 = (0.25, 0.25, 0.25, 0.25)

w_opt, res = optimize_head_weights_L1(
    obs_cat_df=obs_cat_valid_test,
    w0_full=w0,
    maxiter=500,
    xatol=1e-3,
    fatol=1e-4,
)

print("Final weights:", w_opt)
print("Final loss:", res.fun)

loss = 12.700291, w = [0.25 0.25 0.25 0.25]
loss = 12.681638, w = [0.2625 0.25   0.25   0.2375]
loss = 12.692298, w = [0.25   0.2625 0.25   0.2375]
loss = 12.680256, w = [0.25   0.25   0.2625 0.2375]
loss = 12.669948, w = [0.25833333 0.25833333 0.25833333 0.225     ]
loss = 12.653733, w = [0.2625 0.2625 0.2625 0.2125]
loss = 12.654739, w = [0.26666667 0.24583333 0.26666667 0.22083333]
loss = 12.644488, w = [0.25694444 0.25555556 0.27777778 0.20972222]
loss = 12.628388, w = [0.25416667 0.25833333 0.29166667 0.19583333]
loss = 12.615002, w = [0.27222222 0.26111111 0.28472222 0.18194444]
loss = 12.590354, w = [0.28333333 0.26666667 0.29583333 0.15416667]
loss = 12.594848, w = [0.26666667 0.27916667 0.3        0.15416667]
loss = 12.566129, w = [0.27361111 0.27361111 0.32916667 0.12361111]
loss = 12.538311, w = [0.27916667 0.27916667 0.3625     0.07916667]
loss = 12.534589, w = [0.29861111 0.29166667 0.34722222 0.0625    ]
loss = 12.505223, w = [ 0.31818182  0.30578512  0.37190083 -0.004132

#### (3) Robustness check: Matching on average video embeddings per session

In [65]:
# Extract video embeddings

with torch.no_grad():
    video_emb = x_nodes[n_users:].detach().cpu().numpy()   # [num_videos, d]

d_emb = video_emb.shape[1]
d_emb, video_emb.shape[0]

(128, 10728)

In [81]:
def build_observed_session_embeddings(df, test_idx, valid_test_sessions):
    """
    df: long data with columns ['user_id', 'session', 'video_id', ...]
    test_idx: indices of test edges (same as used in GNN)
    valid_test_sessions: MultiIndex of (user_id, session) (sess_len>=5 & in test)
    """
    # subset to test edges only
    df_test = df.loc[test_idx, ["user_id", "session", "video_id"]].copy()
    df_test["user_id"] = df_test["user_id"].astype(int)
    df_test["session"] = df_test["session"].astype(int)
    df_test["video_id"] = df_test["video_id"].astype(int)

    # keep only valid test sessions
    df_test = df_test.set_index(["user_id", "session"])
    df_test = df_test.loc[df_test.index.intersection(valid_test_sessions)]
    df_test = df_test.reset_index()

    obs_emb = {}
    for (uid, sess), g in df_test.groupby(["user_id", "session"]):
        vids = g["video_id"].to_numpy()
        emb = video_emb[vids]            # [K, d]
        mu_obs = emb.mean(axis=0)        # [d]
        obs_emb[(int(uid), int(sess))] = mu_obs

    return obs_emb

# Create a dict to store avg video embeddings for each session (user_id, session)
obs_session_emb = build_observed_session_embeddings(df, test_idx, valid_test_sessions)
print("Observed embeddings for sessions:", len(obs_session_emb))


Observed embeddings for sessions: 66423


In [71]:
def build_simulated_session_embeddings(weights, precomputed_heads, sess_len_map):
    """
    weights: array-like (4,) head weights
    precomputed_heads: dict[(user_id, session)] -> {"video_ids": [V], "probs": [V,4]}
    sess_len_map: Series indexed by (user_id, session) giving sess_len
    """
    from copy import deepcopy

    w = np.asarray(weights, dtype=float).reshape(-1)  # [4]

    sim_emb = {}
    for (uid, sess), data in precomputed_heads.items():
        vids = data["video_ids"]          # [V]
        probs = data["probs"]             # [V, 4]

        # 1) score and top-K
        scores = probs @ w                # [V]
        K = int(sess_len_map.loc[(uid, sess)])
        K_eff = min(K, scores.shape[0])

        top_idx = np.argpartition(-scores, K_eff - 1)[:K_eff]
        top_idx = top_idx[np.argsort(-scores[top_idx])]

        rec_vids = vids[top_idx]          # [K_eff]

        # 2) mean embedding
        emb = video_emb[rec_vids]         # [K_eff, d]
        mu_sim = emb.mean(axis=0)         # [d]

        sim_emb[(int(uid), int(sess))] = mu_sim

    return sim_emb

In [72]:
def embedding_matching_loss(weights,
                            obs_session_emb,
                            precomputed_heads,
                            sess_len_map):
    """
    L_emb(w) = average squared L2 distance between
               observed and simulated mean embeddings
               across sessions where both are defined.
    """
    # simulated embeddings for this w
    sim_session_emb = build_simulated_session_embeddings(
        weights, precomputed_heads, sess_len_map
    )

    # intersection of sessions
    common_keys = set(obs_session_emb.keys()).intersection(sim_session_emb.keys())
    if not common_keys:
        raise ValueError("No overlapping sessions between observed and simulated embeddings.")

    diffs = []
    for key in common_keys:
        mu_obs = obs_session_emb[key]   # [d]
        mu_sim = sim_session_emb[key]   # [d]
        diff = mu_obs - mu_sim
        diffs.append(np.dot(diff, diff))   # squared L2

    return float(np.mean(diffs))

In [74]:
# test
tick = time()
L_emb = embedding_matching_loss(
    weights=w_opt,                       # optimized weights
    obs_session_emb=obs_session_emb,
    precomputed_heads=precomputed_heads_test,
    sess_len_map=sess_len_map,
)
tock = time()
print("Embedding-matching loss at w_opt:", L_emb)
print("Run time:", tock-tick)

Embedding-matching loss at w_opt: 2476.6748046875
Run time: 8.754406213760376


In [91]:
from scipy.optimize import minimize
import numpy as np

def optimize_head_weights_L1_embed(
    obs_session_emb,
    precomputed_heads,
    sess_len_map,
    w0_full,
    maxiter=200,
    xatol=1e-3,
    fatol=1e-4,
):
    """
    Optimize 4 head weights w for the embedding-matching loss:

      L_emb(w) = avg_s || mu_s^obs - mu_s^sim(w) ||^2

    Parametrization:
      x = (w1_raw, w2_raw, w3_raw)
      w4_raw = -(1 - (w1_raw + w2_raw + w3_raw))   # always negative
      w = w_raw / sum(|w_raw|)                     # L1-normalized

    Only x (first 3 raw weights) are decision variables.
    """

    w0_full = np.asarray(w0_full, dtype=float)
    x0 = w0_full[:3].copy()

    def x_to_w(x):
        """Map 3-d decision variable -> 4-d normalized weights."""
        x1, x2, x3 = x
        w_raw = np.array(
            [x1, x2, x3, -(1.0 - (x1 + x2 + x3))],
            dtype=float,
        )
        norm = np.sum(np.abs(w_raw))
        if norm < 1e-8:
            # fallback if optimizer goes crazy
            return np.array([0.25, 0.25, 0.25, -0.25], dtype=float)
        return w_raw / norm

    def objective_3(x):
        w = x_to_w(x)
        L = embedding_matching_loss(
            weights=w,
            obs_session_emb=obs_session_emb,
            precomputed_heads=precomputed_heads,
            sess_len_map=sess_len_map,
        )
        print(f"loss = {L:.6f}, w = {w}")
        return L

    res = minimize(
        objective_3,
        x0=x0,
        method="Nelder-Mead",
        options={
            "maxiter": maxiter,
            "xatol": xatol,
            "fatol": fatol,
            "disp": True,
        },
    )

    # use the SAME mapping for the final weights
    w_opt = x_to_w(res.x)
    return w_opt, res


In [92]:
# Reoptimize
w0_embed = w_opt  # from previous optimization

w_opt_embed, res_embed = optimize_head_weights_L1_embed(
    obs_session_emb=obs_session_emb,
    precomputed_heads=precomputed_heads_test,
    sess_len_map=sess_len_map,
    w0_full=w0_embed,
    maxiter=200,
    xatol=1e-3,
    fatol=1e-4,
)

print("Final weights (embedding loss):", w_opt_embed)
print("Final embedding loss:", res_embed.fun)

loss = 2476.674805, w = [ 0.14549531  0.08109316  0.2734943  -0.49991723]
loss = 2476.633301, w = [ 0.15277007  0.08109316  0.2734943  -0.49264247]
loss = 2476.712891, w = [ 0.14549531  0.08514782  0.2734943  -0.49586257]
loss = 2477.350830, w = [ 0.14549531  0.08109316  0.28716902 -0.48624252]
loss = 2475.879395, w = [ 0.15034515  0.08379626  0.25981959 -0.506039  ]
loss = 2474.844482, w = [ 0.15277007  0.08514782  0.24614487 -0.51593724]
loss = 2475.459961, w = [ 0.155195    0.07974161  0.25526135 -0.50980205]
loss = 2474.547607, w = [ 0.16166145  0.08289523  0.24310605 -0.51233727]
loss = 2473.087402, w = [ 0.16974453  0.08379626  0.22791192 -0.51854729]
loss = 2470.986572, w = [ 0.16570299  0.0846973   0.21271779 -0.53688192]
loss = 2463.661133, w = [ 0.17216945  0.08649937  0.18232953 -0.55900165]
loss = 2464.105225, w = [ 0.17459437  0.09055403  0.18232953 -0.55252207]
loss = 2446.764404, w = [ 0.19156882  0.08875196  0.14890245 -0.57077677]
loss = 2355.878418, w = [ 0.2109682   

In [89]:
w_opt_embed

array([ 0.22281076,  0.09543717, -0.01856394,  0.66318813])

In [90]:
# test
tick = time()
L_emb = embedding_matching_loss(
    weights=w_opt_embed,                       # optimized weights
    obs_session_emb=obs_session_emb,
    precomputed_heads=precomputed_heads_test,
    sess_len_map=sess_len_map,
)
tock = time()
print("Embedding-matching loss at w_opt:", L_emb)
print("Run time:", tock-tick)

Embedding-matching loss at w_opt: 2439.9248046875
Run time: 13.820978879928589
