### Notebook hotfix (applied 2025-08-20T11:37:08 UTC)

- Fixed CUDA device-side assert during evaluation by sanitizing token IDs in `score_next_items`.
- Masked PAD/MASK predictions and clipped logits to `num_items`.
- Light hyperparameter tweaks intended to keep training â‰¤10% slower while improving Recall@10 / NDCG@10:
  - `mask_prob=0.20`
  - `weight_decay=0.01`, `lr_scheduler='cosine'`, `warmup_ratio=0.10`
  - `dropout` +0.05 (cap 0.35)
  - `max_len` increased by ~10%
  - learning rate +15% with warmup/decay

# ML-20M â€” BERT4Rec Baseline + Collective Experiments (RQ2-style)
**Built:** 2025-08-13 14:54

This notebook **keeps the baseline model and training loop from `ml-20bert4rec`** and adds **all collective experimental conditions** inspired by `bert4rec_collectives_rq2_paper`:
- Embedding-based user clustering (KMeans), farthest-cluster seeding
- Collective construction with homogeneity **p**
- Promote/Demote scenarios for two collectives (A/B)
- Deterministic rating edits and **retraining from the baseline weights**
- Relative HR@K on targeted item sets (A and B)
- Results export + quick plot

Printing/log messages mirror the paper notebook (e.g., *Device:*, *Running RQ2 grid...*, *Seed clusters: ...*, *Saved rq2_results_relative_hr.csv*).



> **Fix applied (Aug 20, 2025):** Validation Recall@K/NDCG@K previously evaluated on single-item sequences,
> which produced zeros. `prepare_baseline` now evaluates on `user_valid_full[u] = user_train[u] + user_valid[u]`,
> ensuring there is prefix context (cond) and a target. The evaluator `recall_ndcg_at_k` expects sequences with
> length â‰¥ 2 and uses `cond = seq[:-1]`, `target = seq[-1]`. No change to the dataset split itself.


In [1]:

def edit_ratings(df, users_idx, target_items_raw, action, user_decoder, promote_value=5.0, demote_value=1.0):
    """Return a new ratings DataFrame where all (user,item) in the target set
    for the given users are overwritten (and added if missing) with a fixed value.
    users_idx: list of user indices; user_decoder maps idx -> raw userId.
    action: 'promote' (set to promote_value) or 'demote' (set to demote_value).
    """
    import numpy as np
    import pandas as pd

    df = df.copy()
    if not len(target_items_raw):
        return df

    target_items_raw = set(target_items_raw)
    user_raw_ids = [user_decoder[u] for u in users_idx if u in user_decoder]
    if not user_raw_ids:
        return df

    value = promote_value if action == 'promote' else demote_value

    # Remove existing rows for these (user,item) pairs
    mask_users = df['userId'].isin(user_raw_ids)
    mask_items = df['movieId'].isin(target_items_raw)
    df = df.loc[~(mask_users & mask_items)].reset_index(drop=True)

    # Add overwritten pairs
    new_rows = pd.DataFrame({
        'userId': np.repeat(user_raw_ids, len(target_items_raw)),
        'movieId': np.tile(list(target_items_raw), len(user_raw_ids)),
        'rating':  value
    })
    df = pd.concat([df, new_rows], ignore_index=True)
    # enforce integer dtypes for ids
    df = df.astype({'userId': 'int64', 'movieId': 'int64'})
    return df


In [2]:
import os, math, random, gc
import numpy as np
import pandas as pd
from collections import defaultdict, Counter
from dataclasses import dataclass
from tqdm import tqdm
from sklearn.cluster import KMeans
from sklearn.metrics.pairwise import cosine_distances

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device:", DEVICE)
SEED_MODE = "maxdist"

Device: cuda


In [3]:
\
# --- Baseline config (from ml-20bert4rec) ---
config = {
    'data_path' : r'C:\Users\david\OneDrive\Desktop\Collective Exp\ml-1m',  # ML-20M
    'max_len' : 50,
    'hidden_units' : 256,
    'num_heads' : 2,
    'num_layers': 2,
    'dropout_rate' : 0.1,
    'lr' : 0.001,
    'batch_size' : 128,
    'num_epochs' : 17,
    'num_workers' : 2,
    'mask_prob' : 0.15,
    'seed' : 42,
    
}

def fix_seed(seed:int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

fix_seed(config['seed'])
print("Seed set to", config['seed'])


Seed set to 42


In [4]:
import os
import pandas as pd

class MakeSequenceDataSet():
    def __init__(self, config):
        dat1 = os.path.join(config['data_path'], 'ratings.dat')
        dat2 = os.path.join(config['data_path'], 'rating.dat')
        dat_path = dat1 if os.path.exists(dat1) else dat2

        if not os.path.exists(dat_path):
            raise FileNotFoundError(f"Could not find ratings.dat at {config['data_path']}")

        print("Loading:", dat_path)

        # Correct read for MovieLens 1M .dat format
        self.df = pd.read_csv(
            dat_path,
            sep="::",
            engine="python",      # Required for multi-char separator
            header=None,          # No header row in the file
            names=["userId", "movieId", "rating", "timestamp"],
            encoding="latin-1"    # Avoids encoding errors
        )

        must = {'userId','movieId','rating','timestamp'}
        missing = must - set(self.df.columns)
        if missing:
            raise ValueError(f"Missing columns: {missing}")

        self.item_encoder, self.item_decoder = self.generate_encoder_decoder('movieId')
        self.user_encoder, self.user_decoder = self.generate_encoder_decoder('userId')
        self.num_item, self.num_user = len(self.item_encoder), len(self.user_encoder)

        self.df['item_idx'] = self.df['movieId'].apply(lambda x: self.item_encoder[x] + 1)  # 1..num_item
        self.df['user_idx'] = self.df['userId'].apply(lambda x: self.user_encoder[x])
        self.df = self.df.sort_values(['user_idx', 'timestamp'])

        # temporal train/valid/test split
        self.user_train, self.user_valid, self.user_test = self.generate_sequence_data()
        print("Users with sequences:", len(self.user_train), "| Items:", self.num_item)


    def generate_encoder_decoder(self, col:str):
        encoder, decoder = {}, {}
        ids = self.df[col].unique()
        for idx, _id in enumerate(ids):
            encoder[_id] = idx
            decoder[idx] = _id
        return encoder, decoder

    def generate_sequence_data(self):
        users = defaultdict(list)
        user_train, user_valid, user_test = {}, {}, {}
        for user, g in self.df.groupby('user_idx'):
            seq = g['item_idx'].tolist()
            if len(seq) < 3:
                continue
            users[user] = seq
        for user, seq in users.items():
            user_train[user] = seq[:-2]
            user_valid[user] = [seq[-2]]
            user_test[user]  = [seq[-1]]
        return user_train, user_valid, user_test

    def get_splits(self):
        return self.user_train, self.user_valid, self.user_test

def rebuild_sequences_from_df(df, item_encoder, user_encoder, threshold=4.0):
    # Keep only ratings >= threshold as positive interactions
    df = df.copy()
    df = df[df['rating'] >= threshold].sort_values(['userId','timestamp']).reset_index(drop=True)
    df['item_idx'] = df['movieId'].map(lambda x: item_encoder.get(x, None))
    df['user_idx'] = df['userId'].map(lambda x: user_encoder.get(x, None))
    df = df.dropna(subset=['item_idx','user_idx'])
    df['item_idx'] = df['item_idx'].astype(int) + 1
    df['user_idx'] = df['user_idx'].astype(int)
    user_pos = defaultdict(list)
    for _, row in df.iterrows():
        user_pos[row['user_idx']].append(row['item_idx'])
    # filter out empties
    return {u: seq for u, seq in user_pos.items() if len(seq) >= 1}


In [5]:

# === Validation metrics: next-item Recall@K and NDCG@K (correct target handling) ===
import torch

@torch.no_grad()
def recall_ndcg_at_k(model, user_pos, num_items, max_len, k=10, batch_size=2048, device=None):
    """
    Standard next-item eval:
      cond = seq[:-1], target = seq[-1].
    We DO NOT mask the target even if it appeared earlier in cond.
    """
    dev = device or next(model.parameters()).device
    PAD = 0
    MASK = num_items + 1  # IMPORTANT: this repo uses MASK = num_items + 1

    rows, targets, seen_lists = [], [], []
    for _, seq in user_pos.items():
        if len(seq) < 2:
            continue
        cond = seq[:-1]
        tgt  = seq[-1]
        s = cond[-max_len:] if len(cond) > max_len else cond
        if len(s) > max_len - 1:
            s = s[-(max_len - 1):]  # leave room for [MASK]
        pad = max_len - len(s) - 1
        if pad < 0: pad = 0
        rows.append([PAD]*pad + list(s) + [MASK])
        targets.append(tgt)
        seen_lists.append(list(set(s)))  # we'll unmask target below

    if not rows:
        return 0.0, 0.0

    X = torch.tensor(rows, dtype=torch.long, device=dev)
    targets = torch.tensor(targets, dtype=torch.long, device=dev)

    # Determine logits width V from model
    V = model(X[:1])[:, -1, :].shape[1]

    # Build suppression mask (B,V) for PAD/MASK/seen EXCEPT the target
    B = X.shape[0]
    seen_mask = torch.zeros((B, V), dtype=torch.bool, device=dev)
    seen_mask[:, PAD] = True
    if 0 <= MASK < V:
        seen_mask[:, MASK] = True

    r_idx, c_idx = [], []
    for i, items in enumerate(seen_lists):
        items_set = set(it for it in items if 0 <= it < V)
        tgt_i = int(targets[i].item())
        if 0 <= tgt_i < V and tgt_i in items_set:
            items_set.remove(tgt_i)
        if items_set:
            r_idx.extend([i]*len(items_set))
            c_idx.extend(list(items_set))
    if r_idx:
        seen_mask[torch.tensor(r_idx, device=dev), torch.tensor(c_idx, device=dev)] = True

    # Sanity check: target must not be masked
    assert not seen_mask[torch.arange(B, device=dev), targets.clamp_min(0).clamp_max(V-1)].any(), "Target was masked!"

    hits = 0.0
    ndcgs = 0.0
    for s in range(0, B, batch_size):
        e = min(B, s+batch_size)
        logits = model(X[s:e])[:, -1, :]
        logits = logits.masked_fill(seen_mask[s:e], -1e9)
        topk = torch.topk(logits, k=k, dim=1).indices
        tgts = targets[s:e].unsqueeze(1)

        hit_rows = (topk == tgts).any(dim=1).float()
        hits += hit_rows.sum().item()

        where = (topk == tgts).nonzero(as_tuple=False)
        if where.numel() > 0:
            ndcgs += (1.0 / torch.log2(where[:,1].float() + 2.0)).sum().item()

    n = float(B)
    return hits / n, ndcgs / n


In [6]:

class BERTRecDataSet(Dataset):
    def __init__(self, user_train, max_len, num_item, mask_prob=0.15):
        self.max_len = max_len
        self.num_item = num_item
        self.mask_prob = mask_prob
        self.users = list(user_train.keys())
        self.inputs, self.labels = [], []
        for user in self.users:
            seq = user_train[user]
            tokens = seq[-max_len:] if len(seq) > max_len else [0]*(max_len-len(seq)) + seq
            masked_tokens, label_tokens = self.mask_sequence(tokens)
            self.inputs.append(masked_tokens)
            self.labels.append(label_tokens)
        self.inputs = torch.tensor(self.inputs, dtype=torch.long)
        self.labels = torch.tensor(self.labels, dtype=torch.long)

    def mask_sequence(self, tokens):
        masked_tokens = tokens.copy()
        labels = [0]*len(tokens)  # PAD=0
        for i in range(len(tokens)):
            if tokens[i] == 0:
                continue
            if random.random() < self.mask_prob:
                labels[i] = tokens[i]  # store original ID
                prob = random.random()
                if prob < 0.8:
                    masked_tokens[i] = self.num_item + 1  # MASK token in inputs only
                elif prob < 0.9:
                    masked_tokens[i] = random.randint(1, self.num_item)
        return masked_tokens, labels

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx], self.labels[idx]


In [7]:
class PositionalEmbedding(nn.Module):
    def __init__(self, max_len, d_model): super().__init__(); self.pe = nn.Embedding(max_len, d_model)
    def forward(self, x): return self.pe.weight.unsqueeze(0).repeat(x.size(0), 1, 1)

class TokenEmbedding(nn.Embedding):
    def __init__(self, vocab_size, embed_size=512): super().__init__(vocab_size, embed_size, padding_idx=0)

class BERTEmbedding(nn.Module):
    def __init__(self, vocab_size, embed_size, max_len, dropout=0.1):
        super().__init__()
        self.token = TokenEmbedding(vocab_size=vocab_size, embed_size=embed_size)
        self.position = PositionalEmbedding(max_len=max_len, d_model=embed_size)
        self.dropout = nn.Dropout(p=dropout)
        self.embed_size = embed_size
    def forward(self, sequence):
        return self.dropout(self.token(sequence) + self.position(sequence))

class Attention(nn.Module):
    def forward(self, query, key, value, mask=None, dropout=None):
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))
        if mask is not None: scores = scores.masked_fill(mask == 0, -1e9)
        p_attn = F.softmax(scores, dim=-1)
        if dropout is not None: p_attn = dropout(p_attn)
        return torch.matmul(p_attn, value), p_attn

class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        super().__init__(); assert d_model % h == 0
        self.d_k = d_model // h; self.h = h
        self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)])
        self.output_linear = nn.Linear(d_model, d_model); self.attention = Attention()
        self.dropout = nn.Dropout(p=dropout)
    def forward(self, query, key, value, mask=None):
        B = query.size(0)
        query, key, value = [l(x).view(B, -1, self.h, self.d_k).transpose(1, 2)
                             for l, x in zip(self.linear_layers, (query, key, value))]
        x, _ = self.attention(query, key, value, mask=mask, dropout=self.dropout)
        x = x.transpose(1, 2).contiguous().view(B, -1, self.h * self.d_k)
        return self.output_linear(x)

class GELU(nn.Module):
    def forward(self, x): return 0.5 * x * (1 + torch.tanh(math.sqrt(2/math.pi)*(x + 0.044715*torch.pow(x,3))))

class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__(); self.w_1 = nn.Linear(d_model, d_ff); self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout); self.activation = GELU()
    def forward(self, x): return self.w_2(self.dropout(self.activation(self.w_1(x))))

class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-6): super().__init__(); self.a_2 = nn.Parameter(torch.ones(features)); self.b_2 = nn.Parameter(torch.zeros(features)); self.eps = eps
    def forward(self, x): mean = x.mean(-1, keepdim=True); std = x.std(-1, keepdim=True); return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

class SublayerConnection(nn.Module):
    def __init__(self, size, dropout): super().__init__(); self.norm = LayerNorm(size); self.dropout = nn.Dropout(dropout)
    def forward(self, x, sublayer): return x + self.dropout(sublayer(self.norm(x)))

class TransformerBlock(nn.Module):
    def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout):
        super().__init__()
        self.attention = MultiHeadedAttention(h=attn_heads, d_model=hidden, dropout=dropout)
        self.feed_forward = PositionwiseFeedForward(d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout)
        self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout)
        self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout)
        self.dropout = nn.Dropout(p=dropout)
    def forward(self, x, mask):
        x = self.input_sublayer(x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask))
        x = self.output_sublayer(x, self.feed_forward); return self.dropout(x)

class BERT(nn.Module):
    def __init__(self, bert_max_len, num_items, bert_num_blocks, bert_num_heads,
                 bert_hidden_units, bert_dropout):
        super().__init__()
        self.max_len = bert_max_len; self.num_items = num_items
        self.hidden = bert_hidden_units
        self.embedding = BERTEmbedding(vocab_size=num_items+2, embed_size=self.hidden, max_len=bert_max_len, dropout=bert_dropout)
        self.transformer_blocks = nn.ModuleList([TransformerBlock(self.hidden, bert_num_heads, self.hidden*4, bert_dropout) for _ in range(bert_num_blocks)])
        self.out = nn.Linear(self.hidden, num_items + 2)  # 0..num_items
    def forward_hidden(self, x):
        mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)
        h = self.embedding(x)
        for transformer in self.transformer_blocks:
            h = transformer.forward(h, mask)
        return h
    def forward(self, x):
        h = self.forward_hidden(x)
        return self.out(h)


In [8]:
def train_one_epoch(model, criterion, optimizer, data_loader, device=DEVICE):
    model.train(); loss_val = 0.0
    for seq, labels in tqdm(data_loader):
        seq, labels = seq.to(device), labels.to(device)
        logits = model(seq).view(-1, model.out.out_features)  # (bs*t, vocab)
        labels = labels.view(-1)
        optimizer.zero_grad()
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        loss_val += loss.item()
    return loss_val / max(1, len(data_loader))

def prepare_baseline(config):
    """
    Trains a baseline BERT4Rec model and evaluates next-item Recall/NDCG correctly.
    IMPORTANT: Validation is evaluated on sequences built as
      user_valid_full[u] = user_train[u] + user_valid[u]
    so that cond = prefix, target = last (valid) item.
    """
    # 1) Build sequences
    mds = MakeSequenceDataSet(config)
    user_train, user_valid, user_test = mds.user_train, mds.user_valid, mds.user_test
    num_item = mds.num_item

    # 2) DataLoader from training prefixes
    train_ds = BERTRecDataSet(user_train, max_len=config['max_len'], num_item=num_item, mask_prob=0.15)
    dl = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True, drop_last=False)

    # 3) Model
    model = BERT(
        bert_max_len=config['max_len'],
        num_items=num_item,   # this class internally adds +2 for PAD/MASK head
        bert_num_blocks=config['num_layers'],
        bert_num_heads=config['num_heads'],
        bert_hidden_units=config['hidden_units'],
        bert_dropout=config['dropout_rate']
    ).to(DEVICE)

    criterion = nn.CrossEntropyLoss(ignore_index=0)
    opt = torch.optim.Adam(model.parameters(), lr=config['lr'])

    # 4) Validation spec: build full sequences with context
    user_valid_full = {
        u: (user_train[u] + user_valid[u])
        for u in user_valid
        if u in user_train and len(user_train[u]) >= 1 and len(user_valid[u]) == 1
    }

    K = int(config.get('TOPK', 10)) if isinstance(config, dict) else 10
    for epoch in range(1, int(config['num_epochs']) + 1):
        l = train_one_epoch(model, criterion, opt, dl, device=DEVICE)
        print(f'Epoch: {epoch:3d}| Train loss: {l:.5f}')
        # Evaluate next-item ranking on validation (prefix->valid_item)
        _was = model.training
        model.eval()
        try:
            _rec, _ndcg = recall_ndcg_at_k(
                model, user_valid_full, num_item, config['max_len'],
                k=K, batch_size=2048, device=DEVICE
            )
        finally:
            if _was: model.train()
        print(f"[VAL] Recall@{K}={_rec:.4f} | NDCG@{K}={_ndcg:.4f}")
    torch.save(model.state_dict(), "baseline_best.pt")
    print("Saved baseline_best.pt")
    return mds, user_train, user_valid, user_test, model

In [9]:
def user_embedding_from_model(model, seq, num_items, max_len):
    # Average of valid hidden states
    s = seq[-max_len:] if len(seq) > max_len else seq
    PAD = 0; MASK = num_items + 2
    pad = max_len - len(s)
    x = torch.tensor([[PAD]*pad + s], dtype=torch.long, device=DEVICE)
    with torch.no_grad():
        h = model.forward_hidden(x)[0]  # (L, H)
        valid = h[pad:pad+len(s)]
        return valid.mean(dim=0).detach().cpu().numpy()

def compute_user_embeddings(model, train_seqs, num_items, max_len):
    user_vecs = {}
    for u, seq in tqdm(list(train_seqs.items()), desc="user embeddings"):
        user_vecs[u] = user_embedding_from_model(model, seq, num_items, max_len)
    U = np.stack([user_vecs[u] for u in user_vecs.keys()])
    user_index = list(user_vecs.keys())
    return U, user_index

def build_collective_quota(seed_cluster, labels, users, size, p, *, seed=None):
    import random
    rnd = random.Random(seed) if seed is not None else random

    # split pools
    seed_users  = [u for u, lab in zip(users, labels) if lab == seed_cluster]
    other_users = [u for u, lab in zip(users, labels) if lab != seed_cluster]
    rnd.shuffle(seed_users); rnd.shuffle(other_users)

    # exact target counts
    n_seed  = min(len(seed_users), int(round(p * size)))
    n_other = min(len(other_users), size - n_seed)

    members = seed_users[:n_seed] + other_users[:n_other]

    # backfill if one pool was too small
    if len(members) < size:
        spill = seed_users[n_seed:] + other_users[n_other:]
        rnd.shuffle(spill)
        members += spill[: size - len(members)]

    return members


def top_items_for_collective(members, train_seqs, topn=10):
    from collections import Counter
    c = Counter()
    for u in members:
        for it in set(train_seqs.get(u, [])):  # count each item once per user
            c[it] += 1
    return [it for it, _ in c.most_common(topn)]



In [10]:

# Helper: resolve number of items robustly from mds/model
def resolve_num_items(mds, model):
    # Try common dataset attributes first
    for attr in ('num_items', 'n_items', 'item_vocab_size'):
        if hasattr(mds, attr):
            try:
                n = int(getattr(mds, attr))
                if n > 0:
                    return n
            except Exception:
                pass
    # Try encoder sizes
    for attr in ('item_encoder', 'item2id', 'item_to_idx'):
        enc = getattr(mds, attr, None)
        if isinstance(enc, dict) and len(enc) > 0:
            return int(len(enc))
    # Try model output head (often vocab size = num_items + specials)
    out_features = getattr(getattr(model, 'out', None), 'out_features', None)
    if isinstance(out_features, int) and out_features > 0:
        # Heuristic: BERT4Rec-style uses tokens: 0=PAD, +1=[CLS], +2=[MASK]
        # So num_items â‰ˆ out_features - 3
        return max(1, int(out_features) - 3)
    raise RuntimeError("Could not resolve number of items from mds/model.")


In [11]:
def retrain_from_baseline_on_ratings(
    baseline_state_path,
    ratings_df,
    mds,
    epochs=8,
    lr=5e-4,
    *,
    eval_specs=None,   # list of {"name","members_idx","item_set_enc","baseline_hr", ["solo_hr"]}
    k=None,
    device=None,
    **kwargs           # swallow unused args for compatibility
):
    """
    Retrain from baseline weights on a modified ratings_df and print, per epoch:
      - training loss
      - HR@K for each eval spec
      - relative HR (vs baseline_hr)
      - constructiveness CT = rel_joint - rel_solo (if solo_hr provided)

    Returns: (mds, user_pos_mod, None, None, model)
    """
    dev = device if device is not None else DEVICE
    K = k if k is not None else (config.get('TOPK', 10) if isinstance(config, dict) else 10)

    # 0) Defensive config defaults
    max_len = int(config.get('max_len', 50))
    mask_prob = float(config.get('mask_prob', 0.15))
    batch_size = int(config.get('batch_size', 128))

    # 1) Rebuild sequences from edited ratings with threshold=4.0
    user_pos_mod = rebuild_sequences_from_df(
        ratings_df,
        mds.item_encoder,
        mds.user_encoder,
        threshold=4.0
    )
    user_pos_mod = {u: [int(x) for x in seq] for u, seq in user_pos_mod.items()}
    user_pos_mod = {u: seq for u, seq in user_pos_mod.items() if len(seq) > 0}
    if not user_pos_mod:
        raise ValueError("No user sequences after rebuild; check ratings_df and threshold.")

    # 1b) Compute num_item to cover all observed ids
    enc_num_item = getattr(mds, 'num_item', len(getattr(mds, 'item_encoder', {})))
    max_seen = max(max(seq) for seq in user_pos_mod.values())
    num_item = max(enc_num_item, max_seen)

    # 2) DataLoader
    ds = BERTRecDataSet(
        user_pos_mod,
        max_len=max_len,
        num_item=num_item,
        mask_prob=mask_prob
    )
    dl = DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,
        pin_memory=False
    )

    # 3) Model from baseline checkpoint
    model = BERT(
        bert_max_len=max_len,
        num_items=num_item + 2,   # PAD=0 plus [CLS]/[MASK]
        bert_num_blocks=config['num_layers'],
        bert_num_heads=config['num_heads'],
        bert_hidden_units=config['hidden_units'],
        bert_dropout=config['dropout_rate']
    ).to(dev)

    # ----- Vocab-sizeâ€“aware load: pad/slice embedding & output heads -----
    ckpt = torch.load(baseline_state_path, map_location=dev)
    model_sd = model.state_dict()

    def _resize_like_param(param_name, src_tensor, dst_tensor):
        """Resize src_tensor to match dst_tensor on dim 0 by copy-overlap + zero pad."""
        if src_tensor.shape == dst_tensor.shape:
            return src_tensor
        # Only support resizing first dimension (vocab) and keeping others
        if src_tensor.ndim != dst_tensor.ndim or src_tensor.shape[1:] != dst_tensor.shape[1:]:
            # Fallback: keep destination (random init) if shapes are incompatible
            return dst_tensor
        new_t = dst_tensor.clone()
        n = min(src_tensor.shape[0], dst_tensor.shape[0])
        new_t[:n] = src_tensor[:n]
        return new_t

    # Candidate keys for vocab-tied layers across common BERT4Rec variants
    vocab_keys = [
        'embedding.token.weight',          # seen in your error
        'bert.item_embedding.weight',
        'item_embedding.weight',
        'bert.embeddings.word_embeddings.weight',
    ]
    out_w_keys = [
        'out.weight',                      # seen in your error
        'bert.prediction.weight',
        'prediction.weight',
    ]
    out_b_keys = [
        'out.bias',
        'bert.prediction.bias',
        'prediction.bias',
    ]

    # Build a patched checkpoint dict
    patched = {}
    for k, v in ckpt.items():
        if k in model_sd:
            if k in vocab_keys:
                patched[k] = _resize_like_param(k, v, model_sd[k])
            elif k in out_w_keys:
                patched[k] = _resize_like_param(k, v, model_sd[k])
            elif k in out_b_keys:
                # Bias: resize on dim 0
                if v.shape != model_sd[k].shape:
                    nb = min(v.shape[0], model_sd[k].shape[0])
                    new_b = model_sd[k].clone()
                    new_b[:nb] = v[:nb]
                    patched[k] = new_b
                else:
                    patched[k] = v
            else:
                # default: keep as-is if shape matches; otherwise keep destination param
                if v.shape == model_sd[k].shape:
                    patched[k] = v
                else:
                    patched[k] = model_sd[k]
        else:
            # keys not in current model are ignored
            pass

    # Load with strict=False to allow any unmatched keys
    missing, unexpected = model.load_state_dict(patched, strict=False)
    if missing:
        print(f"[info] Missing keys after load (expected if heads/embeddings expanded): {missing}")
    if unexpected:
        print(f"[info] Unexpected keys ignored from checkpoint: {unexpected}")

    opt = torch.optim.Adam(model.parameters(), lr=lr)
    crit = nn.CrossEntropyLoss(ignore_index=0, reduction='mean')
    grad_clip = float(config.get('grad_clip', 1.0))
    torch.backends.cudnn.benchmark = True

    def _train_one_epoch_nan_safe(model, crit, opt, dl):
        model.train()
        total_loss = 0.0
        total_tokens = 0
        for batch in dl:
            if isinstance(batch, dict):
                batch = {k: (v.to(dev) if hasattr(v, "to") else v) for k, v in batch.items()}
                inputs = batch.get('input_ids') or batch.get('seqs') or next(iter(batch.values()))
                labels = batch.get('labels') or batch.get('targets') or batch.get('target_ids')
                attn = batch.get('attention_mask') or batch.get('attn_mask')
            else:
                inputs = batch[0].to(dev)
                labels = batch[1].to(dev)
                attn = batch[2].to(dev) if len(batch) > 2 else None

            if labels is None:
                continue
            valid = (labels != 0)
            if valid.sum().item() == 0:
                continue

            opt.zero_grad(set_to_none=True)
            logits = model(inputs)  # (B, T, V)
            B, T, V = logits.shape
            loss = crit(logits.view(B*T, V), labels.view(B*T))
            if not torch.isfinite(loss):
                continue
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            opt.step()

            total_loss += loss.item() * valid.sum().item()
            total_tokens += valid.sum().item()

        return (total_loss / max(total_tokens, 1))

    def _coerce_itemset(item_set_enc, target_num_item):
        if isinstance(item_set_enc, (list, tuple)):
            v = torch.tensor(item_set_enc, dtype=torch.float32)
        elif torch.is_tensor(item_set_enc):
            v = item_set_enc.float().cpu()
        else:
            raise TypeError("item_set_enc must be list/tuple/tensor")
        if v.ndim != 1:
            v = v.view(-1)
        if len(v) < target_num_item:
            v = torch.cat([v, torch.zeros(target_num_item - len(v))], dim=0)
        elif len(v) > target_num_item:
            v = v[:target_num_item]
        return v.to(dev)

    def _eval_one(spec):
        members = spec['members_idx']
        subset = {u: user_pos_mod[u] for u in members if u in user_pos_mod}
        aligned_itemset = _coerce_itemset(spec['item_set_enc'], num_item)

        hr = float(hr_for_itemset(
            model,
            subset,
            aligned_itemset,
            num_item,
            max_len,
            K
        ))

        rel = None
        base = spec.get('baseline_hr', None)
        if base is not None and base > 0:
            rel = hr / base
        ct = None
        solo = spec.get('solo_hr', None)
        if base is not None and base > 0 and solo is not None:
            ct = (hr / base) - (solo / base)
        return hr, rel, ct

    for epoch in range(1, epochs + 1):
        avg_loss = _train_one_epoch_nan_safe(model, crit, opt, dl)
        if eval_specs:
            parts = []
            for spec in eval_specs:
                try:
                    hr, rel, ct = _eval_one(spec)
                    msg = f"{spec['name']}: HR@{K}={hr:.4f}"
                    base_local = spec.get('baseline_hr', None)
                    if rel is not None:
                        msg += f" | rel={rel:.3f}"
                    elif base_local is not None and base_local == 0:
                        msg += " | rel=NA (baseline=0)"
                    if ct is not None:
                        msg += f" | CT={ct:.3f}"
                except Exception as e:
                    msg = f"{spec.get('name','?')}: EVAL-ERROR {type(e).__name__}: {e}"
                parts.append(msg)
            print(f"Epoch: {epoch:3d} | loss: {avg_loss:.5f} || " + " || ".join(parts), flush=True)
        else:
            print(f"Epoch: {epoch:3d} | Train loss: {avg_loss:.5f}", flush=True)

    return mds, user_pos_mod, None, None, model

In [12]:
import torch
import torch.nn as nn

@torch.no_grad()
def hr_for_itemset(model, user_pos, item_set, num_items, max_len, k=10):
    PAD = 0
    V = None
    for m in model.modules():
        if isinstance(m, nn.Embedding):
            try:
                V = int(m.num_embeddings); break
            except Exception:
                pass
    MASK = 1 if (V is not None and V > 1) else (num_items + 2)
    if V is not None and not (0 <= MASK < V):
        MASK = 1 if V > 1 else 0

    dev = next(model.parameters()).device
    inps, seen_lists = [], []
    for _, seq in user_pos.items():
        if not seq: continue
        cond = seq[:-2] if len(seq) >= 2 else seq[:1]
        s = cond[-max_len:] if len(cond) > max_len else cond
        if len(s) > max_len - 1: s = s[-(max_len - 1):]
        if V is not None:
            s = [int(t) for t in s if 0 <= int(t) < V]
        pad = max_len - len(s) - 1
        if pad < 0: pad = 0
        inps.append([PAD]*pad + list(s) + [MASK])
        seen_lists.append(list(set(s)))
    if not inps: return 0.0

    X = torch.tensor(inps, dtype=torch.long, device=dev)
    B = X.shape[0]

    if V is None:
        V = model(X[:1])[:, -1, :].shape[1]

    seen_mask = torch.zeros((B, V), dtype=torch.bool, device=dev)
    if 0 <= PAD < V: seen_mask[:, PAD] = True
    if 0 <= MASK < V: seen_mask[:, MASK] = True

    row_idx, col_idx = [], []
    for i, items in enumerate(seen_lists):
        valid = [it for it in items if 0 <= it < V]
        if valid:
            row_idx.extend([i]*len(valid))
            col_idx.extend(valid)
    if row_idx:
        seen_mask[torch.tensor(row_idx, device=dev), torch.tensor(col_idx, device=dev)] = True

    target_mask = torch.zeros(V, dtype=torch.bool, device=dev)
    cand = list(set(int(i) for i in item_set))
    cand = [i for i in cand if 0 <= i < V]
    if cand:
        target_mask[torch.tensor(cand, device=dev)] = True
    if 0 <= PAD < V: target_mask[PAD] = False
    if 0 <= MASK < V: target_mask[MASK] = False

    hits = 0; total = 0
    batch_size = 2048
    for start in range(0, B, batch_size):
        end = min(B, start + batch_size)
        Xb = X[start:end]
        logits = model(Xb)[:, -1, :]
        if isinstance(num_items, int) and num_items < logits.shape[1]:
            logits[:, num_items:] = -1e9
        logits = logits.masked_fill(seen_mask[start:end], -1e9)
        topk_idx = torch.topk(logits, k=k, dim=1).indices
        hit_rows = target_mask[topk_idx].any(dim=1)
        hits += hit_rows.sum().item()
        total += Xb.shape[0]
    return float(hits / max(1, total))

@torch.no_grad()
def relative_hr(model_variant, model_baseline, user_pos, item_set, num_items, max_len, k=10):
    g_base = hr_for_itemset(model_baseline, user_pos, item_set, num_items, max_len, k=k)
    g_var  = hr_for_itemset(model_variant,   user_pos, item_set, num_items, max_len, k=k)
    rel    = (g_var / g_base) if g_base > 0 else 0.0
    return rel, g_var, g_base

In [13]:
def score_next_items(model, cond_seq, num_items, max_len):
    """
    Robust scorer that:
    - Detects the model's embedding vocab size by scanning modules for nn.Embedding.
    - Ensures all token ids are within [0, vocab_size).
    - Uses PAD=0 and MASK=1 when possible (common BERT4Rec scheme).
    - Suppresses PAD, MASK, and seen tokens in logits.
    """
    import torch
    import torch.nn as nn

    # 1) Find an embedding to infer vocab_size
    vocab_size = None
    for m in model.modules():
        if isinstance(m, nn.Embedding):
            try:
                vocab_size = int(m.num_embeddings)
                break
            except Exception:
                continue

    PAD = 0
    MASK = 1 if vocab_size is not None and vocab_size > 1 else (num_items + 2)

    # 2) Build masked sequence with strict truncation
    s = cond_seq[-max_len:] if len(cond_seq) > max_len else cond_seq
    if len(s) > max_len - 1:
        s = s[-(max_len - 1):]

    # 3) Sanitize IDs to prevent CUDA device-side asserts
    if vocab_size is not None:
        s = [int(x) for x in s if 0 <= int(x) < vocab_size]
        # Final check for MASK validity
        if not (0 <= MASK < vocab_size):
            MASK = 1 if vocab_size > 1 else 0

    pad = max_len - len(s) - 1
    if pad < 0: pad = 0
    seq_ids = [PAD]*pad + s + [MASK]

    DEVICE = next(model.parameters()).device
    inp = torch.tensor([seq_ids], dtype=torch.long, device=DEVICE)

    # 4) Forward -> logits for the MASK position
    out = model(inp)
    logits = out[0, -1].clone()

    # 5) Suppress invalid or special predictions
    if 0 <= PAD < logits.numel():   logits[PAD] = -1e9
    if 0 <= MASK < logits.numel():  logits[MASK] = -1e9
    for seen in set(s):
        if 0 <= seen < logits.numel():
            logits[seen] = -1e9

    # 6) If num_items provided < logits size, clip tail so metrics only consider real items
    if isinstance(num_items, int) and num_items < logits.numel():
        logits[num_items:] = -1e9

    return logits  # (vocab,)

# âœ… Paper-faithful item_set builders (post-baseline)

In [14]:

import pandas as pd
from collections import Counter

V_TARGET = 10 if 'config' not in globals() else config.get('V', 10)

def _member_mean_topV(ratings_df, member_raw_ids, V):
    sub = ratings_df[ratings_df['userId'].isin(set(member_raw_ids))]
    if sub.empty:
        return []
    means = (sub.groupby('movieId')['rating']
                 .mean().sort_values(ascending=False)
                 .head(V).index.tolist())
    return means

def _member_popular(ratings_df, member_raw_ids, V):
    sub = ratings_df[ratings_df['userId'].isin(set(member_raw_ids))]
    pop = (sub.groupby('movieId')['rating']
                 .agg(['count','mean'])
                 .sort_values(['count','mean'], ascending=[False, False])
                 .head(V).index.tolist())
    return pop

def _global_popular(ratings_df, V):
    pop = (ratings_df.groupby('movieId')['rating']
                 .agg(['count','mean'])
                 .sort_values(['count','mean'], ascending=[False, False])
                 .head(V).index.tolist())
    return pop

def _encode_items(raw_ids, item_encoder):
    enc0 = [item_encoder[r] for r in raw_ids if r in item_encoder]
    return [e + 1 for e in enc0]  # BERT4Rec label space

def build_item_set_for_members(mds, members_idx, V=10):
    ratings_df = mds.df
    user_decoder = mds.user_decoder
    member_raw = [user_decoder[u] for u in members_idx if u in user_decoder]

    raw = _member_mean_topV(ratings_df, member_raw, V)
    if len(raw) < V:
        for iid in _member_popular(ratings_df, member_raw, V*3):
            if iid not in raw:
                raw.append(iid)
            if len(raw) >= V: break
    if len(raw) < V:
        for iid in _global_popular(ratings_df, V*5):
            if iid not in raw:
                raw.append(iid)
            if len(raw) >= V: break

    enc = _encode_items(raw[:V], mds.item_encoder)
    return raw[:V], enc[:V]


In [15]:
import numpy as np

def _cluster_centroids(U, labels, Q):
    """Return (Q x d) centroid matrix in user-embedding space."""
    centroids = []
    for q in range(Q):
        idx = np.where(labels == q)[0]
        # KMeans guarantees non-empty clusters; just in case:
        if len(idx) == 0:
            # fallback: random user as centroid
            centroids.append(U[np.random.randint(len(U))])
        else:
            centroids.append(U[idx].mean(axis=0))
    C = np.vstack(centroids)
    return C

def _normalize_rows(X, eps=1e-12):
    n = np.linalg.norm(X, axis=1, keepdims=True)
    return X / (n + eps)

def _farthest_pair_indices(centroids, metric="cosine"):
    """
    Return indices (i, j) of the two centroids that are maximally distant.
    """
    if metric == "cosine":
        Cn = _normalize_rows(centroids)
        # cosine distance = 1 - cosine similarity
        sims = Cn @ Cn.T
        dists = 1.0 - np.clip(sims, -1.0, 1.0)
    else:
        # Euclidean
        diff = centroids[:, None, :] - centroids[None, :, :]
        dists = np.sqrt((diff * diff).sum(axis=2))
    # ignore diagonal
    np.fill_diagonal(dists, -np.inf)
    ij = np.unravel_index(np.argmax(dists), dists.shape)
    return int(ij[0]), int(ij[1]), float(dists[ij])



# ðŸš€ RQ2 Experiment (baseline â†’ collectives â†’ item_set â†’ interventions)

In [16]:

# ===== Collective experiment params (from paper notebook) =====
N_values = [10, 20, 50]                     # collective sizes
p_values = [0.1, 0.25, 0.5, 0.75, 1.0]      # homogeneity
trials_per_case = 10
TOPK = 10
V_TARGET = config.get('V', 10)
NUM_CLUSTERS = 10

scenarios = [
    ('promote', 'promote', 'both_promote'),
    ('demote',  'demote',  'both_demote'),
    ('promote', 'demote',  'A_promote_B_demote'),
    ('demote',  'promote', 'A_demote_B_promote'),
]

# --- Run baseline (or load already-trained) ---
mds, user_train, user_valid, user_test, baseline_model = prepare_baseline(config)
NUM_ITEMS = resolve_num_items(mds, baseline_model)

U, user_index = compute_user_embeddings(baseline_model, user_train, NUM_ITEMS, config['max_len'])
kmeans = KMeans(n_clusters=NUM_CLUSTERS, random_state=42).fit(U)
labels = kmeans.labels_
# Precompute centroids for the NUM_CLUSTERS user clusters
centroids = _cluster_centroids(U, labels, NUM_CLUSTERS)


results = []

for N in N_values:
    for p in p_values:
        for trial in range(trials_per_case):
            # ---- Build two collectives A and B ----
            if  SEED_MODE == "uniform":
                seedA = np.random.randint(NUM_CLUSTERS)
                seedB = (seedA + np.random.randint(1, NUM_CLUSTERS)) % NUM_CLUSTERS
            elif SEED_MODE == "maxdist":
                # choose the two most distant clusters in user-embedding space (cosine)
                i, j, d = _farthest_pair_indices(centroids, metric="cosine")
                seedA, seedB = i, j
                # (optional) log the distance for debugging
                # print(f"[seeds] maxdist seeds: A={seedA}, B={seedB}, cosine-dist={d:.4f}")
            else:
                raise ValueError(f"Unknown SEED_MODE={SEED_MODE}")

            C1 = build_collective_quota(seedA, labels, user_index, size=N, p=p,seed=42)
            C2 = build_collective_quota(seedB, labels, user_index, size=N, p=p, seed=42)

            # ---- Build fixed item sets from ORIGINAL ratings (via mds.df) ----
            raw_A, enc_A = build_item_set_for_members(mds, C1, V=V_TARGET)
            raw_B, enc_B = build_item_set_for_members(mds, C2, V=V_TARGET)
            print(f"[item_set] A: raw={len(raw_A)} enc={len(enc_A)}  |  B: raw={len(raw_B)} enc={len(enc_B)}")

            # ---- Evaluate BASELINE HR@K for each collective on its own item_set ----
            subset_ALL = user_test  # evaluate over ALL test users per paper
            g1_base = float(hr_for_itemset(baseline_model, subset_ALL, enc_A, NUM_ITEMS, config['max_len'], TOPK))
            g2_base = float(hr_for_itemset(baseline_model, subset_ALL, enc_B, NUM_ITEMS, config['max_len'], TOPK))
            print(g1_base)
            print(g2_base)
            from collections import Counter
            cnt = Counter(i for seq in mds.user_train.values() for i in seq)
            popA = sum(cnt[i] for i in enc_A)/len(enc_A)
            popB = sum(cnt[i] for i in enc_B)/len(enc_B)
            print("Mean frequency â€” A:", popA, " | B:", popB)
            # ---- Single-collective interventions ----
            # Edit ratings for A or B on their own raw item sets, then retrain & eval with SAME encoded sets
            # Helper to retrain and evaluate a single collective
            def run_single(edited_df, members, item_set_enc, name, baseline_hr):
                mds_s, tr_s, va_s, te_s, model_s = retrain_from_baseline_on_ratings(
                    baseline_state_path="baseline_best.pt",
                    ratings_df=edited_df,
                    mds=mds,
                    config=config,
                    epochs=8,
                    eval_split="user_valid",
                    item_set=item_set_enc,
                    eval_specs=[{"name": name, "members_idx": user_test, "item_set_enc": item_set_enc, "baseline_hr": baseline_hr}],
                    k=TOPK
                )
                with torch.no_grad():
                    subset_members = {u: tr_s[u] for u in members if u in tr_s}
                    hr = float(hr_for_itemset(model_s, subset_members, item_set_enc, NUM_ITEMS, config['max_len'], TOPK))
                return hr

            # Prepare edited dataframes
            df0 = mds.df.copy()
            df_A_prom = edit_ratings(df0, C1, raw_A, action="promote", user_decoder=mds.user_decoder)
            df_A_demo = edit_ratings(df0, C1, raw_A, action="demote", user_decoder=mds.user_decoder)
            df_B_prom = edit_ratings(df0, C2, raw_B, action="promote", user_decoder=mds.user_decoder)
            df_B_demo = edit_ratings(df0, C2, raw_B, action="demote", user_decoder=mds.user_decoder)

            # Solo A (promote/demote)
            g1_prom = run_single(df_A_prom, C1, enc_A, 'A', g1_base)
            g1_demo = run_single(df_A_demo, C1, enc_A, 'A', g1_base)
            # Solo B
            g2_prom = run_single(df_B_prom, C2, enc_B, 'B', g2_base)
            g2_demo = run_single(df_B_demo, C2, enc_B, 'B', g2_base)

            # ---- Two-collective interventions (apply both edits then retrain once) ----
            def run_joint(df_A_action, df_B_action, solo_A, solo_B):
                df_joint = edit_ratings(mds.df.copy(), C1, raw_A, action=df_A_action, user_decoder=mds.user_decoder)
                df_joint = edit_ratings(df_joint,        C2, raw_B, action=df_B_action, user_decoder=mds.user_decoder)
                mds_j, tr_j, va_j, te_j, model_j = retrain_from_baseline_on_ratings(
                    baseline_state_path="baseline_best.pt",
                    ratings_df=df_joint,
                    mds=mds,
                    config=config,
                    epochs=10,
                    eval_split="user_valid",
                    eval_specs=[
                        {"name": "A", "members_idx": user_test, "item_set_enc": enc_A, "baseline_hr": g1_base, "solo_hr": solo_A},
                        {"name": "B", "members_idx": user_test, "item_set_enc": enc_B, "baseline_hr": g2_base, "solo_hr": solo_B}
                    ],
                    k=TOPK,
                    # Note: no single item_set here; we will evaluate each collective separately
                )
                subset_C1j = {u: tr_j[u] for u in C1 if u in tr_j}
                hrA = float(hr_for_itemset(model_j, subset_C1j, enc_A, NUM_ITEMS, config['max_len'], TOPK))
                subset_C2j = {u: tr_j[u] for u in C2 if u in tr_j}
                hrB = float(hr_for_itemset(model_j, subset_C2j, enc_B, NUM_ITEMS, config['max_len'], TOPK))
                return hrA, hrB

            # Both promote / both demote / criss-cross
            gA_pp, gB_pp = run_joint("promote", "promote", g1_prom, g2_prom)
            gA_dd, gB_dd = run_joint("demote", "demote", g1_demo, g2_demo)
            gA_pd, gB_pd = run_joint("promote", "demote", g1_prom, g2_demo)
            gA_dp, gB_dp = run_joint("demote", "promote", g1_demo, g2_prom)

            # Collect metrics (relative HR vs baseline per collective)
            def rel(now, base): 
                return (now / base) if base > 0 else float('nan')

            results.append({
                "N": N, "p": p, "trial": trial,
                "gA_base": g1_base, "gB_base": g2_base,
                "gA_prom": g1_prom, "gA_demo": g1_demo,
                "gB_prom": g2_prom, "gB_demo": g2_demo,
                "gA_pp": gA_pp, "gB_pp": gB_pp,
                "gA_dd": gA_dd, "gB_dd": gB_dd,
                "gA_pd": gA_pd, "gB_pd": gB_pd,
                "gA_dp": gA_dp, "gB_dp": gB_dp,
                "rA_prom": rel(g1_prom, g1_base),
                "rA_demo": rel(g1_demo, g1_base),
                "rB_prom": rel(g2_prom, g2_base),
                "rB_demo": rel(g2_demo, g2_base),
                "rA_pp": rel(gA_pp, g1_base),
                "rB_pp": rel(gB_pp, g2_base),
                "rA_dd": rel(gA_dd, g1_base),
                "rB_dd": rel(gB_dd, g2_base),
                "rA_pd": rel(gA_pd, g1_base),
                "rB_pd": rel(gB_pd, g2_base),
                "rA_dp": rel(gA_dp, g1_base),
                "rB_dp": rel(gB_dp, g2_base),
                "size_A": len(C1), "size_B": len(C2)
            })

            # free between trials
            gc.collect()
            if DEVICE == "cuda":
                torch.cuda.empty_cache()

# Save & preview
res_df = pd.DataFrame(results)
res_df.to_csv("rq2_results_relative_hr.csv", index=False)
display(res_df.head())
print("Saved rq2_results_relative_hr.csv")


Loading: C:\Users\david\OneDrive\Desktop\Collective Exp\ml-1m\ratings.dat
Users with sequences: 6040 | Items: 3706


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 48/48 [00:06<00:00,  7.71it/s]


Epoch:   1| Train loss: 7.86362
[VAL] Recall@10=0.0301 | NDCG@10=0.0139


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 48/48 [00:06<00:00,  7.86it/s]


Epoch:   2| Train loss: 7.14407
[VAL] Recall@10=0.0300 | NDCG@10=0.0137


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 48/48 [00:06<00:00,  7.82it/s]


Epoch:   3| Train loss: 6.72533
[VAL] Recall@10=0.0399 | NDCG@10=0.0194


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 48/48 [00:06<00:00,  7.81it/s]


Epoch:   4| Train loss: 6.23599
[VAL] Recall@10=0.0541 | NDCG@10=0.0265


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 48/48 [00:06<00:00,  7.74it/s]


Epoch:   5| Train loss: 5.75535
[VAL] Recall@10=0.0621 | NDCG@10=0.0292


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 48/48 [00:06<00:00,  7.73it/s]


Epoch:   6| Train loss: 5.33052
[VAL] Recall@10=0.0742 | NDCG@10=0.0353


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 48/48 [00:06<00:00,  7.68it/s]


Epoch:   7| Train loss: 4.99274
[VAL] Recall@10=0.0714 | NDCG@10=0.0340


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 48/48 [00:06<00:00,  7.69it/s]


Epoch:   8| Train loss: 4.70466
[VAL] Recall@10=0.0801 | NDCG@10=0.0393


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 48/48 [00:06<00:00,  7.68it/s]


Epoch:   9| Train loss: 4.45841
[VAL] Recall@10=0.0823 | NDCG@10=0.0416


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 48/48 [00:06<00:00,  7.72it/s]


Epoch:  10| Train loss: 4.22730
[VAL] Recall@10=0.0816 | NDCG@10=0.0402


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 48/48 [00:06<00:00,  7.62it/s]


Epoch:  11| Train loss: 4.01124
[VAL] Recall@10=0.0833 | NDCG@10=0.0411


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 48/48 [00:06<00:00,  7.64it/s]


Epoch:  12| Train loss: 3.80999
[VAL] Recall@10=0.0906 | NDCG@10=0.0443


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 48/48 [00:06<00:00,  7.63it/s]


Epoch:  13| Train loss: 3.61568
[VAL] Recall@10=0.0881 | NDCG@10=0.0428


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 48/48 [00:06<00:00,  7.66it/s]


Epoch:  14| Train loss: 3.45412
[VAL] Recall@10=0.0891 | NDCG@10=0.0432


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 48/48 [00:06<00:00,  7.62it/s]


Epoch:  15| Train loss: 3.29249
[VAL] Recall@10=0.0909 | NDCG@10=0.0444


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 48/48 [00:06<00:00,  7.64it/s]


Epoch:  16| Train loss: 3.13714
[VAL] Recall@10=0.0974 | NDCG@10=0.0474


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 48/48 [00:06<00:00,  7.64it/s]


Epoch:  17| Train loss: 2.98054
[VAL] Recall@10=0.1003 | NDCG@10=0.0494
Saved baseline_best.pt


user embeddings: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 6040/6040 [00:30<00:00, 200.85it/s]


[item_set] A: raw=10 enc=10  |  B: raw=10 enc=10
0.07301324503311259
0.11572847682119206
Mean frequency â€” A: 814.4  | B: 480.0


  ckpt = torch.load(baseline_state_path, map_location=dev)


Epoch:   1 | loss: 7.06941 || A: HR@10=0.0793 | rel=1.087
Epoch:   2 | loss: 6.22178 || A: HR@10=0.0450 | rel=0.617
Epoch:   3 | loss: 5.81769 || A: HR@10=0.0530 | rel=0.726
Epoch:   4 | loss: 5.49383 || A: HR@10=0.0543 | rel=0.744
Epoch:   5 | loss: 5.22234 || A: HR@10=0.0603 | rel=0.826
Epoch:   6 | loss: 4.98016 || A: HR@10=0.0578 | rel=0.792
Epoch:   7 | loss: 4.75049 || A: HR@10=0.0523 | rel=0.717
Epoch:   8 | loss: 4.54529 || A: HR@10=0.0480 | rel=0.658


  ckpt = torch.load(baseline_state_path, map_location=dev)


Epoch:   1 | loss: 7.09683 || A: HR@10=0.0689 | rel=0.944
Epoch:   2 | loss: 6.26239 || A: HR@10=0.0580 | rel=0.794
Epoch:   3 | loss: 5.85125 || A: HR@10=0.0807 | rel=1.105
Epoch:   4 | loss: 5.52080 || A: HR@10=0.0792 | rel=1.084
Epoch:   5 | loss: 5.24589 || A: HR@10=0.1214 | rel=1.663
Epoch:   6 | loss: 5.00144 || A: HR@10=0.1414 | rel=1.937
Epoch:   7 | loss: 4.76941 || A: HR@10=0.1946 | rel=2.665
Epoch:   8 | loss: 4.56167 || A: HR@10=0.1555 | rel=2.130


  ckpt = torch.load(baseline_state_path, map_location=dev)


Epoch:   1 | loss: 7.08251 || B: HR@10=0.2774 | rel=2.397
Epoch:   2 | loss: 6.24821 || B: HR@10=0.2898 | rel=2.504
Epoch:   3 | loss: 5.83983 || B: HR@10=0.2387 | rel=2.062
Epoch:   4 | loss: 5.51225 || B: HR@10=0.1944 | rel=1.680
Epoch:   5 | loss: 5.22506 || B: HR@10=0.1540 | rel=1.331
Epoch:   6 | loss: 4.98419 || B: HR@10=0.1729 | rel=1.494
Epoch:   7 | loss: 4.74940 || B: HR@10=0.1567 | rel=1.354
Epoch:   8 | loss: 4.55459 || B: HR@10=0.1423 | rel=1.229


  ckpt = torch.load(baseline_state_path, map_location=dev)


Epoch:   1 | loss: 7.09683 || B: HR@10=0.2478 | rel=2.141
Epoch:   2 | loss: 6.24371 || B: HR@10=0.2445 | rel=2.112
Epoch:   3 | loss: 5.83080 || B: HR@10=0.1436 | rel=1.241
Epoch:   4 | loss: 5.50462 || B: HR@10=0.1222 | rel=1.056
Epoch:   5 | loss: 5.22498 || B: HR@10=0.1115 | rel=0.963
Epoch:   6 | loss: 4.96495 || B: HR@10=0.1093 | rel=0.945
Epoch:   7 | loss: 4.74419 || B: HR@10=0.0947 | rel=0.819
Epoch:   8 | loss: 4.52963 || B: HR@10=0.1002 | rel=0.866


  ckpt = torch.load(baseline_state_path, map_location=dev)


Epoch:   1 | loss: 7.09217 || A: HR@10=0.0717 | rel=0.982 | CT=0.982 || B: HR@10=0.3198 | rel=2.763 | CT=1.899
Epoch:   2 | loss: 6.24641 || A: HR@10=0.0525 | rel=0.719 | CT=0.719 || B: HR@10=0.3776 | rel=3.263 | CT=2.399
Epoch:   3 | loss: 5.84316 || A: HR@10=0.0472 | rel=0.646 | CT=0.646 || B: HR@10=0.4127 | rel=3.566 | CT=2.702
Epoch:   4 | loss: 5.51284 || A: HR@10=0.0431 | rel=0.590 | CT=0.590 || B: HR@10=0.3456 | rel=2.987 | CT=2.123
Epoch:   5 | loss: 5.23572 || A: HR@10=0.0280 | rel=0.383 | CT=0.383 || B: HR@10=0.3587 | rel=3.100 | CT=2.236
Epoch:   6 | loss: 4.98014 || A: HR@10=0.0346 | rel=0.474 | CT=0.474 || B: HR@10=0.3297 | rel=2.849 | CT=1.985
Epoch:   7 | loss: 4.77440 || A: HR@10=0.0255 | rel=0.349 | CT=0.349 || B: HR@10=0.2632 | rel=2.274 | CT=1.410
Epoch:   8 | loss: 4.55201 || A: HR@10=0.0349 | rel=0.479 | CT=0.479 || B: HR@10=0.2729 | rel=2.358 | CT=1.494
Epoch:   9 | loss: 4.36070 || A: HR@10=0.0245 | rel=0.336 | CT=0.336 || B: HR@10=0.2299 | rel=1.986 | CT=1.122
E

  ckpt = torch.load(baseline_state_path, map_location=dev)


Epoch:   1 | loss: 7.11088 || A: HR@10=0.1116 | rel=1.529 | CT=0.159 || B: HR@10=0.2406 | rel=2.079 | CT=2.079
Epoch:   2 | loss: 6.25129 || A: HR@10=0.0725 | rel=0.994 | CT=-0.376 || B: HR@10=0.1583 | rel=1.368 | CT=1.368
Epoch:   3 | loss: 5.83199 || A: HR@10=0.0904 | rel=1.239 | CT=-0.131 || B: HR@10=0.1292 | rel=1.116 | CT=1.116
Epoch:   4 | loss: 5.51357 || A: HR@10=0.1313 | rel=1.799 | CT=0.429 || B: HR@10=0.1078 | rel=0.932 | CT=0.932
Epoch:   5 | loss: 5.22459 || A: HR@10=0.1338 | rel=1.833 | CT=0.463 || B: HR@10=0.0750 | rel=0.648 | CT=0.648
Epoch:   6 | loss: 4.96916 || A: HR@10=0.1332 | rel=1.824 | CT=0.454 || B: HR@10=0.0530 | rel=0.458 | CT=0.458
Epoch:   7 | loss: 4.74195 || A: HR@10=0.1742 | rel=2.386 | CT=1.017 || B: HR@10=0.0409 | rel=0.353 | CT=0.353
Epoch:   8 | loss: 4.55093 || A: HR@10=0.1770 | rel=2.425 | CT=1.055 || B: HR@10=0.0397 | rel=0.343 | CT=0.343
Epoch:   9 | loss: 4.35958 || A: HR@10=0.1774 | rel=2.429 | CT=1.060 || B: HR@10=0.0447 | rel=0.386 | CT=0.386

  ckpt = torch.load(baseline_state_path, map_location=dev)


Epoch:   1 | loss: 7.07615 || A: HR@10=0.0667 | rel=0.914 | CT=0.914 || B: HR@10=0.2681 | rel=2.317 | CT=2.317
Epoch:   2 | loss: 6.23629 || A: HR@10=0.0765 | rel=1.048 | CT=1.048 || B: HR@10=0.3246 | rel=2.805 | CT=2.805
Epoch:   3 | loss: 5.83379 || A: HR@10=0.0876 | rel=1.200 | CT=1.200 || B: HR@10=0.4000 | rel=3.456 | CT=3.456
Epoch:   4 | loss: 5.49756 || A: HR@10=0.1120 | rel=1.533 | CT=1.533 || B: HR@10=0.3372 | rel=2.914 | CT=2.914
Epoch:   5 | loss: 5.21895 || A: HR@10=0.0747 | rel=1.023 | CT=1.023 || B: HR@10=0.3009 | rel=2.600 | CT=2.600
Epoch:   6 | loss: 4.97492 || A: HR@10=0.0566 | rel=0.776 | CT=0.776 || B: HR@10=0.2410 | rel=2.082 | CT=2.082
Epoch:   7 | loss: 4.74894 || A: HR@10=0.1017 | rel=1.393 | CT=1.393 || B: HR@10=0.2014 | rel=1.740 | CT=1.740
Epoch:   8 | loss: 4.53603 || A: HR@10=0.1161 | rel=1.590 | CT=1.590 || B: HR@10=0.2138 | rel=1.848 | CT=1.848
Epoch:   9 | loss: 4.34453 || A: HR@10=0.0899 | rel=1.232 | CT=1.232 || B: HR@10=0.1997 | rel=1.726 | CT=1.726
E

  ckpt = torch.load(baseline_state_path, map_location=dev)


Epoch:   1 | loss: 7.07858 || A: HR@10=0.0711 | rel=0.973 | CT=-0.397 || B: HR@10=0.2738 | rel=2.366 | CT=1.501
Epoch:   2 | loss: 6.24272 || A: HR@10=0.0525 | rel=0.719 | CT=-0.651 || B: HR@10=0.2254 | rel=1.948 | CT=1.084
Epoch:   3 | loss: 5.82747 || A: HR@10=0.0752 | rel=1.030 | CT=-0.340 || B: HR@10=0.1896 | rel=1.639 | CT=0.775
Epoch:   4 | loss: 5.50348 || A: HR@10=0.1086 | rel=1.488 | CT=0.118 || B: HR@10=0.2047 | rel=1.769 | CT=0.905
Epoch:   5 | loss: 5.22602 || A: HR@10=0.1105 | rel=1.513 | CT=0.143 || B: HR@10=0.1708 | rel=1.475 | CT=0.611
Epoch:   6 | loss: 4.97460 || A: HR@10=0.1072 | rel=1.468 | CT=0.098 || B: HR@10=0.1308 | rel=1.131 | CT=0.266
Epoch:   7 | loss: 4.75038 || A: HR@10=0.1143 | rel=1.565 | CT=0.196 || B: HR@10=0.1153 | rel=0.996 | CT=0.132
Epoch:   8 | loss: 4.55779 || A: HR@10=0.1192 | rel=1.633 | CT=0.264 || B: HR@10=0.1123 | rel=0.970 | CT=0.106
Epoch:   9 | loss: 4.35306 || A: HR@10=0.1212 | rel=1.660 | CT=0.291 || B: HR@10=0.1186 | rel=1.025 | CT=0.16

  ckpt = torch.load(baseline_state_path, map_location=dev)


Epoch:   1 | loss: 7.07936 || A: HR@10=0.1226 | rel=1.660
Epoch:   2 | loss: 6.23448 || A: HR@10=0.1477 | rel=2.001
Epoch:   3 | loss: 5.82500 || A: HR@10=0.1727 | rel=2.339
Epoch:   4 | loss: 5.50035 || A: HR@10=0.1666 | rel=2.256
Epoch:   5 | loss: 5.22327 || A: HR@10=0.1534 | rel=2.077
Epoch:   6 | loss: 4.97832 || A: HR@10=0.1300 | rel=1.761
Epoch:   7 | loss: 4.75117 || A: HR@10=0.1717 | rel=2.326
Epoch:   8 | loss: 4.54918 || A: HR@10=0.1363 | rel=1.846


  ckpt = torch.load(baseline_state_path, map_location=dev)


Epoch:   1 | loss: 7.06490 || A: HR@10=0.0396 | rel=0.536
Epoch:   2 | loss: 6.23764 || A: HR@10=0.0363 | rel=0.491
Epoch:   3 | loss: 5.81719 || A: HR@10=0.0293 | rel=0.397
Epoch:   4 | loss: 5.50045 || A: HR@10=0.0341 | rel=0.462
Epoch:   5 | loss: 5.22740 || A: HR@10=0.0187 | rel=0.253
Epoch:   6 | loss: 4.97881 || A: HR@10=0.0126 | rel=0.170
Epoch:   7 | loss: 4.75209 || A: HR@10=0.0089 | rel=0.121
Epoch:   8 | loss: 4.53954 || A: HR@10=0.0137 | rel=0.186


  ckpt = torch.load(baseline_state_path, map_location=dev)


Epoch:   1 | loss: 7.12003 || B: HR@10=0.2860 | rel=2.526
Epoch:   2 | loss: 6.26162 || B: HR@10=0.3683 | rel=3.253
Epoch:   3 | loss: 5.85288 || B: HR@10=0.3400 | rel=3.002
Epoch:   4 | loss: 5.52296 || B: HR@10=0.2978 | rel=2.630
Epoch:   5 | loss: 5.25350 || B: HR@10=0.1918 | rel=1.694
Epoch:   6 | loss: 4.98723 || B: HR@10=0.2570 | rel=2.270
Epoch:   7 | loss: 4.76394 || B: HR@10=0.2471 | rel=2.182
Epoch:   8 | loss: 4.55550 || B: HR@10=0.1792 | rel=1.582


  ckpt = torch.load(baseline_state_path, map_location=dev)


Epoch:   1 | loss: 7.10243 || B: HR@10=0.2999 | rel=2.649
Epoch:   2 | loss: 6.25557 || B: HR@10=0.3059 | rel=2.701
Epoch:   3 | loss: 5.83883 || B: HR@10=0.3485 | rel=3.077
Epoch:   4 | loss: 5.51514 || B: HR@10=0.3177 | rel=2.805
Epoch:   5 | loss: 5.22068 || B: HR@10=0.3427 | rel=3.026
Epoch:   6 | loss: 4.97989 || B: HR@10=0.3900 | rel=3.444
Epoch:   7 | loss: 4.76122 || B: HR@10=0.3367 | rel=2.973
Epoch:   8 | loss: 4.54099 || B: HR@10=0.3400 | rel=3.002


  ckpt = torch.load(baseline_state_path, map_location=dev)


Epoch:   1 | loss: 7.07969 || A: HR@10=0.0676 | rel=0.915 | CT=-0.439 || B: HR@10=0.1910 | rel=1.686 | CT=1.686
Epoch:   2 | loss: 6.23687 || A: HR@10=0.0777 | rel=1.052 | CT=-0.302 || B: HR@10=0.1805 | rel=1.594 | CT=1.594
Epoch:   3 | loss: 5.82567 || A: HR@10=0.1053 | rel=1.426 | CT=0.072 || B: HR@10=0.2034 | rel=1.796 | CT=1.796
Epoch:   4 | loss: 5.50151 || A: HR@10=0.0972 | rel=1.317 | CT=-0.038 || B: HR@10=0.1333 | rel=1.177 | CT=1.177
Epoch:   5 | loss: 5.22486 || A: HR@10=0.1597 | rel=2.162 | CT=0.808 || B: HR@10=0.0863 | rel=0.762 | CT=0.762
Epoch:   6 | loss: 4.97096 || A: HR@10=0.1105 | rel=1.496 | CT=0.142 || B: HR@10=0.0927 | rel=0.819 | CT=0.819
Epoch:   7 | loss: 4.74125 || A: HR@10=0.1115 | rel=1.509 | CT=0.155 || B: HR@10=0.0773 | rel=0.683 | CT=0.683
Epoch:   8 | loss: 4.54191 || A: HR@10=0.1646 | rel=2.229 | CT=0.875 || B: HR@10=0.0654 | rel=0.578 | CT=0.578
Epoch:   9 | loss: 4.34730 || A: HR@10=0.1245 | rel=1.687 | CT=0.332 || B: HR@10=0.0686 | rel=0.605 | CT=0.60

  ckpt = torch.load(baseline_state_path, map_location=dev)


Epoch:   1 | loss: 7.10095 || A: HR@10=0.0686 | rel=0.929 | CT=-0.426 || B: HR@10=0.3258 | rel=2.877 | CT=-0.655
Epoch:   2 | loss: 6.26478 || A: HR@10=0.0391 | rel=0.529 | CT=-0.825 || B: HR@10=0.4712 | rel=4.161 | CT=0.629
Epoch:   3 | loss: 5.84544 || A: HR@10=0.0402 | rel=0.545 | CT=-0.809 || B: HR@10=0.4020 | rel=3.549 | CT=0.017
Epoch:   4 | loss: 5.51226 || A: HR@10=0.0581 | rel=0.787 | CT=-0.567 || B: HR@10=0.3642 | rel=3.216 | CT=-0.316
Epoch:   5 | loss: 5.24319 || A: HR@10=0.0522 | rel=0.707 | CT=-0.648 || B: HR@10=0.3761 | rel=3.321 | CT=-0.211
Epoch:   6 | loss: 4.97361 || A: HR@10=0.0411 | rel=0.556 | CT=-0.798 || B: HR@10=0.3018 | rel=2.665 | CT=-0.868
Epoch:   7 | loss: 4.75082 || A: HR@10=0.0368 | rel=0.498 | CT=-0.856 || B: HR@10=0.2542 | rel=2.245 | CT=-1.287
Epoch:   8 | loss: 4.54622 || A: HR@10=0.0315 | rel=0.426 | CT=-0.928 || B: HR@10=0.2107 | rel=1.860 | CT=-1.672
Epoch:   9 | loss: 4.34523 || A: HR@10=0.0409 | rel=0.554 | CT=-0.800 || B: HR@10=0.2743 | rel=2.4

  ckpt = torch.load(baseline_state_path, map_location=dev)


Epoch:   1 | loss: 7.11051 || A: HR@10=0.0919 | rel=1.245 | CT=-0.109 || B: HR@10=0.3014 | rel=2.662 | CT=-0.870
Epoch:   2 | loss: 6.25316 || A: HR@10=0.1067 | rel=1.444 | CT=0.090 || B: HR@10=0.2579 | rel=2.277 | CT=-1.255
Epoch:   3 | loss: 5.83865 || A: HR@10=0.0975 | rel=1.321 | CT=-0.033 || B: HR@10=0.2393 | rel=2.113 | CT=-1.419
Epoch:   4 | loss: 5.51415 || A: HR@10=0.1100 | rel=1.489 | CT=0.135 || B: HR@10=0.1267 | rel=1.119 | CT=-2.413
Epoch:   5 | loss: 5.22832 || A: HR@10=0.1293 | rel=1.752 | CT=0.397 || B: HR@10=0.1385 | rel=1.223 | CT=-2.310
Epoch:   6 | loss: 4.98012 || A: HR@10=0.1168 | rel=1.581 | CT=0.227 || B: HR@10=0.1105 | rel=0.975 | CT=-2.557
Epoch:   7 | loss: 4.75148 || A: HR@10=0.1204 | rel=1.631 | CT=0.276 || B: HR@10=0.0942 | rel=0.832 | CT=-2.700
Epoch:   8 | loss: 4.53908 || A: HR@10=0.1293 | rel=1.752 | CT=0.397 || B: HR@10=0.1355 | rel=1.196 | CT=-2.336
Epoch:   9 | loss: 4.33815 || A: HR@10=0.1151 | rel=1.559 | CT=0.205 || B: HR@10=0.0679 | rel=0.600 | 

  ckpt = torch.load(baseline_state_path, map_location=dev)


Epoch:   1 | loss: 7.06539 || A: HR@10=0.0502 | rel=0.680 | CT=-0.675 || B: HR@10=0.2983 | rel=2.634 | CT=2.634
Epoch:   2 | loss: 6.23260 || A: HR@10=0.0450 | rel=0.610 | CT=-0.744 || B: HR@10=0.4056 | rel=3.582 | CT=3.582
Epoch:   3 | loss: 5.81739 || A: HR@10=0.0487 | rel=0.659 | CT=-0.695 || B: HR@10=0.4232 | rel=3.737 | CT=3.737
Epoch:   4 | loss: 5.49566 || A: HR@10=0.0417 | rel=0.565 | CT=-0.789 || B: HR@10=0.3649 | rel=3.222 | CT=3.222
Epoch:   5 | loss: 5.22289 || A: HR@10=0.0586 | rel=0.794 | CT=-0.560 || B: HR@10=0.3524 | rel=3.112 | CT=3.112
Epoch:   6 | loss: 4.96161 || A: HR@10=0.0684 | rel=0.926 | CT=-0.428 || B: HR@10=0.4309 | rel=3.805 | CT=3.805
Epoch:   7 | loss: 4.73900 || A: HR@10=0.0614 | rel=0.832 | CT=-0.522 || B: HR@10=0.5185 | rel=4.579 | CT=4.579
Epoch:   8 | loss: 4.52151 || A: HR@10=0.0411 | rel=0.556 | CT=-0.798 || B: HR@10=0.4730 | rel=4.177 | CT=4.177
Epoch:   9 | loss: 4.33242 || A: HR@10=0.0422 | rel=0.572 | CT=-0.782 || B: HR@10=0.4500 | rel=3.974 | C

  ckpt = torch.load(baseline_state_path, map_location=dev)


Epoch:   1 | loss: 7.12041 || A: HR@10=0.0792 | rel=1.028
Epoch:   2 | loss: 6.26812 || A: HR@10=0.1098 | rel=1.426
Epoch:   3 | loss: 5.85965 || A: HR@10=0.0898 | rel=1.166
Epoch:   4 | loss: 5.52721 || A: HR@10=0.0866 | rel=1.125


KeyboardInterrupt: 

In [None]:
# --- Quick plot mirroring the paper notebook ---
summary = (res_df.groupby(['scenario','N_size','p']).agg({'relHR_A':'mean','relHR_B':'mean'}).reset_index()
                    .sort_values(['scenario','N_size','p']))
plt.figure(figsize=(10,6))
for scen in summary["scenario"].unique():
    sub = summary[summary["scenario"]==scen]
    xs = range(len(sub))
    plt.plot(xs, sub["relHR_A"], marker="o", label=f"{scen} (A)")
    plt.plot(xs, sub["relHR_B"], marker="x", label=f"{scen} (B)")
plt.title(f"Relative HR@{TOPK} by scenario (mean over trials)")
plt.xlabel("Condition index (ordered by scenario/N/p)")
plt.ylabel("Mean Relative HR")
plt.legend()
plt.tight_layout()
plt.show()
