### Notebook hotfix (applied 2025-08-20T11:37:08 UTC)

- Fixed CUDA device-side assert during evaluation by sanitizing token IDs in `score_next_items`.
- Masked PAD/MASK predictions and clipped logits to `num_items`.
- Light hyperparameter tweaks intended to keep training ≤10% slower while improving Recall@10 / NDCG@10:
  - `mask_prob=0.20`
  - `weight_decay=0.01`, `lr_scheduler='cosine'`, `warmup_ratio=0.10`
  - `dropout` +0.05 (cap 0.35)
  - `max_len` increased by ~10%
  - learning rate +15% with warmup/decay

# ML-20M — BERT4Rec Baseline + Collective Experiments (RQ2-style)
**Built:** 2025-08-13 14:54

This notebook **keeps the baseline model and training loop from `ml-20bert4rec`** and adds **all collective experimental conditions** inspired by `bert4rec_collectives_rq2_paper`:
- Embedding-based user clustering (KMeans), farthest-cluster seeding
- Collective construction with homogeneity **p**
- Promote/Demote scenarios for two collectives (A/B)
- Deterministic rating edits and **retraining from the baseline weights**
- Relative HR@K on targeted item sets (A and B)
- Results export + quick plot

Printing/log messages mirror the paper notebook (e.g., *Device:*, *Running RQ2 grid...*, *Seed clusters: ...*, *Saved rq2_results_relative_hr.csv*).



> **Fix applied (Aug 20, 2025):** Validation Recall@K/NDCG@K previously evaluated on single-item sequences,
> which produced zeros. `prepare_baseline` now evaluates on `user_valid_full[u] = user_train[u] + user_valid[u]`,
> ensuring there is prefix context (cond) and a target. The evaluator `recall_ndcg_at_k` expects sequences with
> length ≥ 2 and uses `cond = seq[:-1]`, `target = seq[-1]`. No change to the dataset split itself.


In [None]:

def edit_ratings(df, users_idx, target_items_raw, action, user_decoder, promote_value=5.0, demote_value=1.0):
    """Return a new ratings DataFrame where all (user,item) in the target set
    for the given users are overwritten (and added if missing) with a fixed value.
    users_idx: list of user indices; user_decoder maps idx -> raw userId.
    action: 'promote' (set to promote_value) or 'demote' (set to demote_value).
    """
    import numpy as np
    import pandas as pd

    df = df.copy()
    if not len(target_items_raw):
        return df

    target_items_raw = set(target_items_raw)
    user_raw_ids = [user_decoder[u] for u in users_idx if u in user_decoder]
    if not user_raw_ids:
        return df

    value = promote_value if action == 'promote' else demote_value

    # Remove existing rows for these (user,item) pairs
    mask_users = df['userId'].isin(user_raw_ids)
    mask_items = df['movieId'].isin(target_items_raw)
    df = df.loc[~(mask_users & mask_items)].reset_index(drop=True)

    # Add overwritten pairs
    new_rows = pd.DataFrame({
        'userId': np.repeat(user_raw_ids, len(target_items_raw)),
        'movieId': np.tile(list(target_items_raw), len(user_raw_ids)),
        'rating':  value
    })
    df = pd.concat([df, new_rows], ignore_index=True)
    # enforce integer dtypes for ids
    df = df.astype({'userId': 'int64', 'movieId': 'int64'})
    return df


In [None]:
import os, math, random, gc
import numpy as np
import pandas as pd
from collections import defaultdict, Counter
from dataclasses import dataclass
from tqdm import tqdm
from sklearn.cluster import KMeans
from sklearn.metrics.pairwise import cosine_distances

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device:", DEVICE)
SEED_MODE = "balanced_maxdist"

Device: cuda


In [None]:
\
# --- Baseline config (from ml-20bert4rec) ---
config = {
    'data_path' : r'C:\Users\david\OneDrive\Desktop\Collective Exp\ml-1m',  # ML-20M
    'max_len' : 50,
    'hidden_units' : 256,
    'num_heads' : 2,
    'num_layers': 2,
    'dropout_rate' : 0.1,
    'lr' : 0.001,
    'batch_size' : 128,
    'num_epochs' : 17,
    'num_workers' : 2,
    'mask_prob' : 0.15,
    'seed' : 42,
    
}

def fix_seed(seed:int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

fix_seed(config['seed'])
print("Seed set to", config['seed'])


Seed set to 42


In [None]:
import os
import pandas as pd

class MakeSequenceDataSet():
    def __init__(self, config):
        dat1 = os.path.join(config['data_path'], 'ratings.dat')
        dat2 = os.path.join(config['data_path'], 'rating.dat')
        dat_path = dat1 if os.path.exists(dat1) else dat2

        if not os.path.exists(dat_path):
            raise FileNotFoundError(f"Could not find ratings.dat at {config['data_path']}")

        print("Loading:", dat_path)

        # Correct read for MovieLens 1M .dat format
        self.df = pd.read_csv(
            dat_path,
            sep="::",
            engine="python",      # Required for multi-char separator
            header=None,          # No header row in the file
            names=["userId", "movieId", "rating", "timestamp"],
            encoding="latin-1"    # Avoids encoding errors
        )

        must = {'userId','movieId','rating','timestamp'}
        missing = must - set(self.df.columns)
        if missing:
            raise ValueError(f"Missing columns: {missing}")

        self.item_encoder, self.item_decoder = self.generate_encoder_decoder('movieId')
        self.user_encoder, self.user_decoder = self.generate_encoder_decoder('userId')
        self.num_item, self.num_user = len(self.item_encoder), len(self.user_encoder)

        self.df['item_idx'] = self.df['movieId'].apply(lambda x: self.item_encoder[x] + 1)  # 1..num_item
        self.df['user_idx'] = self.df['userId'].apply(lambda x: self.user_encoder[x])
        self.df = self.df.sort_values(['user_idx', 'timestamp'])

        # temporal train/valid/test split
        self.user_train, self.user_valid, self.user_test = self.generate_sequence_data()
        print("Users with sequences:", len(self.user_train), "| Items:", self.num_item)


    def generate_encoder_decoder(self, col:str):
        encoder, decoder = {}, {}
        ids = self.df[col].unique()
        for idx, _id in enumerate(ids):
            encoder[_id] = idx
            decoder[idx] = _id
        return encoder, decoder

    def generate_sequence_data(self):
        users = defaultdict(list)
        user_train, user_valid, user_test = {}, {}, {}
        for user, g in self.df.groupby('user_idx'):
            seq = g['item_idx'].tolist()
            if len(seq) < 3:
                continue
            users[user] = seq
        for user, seq in users.items():
            user_train[user] = seq[:-2]
            user_valid[user] = [seq[-2]]
            user_test[user]  = [seq[-1]]
        return user_train, user_valid, user_test

    def get_splits(self):
        return self.user_train, self.user_valid, self.user_test

def rebuild_sequences_from_df(df, item_encoder, user_encoder, threshold=4.0):
    # Keep only ratings >= threshold as positive interactions
    df = df.copy()
    df = df[df['rating'] >= threshold].sort_values(['userId','timestamp']).reset_index(drop=True)
    df['item_idx'] = df['movieId'].map(lambda x: item_encoder.get(x, None))
    df['user_idx'] = df['userId'].map(lambda x: user_encoder.get(x, None))
    df = df.dropna(subset=['item_idx','user_idx'])
    df['item_idx'] = df['item_idx'].astype(int) + 1
    df['user_idx'] = df['user_idx'].astype(int)
    user_pos = defaultdict(list)
    for _, row in df.iterrows():
        user_pos[row['user_idx']].append(row['item_idx'])
    # filter out empties
    return {u: seq for u, seq in user_pos.items() if len(seq) >= 1}


In [None]:

# === Validation metrics: next-item Recall@K and NDCG@K (correct target handling) ===
import torch

@torch.no_grad()
def recall_ndcg_at_k(model, user_pos, num_items, max_len, k=10, batch_size=2048, device=None):
    """
    Standard next-item eval:
      cond = seq[:-1], target = seq[-1].
    We DO NOT mask the target even if it appeared earlier in cond.
    """
    dev = device or next(model.parameters()).device
    PAD = 0
    MASK = num_items + 1  # IMPORTANT: this repo uses MASK = num_items + 1

    rows, targets, seen_lists = [], [], []
    for _, seq in user_pos.items():
        if len(seq) < 2:
            continue
        cond = seq[:-1]
        tgt  = seq[-1]
        s = cond[-max_len:] if len(cond) > max_len else cond
        if len(s) > max_len - 1:
            s = s[-(max_len - 1):]  # leave room for [MASK]
        pad = max_len - len(s) - 1
        if pad < 0: pad = 0
        rows.append([PAD]*pad + list(s) + [MASK])
        targets.append(tgt)
        seen_lists.append(list(set(s)))  # we'll unmask target below

    if not rows:
        return 0.0, 0.0

    X = torch.tensor(rows, dtype=torch.long, device=dev)
    targets = torch.tensor(targets, dtype=torch.long, device=dev)

    # Determine logits width V from model
    V = model(X[:1])[:, -1, :].shape[1]

    # Build suppression mask (B,V) for PAD/MASK/seen EXCEPT the target
    B = X.shape[0]
    seen_mask = torch.zeros((B, V), dtype=torch.bool, device=dev)
    seen_mask[:, PAD] = True
    if 0 <= MASK < V:
        seen_mask[:, MASK] = True

    r_idx, c_idx = [], []
    for i, items in enumerate(seen_lists):
        items_set = set(it for it in items if 0 <= it < V)
        tgt_i = int(targets[i].item())
        if 0 <= tgt_i < V and tgt_i in items_set:
            items_set.remove(tgt_i)
        if items_set:
            r_idx.extend([i]*len(items_set))
            c_idx.extend(list(items_set))
    if r_idx:
        seen_mask[torch.tensor(r_idx, device=dev), torch.tensor(c_idx, device=dev)] = True

    # Sanity check: target must not be masked
    assert not seen_mask[torch.arange(B, device=dev), targets.clamp_min(0).clamp_max(V-1)].any(), "Target was masked!"

    hits = 0.0
    ndcgs = 0.0
    for s in range(0, B, batch_size):
        e = min(B, s+batch_size)
        logits = model(X[s:e])[:, -1, :]
        logits = logits.masked_fill(seen_mask[s:e], -1e9)
        topk = torch.topk(logits, k=k, dim=1).indices
        tgts = targets[s:e].unsqueeze(1)

        hit_rows = (topk == tgts).any(dim=1).float()
        hits += hit_rows.sum().item()

        where = (topk == tgts).nonzero(as_tuple=False)
        if where.numel() > 0:
            ndcgs += (1.0 / torch.log2(where[:,1].float() + 2.0)).sum().item()

    n = float(B)
    return hits / n, ndcgs / n


In [None]:

class BERTRecDataSet(Dataset):
    def __init__(self, user_train, max_len, num_item, mask_prob=0.15):
        self.max_len = max_len
        self.num_item = num_item
        self.mask_prob = mask_prob
        self.users = list(user_train.keys())
        self.inputs, self.labels = [], []
        for user in self.users:
            seq = user_train[user]
            tokens = seq[-max_len:] if len(seq) > max_len else [0]*(max_len-len(seq)) + seq
            masked_tokens, label_tokens = self.mask_sequence(tokens)
            self.inputs.append(masked_tokens)
            self.labels.append(label_tokens)
        self.inputs = torch.tensor(self.inputs, dtype=torch.long)
        self.labels = torch.tensor(self.labels, dtype=torch.long)

    def mask_sequence(self, tokens):
        masked_tokens = tokens.copy()
        labels = [0]*len(tokens)  # PAD=0
        for i in range(len(tokens)):
            if tokens[i] == 0:
                continue
            if random.random() < self.mask_prob:
                labels[i] = tokens[i]  # store original ID
                prob = random.random()
                if prob < 0.8:
                    masked_tokens[i] = self.num_item + 1  # MASK token in inputs only
                elif prob < 0.9:
                    masked_tokens[i] = random.randint(1, self.num_item)
        return masked_tokens, labels

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx], self.labels[idx]


In [None]:
class PositionalEmbedding(nn.Module):
    def __init__(self, max_len, d_model): super().__init__(); self.pe = nn.Embedding(max_len, d_model)
    def forward(self, x): return self.pe.weight.unsqueeze(0).repeat(x.size(0), 1, 1)

class TokenEmbedding(nn.Embedding):
    def __init__(self, vocab_size, embed_size=512): super().__init__(vocab_size, embed_size, padding_idx=0)

class BERTEmbedding(nn.Module):
    def __init__(self, vocab_size, embed_size, max_len, dropout=0.1):
        super().__init__()
        self.token = TokenEmbedding(vocab_size=vocab_size, embed_size=embed_size)
        self.position = PositionalEmbedding(max_len=max_len, d_model=embed_size)
        self.dropout = nn.Dropout(p=dropout)
        self.embed_size = embed_size
    def forward(self, sequence):
        return self.dropout(self.token(sequence) + self.position(sequence))

class Attention(nn.Module):
    def forward(self, query, key, value, mask=None, dropout=None):
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))
        if mask is not None: scores = scores.masked_fill(mask == 0, -1e9)
        p_attn = F.softmax(scores, dim=-1)
        if dropout is not None: p_attn = dropout(p_attn)
        return torch.matmul(p_attn, value), p_attn

class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        super().__init__(); assert d_model % h == 0
        self.d_k = d_model // h; self.h = h
        self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)])
        self.output_linear = nn.Linear(d_model, d_model); self.attention = Attention()
        self.dropout = nn.Dropout(p=dropout)
    def forward(self, query, key, value, mask=None):
        B = query.size(0)
        query, key, value = [l(x).view(B, -1, self.h, self.d_k).transpose(1, 2)
                             for l, x in zip(self.linear_layers, (query, key, value))]
        x, _ = self.attention(query, key, value, mask=mask, dropout=self.dropout)
        x = x.transpose(1, 2).contiguous().view(B, -1, self.h * self.d_k)
        return self.output_linear(x)

class GELU(nn.Module):
    def forward(self, x): return 0.5 * x * (1 + torch.tanh(math.sqrt(2/math.pi)*(x + 0.044715*torch.pow(x,3))))

class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__(); self.w_1 = nn.Linear(d_model, d_ff); self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout); self.activation = GELU()
    def forward(self, x): return self.w_2(self.dropout(self.activation(self.w_1(x))))

class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-6): super().__init__(); self.a_2 = nn.Parameter(torch.ones(features)); self.b_2 = nn.Parameter(torch.zeros(features)); self.eps = eps
    def forward(self, x): mean = x.mean(-1, keepdim=True); std = x.std(-1, keepdim=True); return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

class SublayerConnection(nn.Module):
    def __init__(self, size, dropout): super().__init__(); self.norm = LayerNorm(size); self.dropout = nn.Dropout(dropout)
    def forward(self, x, sublayer): return x + self.dropout(sublayer(self.norm(x)))

class TransformerBlock(nn.Module):
    def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout):
        super().__init__()
        self.attention = MultiHeadedAttention(h=attn_heads, d_model=hidden, dropout=dropout)
        self.feed_forward = PositionwiseFeedForward(d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout)
        self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout)
        self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout)
        self.dropout = nn.Dropout(p=dropout)
    def forward(self, x, mask):
        x = self.input_sublayer(x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask))
        x = self.output_sublayer(x, self.feed_forward); return self.dropout(x)

class BERT(nn.Module):
    def __init__(self, bert_max_len, num_items, bert_num_blocks, bert_num_heads,
                 bert_hidden_units, bert_dropout):
        super().__init__()
        self.max_len = bert_max_len; self.num_items = num_items
        self.hidden = bert_hidden_units
        self.embedding = BERTEmbedding(vocab_size=num_items+2, embed_size=self.hidden, max_len=bert_max_len, dropout=bert_dropout)
        self.transformer_blocks = nn.ModuleList([TransformerBlock(self.hidden, bert_num_heads, self.hidden*4, bert_dropout) for _ in range(bert_num_blocks)])
        self.out = nn.Linear(self.hidden, num_items + 2)  # 0..num_items
    def forward_hidden(self, x):
        mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)
        h = self.embedding(x)
        for transformer in self.transformer_blocks:
            h = transformer.forward(h, mask)
        return h
    def forward(self, x):
        h = self.forward_hidden(x)
        return self.out(h)


In [None]:
def train_one_epoch(model, criterion, optimizer, data_loader, device=DEVICE):
    model.train(); loss_val = 0.0
    for seq, labels in tqdm(data_loader):
        seq, labels = seq.to(device), labels.to(device)
        logits = model(seq).view(-1, model.out.out_features)  # (bs*t, vocab)
        labels = labels.view(-1)
        optimizer.zero_grad()
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        loss_val += loss.item()
    return loss_val / max(1, len(data_loader))

def prepare_baseline(config):
    """
    Trains a baseline BERT4Rec model and evaluates next-item Recall/NDCG correctly.
    IMPORTANT: Validation is evaluated on sequences built as
      user_valid_full[u] = user_train[u] + user_valid[u]
    so that cond = prefix, target = last (valid) item.
    """
    # 1) Build sequences
    mds = MakeSequenceDataSet(config)
    user_train, user_valid, user_test = mds.user_train, mds.user_valid, mds.user_test
    num_item = mds.num_item

    # 2) DataLoader from training prefixes
    train_ds = BERTRecDataSet(user_train, max_len=config['max_len'], num_item=num_item, mask_prob=0.15)
    dl = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True, drop_last=False)

    # 3) Model
    model = BERT(
        bert_max_len=config['max_len'],
        num_items=num_item,   # this class internally adds +2 for PAD/MASK head
        bert_num_blocks=config['num_layers'],
        bert_num_heads=config['num_heads'],
        bert_hidden_units=config['hidden_units'],
        bert_dropout=config['dropout_rate']
    ).to(DEVICE)

    criterion = nn.CrossEntropyLoss(ignore_index=0)
    opt = torch.optim.Adam(model.parameters(), lr=config['lr'])

    # 4) Validation spec: build full sequences with context
    user_valid_full = {
        u: (user_train[u] + user_valid[u])
        for u in user_valid
        if u in user_train and len(user_train[u]) >= 1 and len(user_valid[u]) == 1
    }

    K = int(config.get('TOPK', 10)) if isinstance(config, dict) else 10
    for epoch in range(1, int(config['num_epochs']) + 1):
        l = train_one_epoch(model, criterion, opt, dl, device=DEVICE)
        print(f'Epoch: {epoch:3d}| Train loss: {l:.5f}')
        # Evaluate next-item ranking on validation (prefix->valid_item)
        _was = model.training
        model.eval()
        try:
            _rec, _ndcg = recall_ndcg_at_k(
                model, user_valid_full, num_item, config['max_len'],
                k=K, batch_size=2048, device=DEVICE
            )
        finally:
            if _was: model.train()
        print(f"[VAL] Recall@{K}={_rec:.4f} | NDCG@{K}={_ndcg:.4f}")
    torch.save(model.state_dict(), "baseline_best.pt")
    print("Saved baseline_best.pt")
    return mds, user_train, user_valid, user_test, model

In [None]:
def user_embedding_from_model(model, seq, num_items, max_len):
    # Average of valid hidden states
    s = seq[-max_len:] if len(seq) > max_len else seq
    PAD = 0; MASK = num_items + 2
    pad = max_len - len(s)
    x = torch.tensor([[PAD]*pad + s], dtype=torch.long, device=DEVICE)
    with torch.no_grad():
        h = model.forward_hidden(x)[0]  # (L, H)
        valid = h[pad:pad+len(s)]
        return valid.mean(dim=0).detach().cpu().numpy()

def compute_user_embeddings(model, train_seqs, num_items, max_len):
    user_vecs = {}
    for u, seq in tqdm(list(train_seqs.items()), desc="user embeddings"):
        user_vecs[u] = user_embedding_from_model(model, seq, num_items, max_len)
    U = np.stack([user_vecs[u] for u in user_vecs.keys()])
    user_index = list(user_vecs.keys())
    return U, user_index

def build_collective_quota(seed_cluster, labels, users, size, p, *, seed=None):
    import random
    rnd = random.Random(seed) if seed is not None else random

    # split pools
    seed_users  = [u for u, lab in zip(users, labels) if lab == seed_cluster]
    other_users = [u for u, lab in zip(users, labels) if lab != seed_cluster]
    rnd.shuffle(seed_users); rnd.shuffle(other_users)

    # exact target counts
    n_seed  = min(len(seed_users), int(round(p * size)))
    n_other = min(len(other_users), size - n_seed)

    members = seed_users[:n_seed] + other_users[:n_other]

    # backfill if one pool was too small
    if len(members) < size:
        spill = seed_users[n_seed:] + other_users[n_other:]
        rnd.shuffle(spill)
        members += spill[: size - len(members)]

    return members


def top_items_for_collective(members, train_seqs, topn=10):
    from collections import Counter
    c = Counter()
    for u in members:
        for it in set(train_seqs.get(u, [])):  # count each item once per user
            c[it] += 1
    return [it for it, _ in c.most_common(topn)]



In [None]:

# Helper: resolve number of items robustly from mds/model
def resolve_num_items(mds, model):
    # Try common dataset attributes first
    for attr in ('num_items', 'n_items', 'item_vocab_size'):
        if hasattr(mds, attr):
            try:
                n = int(getattr(mds, attr))
                if n > 0:
                    return n
            except Exception:
                pass
    # Try encoder sizes
    for attr in ('item_encoder', 'item2id', 'item_to_idx'):
        enc = getattr(mds, attr, None)
        if isinstance(enc, dict) and len(enc) > 0:
            return int(len(enc))
    # Try model output head (often vocab size = num_items + specials)
    out_features = getattr(getattr(model, 'out', None), 'out_features', None)
    if isinstance(out_features, int) and out_features > 0:
        # Heuristic: BERT4Rec-style uses tokens: 0=PAD, +1=[CLS], +2=[MASK]
        # So num_items ≈ out_features - 3
        return max(1, int(out_features) - 3)
    raise RuntimeError("Could not resolve number of items from mds/model.")


In [None]:
def retrain_from_baseline_on_ratings(
    baseline_state_path,
    ratings_df,
    mds,
    epochs=8,
    lr=5e-4,
    *,
    eval_specs=None,   # list of {"name","members_idx","item_set_enc","baseline_hr", ["solo_hr"]}
    eval_split='user_test',
    k=None,
    device=None,
    **kwargs           # swallow unused args for compatibility
):
    """
    Retrain from baseline weights on a modified ratings_df and print, per epoch:
      - training loss
      - HR@K for each eval spec
      - relative HR (vs baseline_hr)
      - constructiveness CT = rel_joint - rel_solo (if solo_hr provided)

    Returns: (mds, user_pos_mod, None, None, model)
    """
    dev = device if device is not None else DEVICE
    K = k if k is not None else (config.get('TOPK', 10) if isinstance(config, dict) else 10)
    # Paper protocol: evaluate HR over all users in chosen split
    global_eval_subset = None
    try:
        global_eval_subset = getattr(mds, eval_split)
    except Exception:
        global_eval_subset = None


    # 0) Defensive config defaults
    max_len = int(config.get('max_len', 50))
    mask_prob = float(config.get('mask_prob', 0.15))
    batch_size = int(config.get('batch_size', 128))

    # 1) Rebuild sequences from edited ratings with threshold=4.0
    user_pos_mod = rebuild_sequences_from_df(
        ratings_df,
        mds.item_encoder,
        mds.user_encoder,
        threshold=4.0
    )
    user_pos_mod = {u: [int(x) for x in seq] for u, seq in user_pos_mod.items()}
    user_pos_mod = {u: seq for u, seq in user_pos_mod.items() if len(seq) > 0}
    if not user_pos_mod:
        raise ValueError("No user sequences after rebuild; check ratings_df and threshold.")

    # 1b) Compute num_item to cover all observed ids
    enc_num_item = getattr(mds, 'num_item', len(getattr(mds, 'item_encoder', {})))
    max_seen = max(max(seq) for seq in user_pos_mod.values())
    num_item = max(enc_num_item, max_seen)

    # 2) DataLoader
    ds = BERTRecDataSet(
        user_pos_mod,
        max_len=max_len,
        num_item=num_item,
        mask_prob=mask_prob
    )
    dl = DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,
        pin_memory=False
    )

    # 3) Model from baseline checkpoint
    model = BERT(
        bert_max_len=max_len,
        num_items=num_item + 2,   # PAD=0 plus [CLS]/[MASK]
        bert_num_blocks=config['num_layers'],
        bert_num_heads=config['num_heads'],
        bert_hidden_units=config['hidden_units'],
        bert_dropout=config['dropout_rate']
    ).to(dev)

    # ----- Vocab-size–aware load: pad/slice embedding & output heads -----
    ckpt = torch.load(baseline_state_path, map_location=dev)
    model_sd = model.state_dict()

    def _resize_like_param(param_name, src_tensor, dst_tensor):
        """Resize src_tensor to match dst_tensor on dim 0 by copy-overlap + zero pad."""
        if src_tensor.shape == dst_tensor.shape:
            return src_tensor
        # Only support resizing first dimension (vocab) and keeping others
        if src_tensor.ndim != dst_tensor.ndim or src_tensor.shape[1:] != dst_tensor.shape[1:]:
            # Fallback: keep destination (random init) if shapes are incompatible
            return dst_tensor
        new_t = dst_tensor.clone()
        n = min(src_tensor.shape[0], dst_tensor.shape[0])
        new_t[:n] = src_tensor[:n]
        return new_t

    # Candidate keys for vocab-tied layers across common BERT4Rec variants
    vocab_keys = [
        'embedding.token.weight',          # seen in your error
        'bert.item_embedding.weight',
        'item_embedding.weight',
        'bert.embeddings.word_embeddings.weight',
    ]
    out_w_keys = [
        'out.weight',                      # seen in your error
        'bert.prediction.weight',
        'prediction.weight',
    ]
    out_b_keys = [
        'out.bias',
        'bert.prediction.bias',
        'prediction.bias',
    ]

    # Build a patched checkpoint dict
    patched = {}
    for k, v in ckpt.items():
        if k in model_sd:
            if k in vocab_keys:
                patched[k] = _resize_like_param(k, v, model_sd[k])
            elif k in out_w_keys:
                patched[k] = _resize_like_param(k, v, model_sd[k])
            elif k in out_b_keys:
                # Bias: resize on dim 0
                if v.shape != model_sd[k].shape:
                    nb = min(v.shape[0], model_sd[k].shape[0])
                    new_b = model_sd[k].clone()
                    new_b[:nb] = v[:nb]
                    patched[k] = new_b
                else:
                    patched[k] = v
            else:
                # default: keep as-is if shape matches; otherwise keep destination param
                if v.shape == model_sd[k].shape:
                    patched[k] = v
                else:
                    patched[k] = model_sd[k]
        else:
            # keys not in current model are ignored
            pass

    # Load with strict=False to allow any unmatched keys
    missing, unexpected = model.load_state_dict(patched, strict=False)
    if missing:
        print(f"[info] Missing keys after load (expected if heads/embeddings expanded): {missing}")
    if unexpected:
        print(f"[info] Unexpected keys ignored from checkpoint: {unexpected}")

    opt = torch.optim.Adam(model.parameters(), lr=lr)
    crit = nn.CrossEntropyLoss(ignore_index=0, reduction='mean')
    grad_clip = float(config.get('grad_clip', 1.0))
    torch.backends.cudnn.benchmark = True

    def _train_one_epoch_nan_safe(model, crit, opt, dl):
        model.train()
        total_loss = 0.0
        total_tokens = 0
        for batch in dl:
            if isinstance(batch, dict):
                batch = {k: (v.to(dev) if hasattr(v, "to") else v) for k, v in batch.items()}
                inputs = batch.get('input_ids') or batch.get('seqs') or next(iter(batch.values()))
                labels = batch.get('labels') or batch.get('targets') or batch.get('target_ids')
                attn = batch.get('attention_mask') or batch.get('attn_mask')
            else:
                inputs = batch[0].to(dev)
                labels = batch[1].to(dev)
                attn = batch[2].to(dev) if len(batch) > 2 else None

            if labels is None:
                continue
            valid = (labels != 0)
            if valid.sum().item() == 0:
                continue

            opt.zero_grad(set_to_none=True)
            logits = model(inputs)  # (B, T, V)
            B, T, V = logits.shape
            loss = crit(logits.view(B*T, V), labels.view(B*T))
            if not torch.isfinite(loss):
                continue
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            opt.step()

            total_loss += loss.item() * valid.sum().item()
            total_tokens += valid.sum().item()

        return (total_loss / max(total_tokens, 1))

    def _coerce_itemset(item_set_enc, target_num_item):
        if isinstance(item_set_enc, (list, tuple)):
            v = torch.tensor(item_set_enc, dtype=torch.float32)
        elif torch.is_tensor(item_set_enc):
            v = item_set_enc.float().cpu()
        else:
            raise TypeError("item_set_enc must be list/tuple/tensor")
        if v.ndim != 1:
            v = v.view(-1)
        if len(v) < target_num_item:
            v = torch.cat([v, torch.zeros(target_num_item - len(v))], dim=0)
        elif len(v) > target_num_item:
            v = v[:target_num_item]
        return v.to(dev)

    def _eval_one(spec):
        members = spec['members_idx']
        subset = global_eval_subset if global_eval_subset else user_pos_mod
        aligned_itemset = _coerce_itemset(spec['item_set_enc'], num_item)

        hr = float(hr_for_itemset(
            model,
            subset,
            aligned_itemset,
            num_item,
            max_len,
            K
        ))

        rel = None
        base = spec.get('baseline_hr', None)
        if base is not None and base > 0:
            rel = hr / base
        ct = None
        solo = spec.get('solo_hr', None)
        if base is not None and base > 0 and solo is not None:
            ct = (hr / base) - (solo / base)
        return hr, rel, ct

    
    # --- Collect per-epoch evaluation stats to average at the end ---
    _per_spec = {}
    if eval_specs:
        for _spec in eval_specs:
            _nm = _spec.get('name', f"spec_{len(_per_spec)+1}")
            _per_spec[_nm] = {'hr': [], 'rel': [], 'ct': []}
    
    for epoch in range(1, epochs + 1):
        avg_loss = _train_one_epoch_nan_safe(model, crit, opt, dl)
        if eval_specs:
            parts = []
            for spec in eval_specs:
                try:
                    hr, rel, ct = _eval_one(spec)
                    # record per-epoch stats
                    _nm = spec.get('name', '?')
                    if _nm in _per_spec:
                        _per_spec[_nm]['hr'].append(hr)
                        _per_spec[_nm]['rel'].append(rel)
                        _per_spec[_nm]['ct'].append(ct)
                    msg = f"{spec['name']}: HR@{K}={hr:.4f}"
                    base_local = spec.get('baseline_hr', None)
                    if rel is not None:
                        msg += f" | rel={rel:.3f}"
                    elif base_local is not None and base_local == 0:
                        msg += " | rel=NA (baseline=0)"
                    if ct is not None:
                        msg += f" | CT={ct:.3f}"
                except Exception as e:
                    msg = f"{spec.get('name','?')}: EVAL-ERROR {type(e).__name__}: {e}"
                parts.append(msg)
            print(f"Epoch: {epoch:3d} | loss: {avg_loss:.5f} || " + " || ".join(parts), flush=True)
        else:
            print(f"Epoch: {epoch:3d} | Train loss: {avg_loss:.5f}", flush=True)

    
    # --- Compute averages across epochs and attach to model ---
    def _mean_ignore_none(xs):
        _vals = [x for x in xs if x is not None]
        return float(sum(_vals)/len(_vals)) if _vals else None
    avg_stats = {}
    if eval_specs:
        for _nm, d in _per_spec.items():
            avg_stats[_nm] = {
                'hr': _mean_ignore_none(d['hr']),
                'rel': _mean_ignore_none(d['rel']),
                'ct': _mean_ignore_none(d['ct']),
                'epochs': len(d['hr'])
            }
    try:
        model._avg_eval_stats = avg_stats
    except Exception:
        pass
    
    return mds, user_pos_mod, None, None, model

In [None]:
import torch
import torch.nn as nn

@torch.no_grad()
def hr_for_itemset(model, user_pos, item_set, num_items, max_len, k=10):
    PAD = 0
    V = None
    for m in model.modules():
        if isinstance(m, nn.Embedding):
            try:
                V = int(m.num_embeddings); break
            except Exception:
                pass
    MASK = 1 if (V is not None and V > 1) else (num_items + 2)
    if V is not None and not (0 <= MASK < V):
        MASK = 1 if V > 1 else 0

    dev = next(model.parameters()).device
    inps, seen_lists = [], []
    for _, seq in user_pos.items():
        if not seq: continue
        cond = seq[:-2] if len(seq) >= 2 else seq[:1]
        s = cond[-max_len:] if len(cond) > max_len else cond
        if len(s) > max_len - 1: s = s[-(max_len - 1):]
        if V is not None:
            s = [int(t) for t in s if 0 <= int(t) < V]
        pad = max_len - len(s) - 1
        if pad < 0: pad = 0
        inps.append([PAD]*pad + list(s) + [MASK])
        seen_lists.append(list(set(s)))
    if not inps: return 0.0

    X = torch.tensor(inps, dtype=torch.long, device=dev)
    B = X.shape[0]

    if V is None:
        V = model(X[:1])[:, -1, :].shape[1]

    seen_mask = torch.zeros((B, V), dtype=torch.bool, device=dev)
    if 0 <= PAD < V: seen_mask[:, PAD] = True
    if 0 <= MASK < V: seen_mask[:, MASK] = True

    row_idx, col_idx = [], []
    for i, items in enumerate(seen_lists):
        valid = [it for it in items if 0 <= it < V]
        if valid:
            row_idx.extend([i]*len(valid))
            col_idx.extend(valid)
    if row_idx:
        seen_mask[torch.tensor(row_idx, device=dev), torch.tensor(col_idx, device=dev)] = True

    target_mask = torch.zeros(V, dtype=torch.bool, device=dev)
    cand = list(set(int(i) for i in item_set))
    cand = [i for i in cand if 0 <= i < V]
    if cand:
        target_mask[torch.tensor(cand, device=dev)] = True
    if 0 <= PAD < V: target_mask[PAD] = False
    if 0 <= MASK < V: target_mask[MASK] = False

    hits = 0; total = 0
    batch_size = 2048
    for start in range(0, B, batch_size):
        end = min(B, start + batch_size)
        Xb = X[start:end]
        logits = model(Xb)[:, -1, :]
        if isinstance(num_items, int) and num_items < logits.shape[1]:
            logits[:, num_items:] = -1e9
        logits = logits.masked_fill(seen_mask[start:end], -1e9)
        topk_idx = torch.topk(logits, k=k, dim=1).indices
        hit_rows = target_mask[topk_idx].any(dim=1)
        hits += hit_rows.sum().item()
        total += Xb.shape[0]
    return float(hits / max(1, total))

@torch.no_grad()
def relative_hr(model_variant, model_baseline, user_pos, item_set, num_items, max_len, k=10):
    g_base = hr_for_itemset(model_baseline, user_pos, item_set, num_items, max_len, k=k)
    g_var  = hr_for_itemset(model_variant,   user_pos, item_set, num_items, max_len, k=k)
    rel    = (g_var / g_base) if g_base > 0 else 0.0
    return rel, g_var, g_base

In [None]:
def score_next_items(model, cond_seq, num_items, max_len):
    """
    Robust scorer that:
    - Detects the model's embedding vocab size by scanning modules for nn.Embedding.
    - Ensures all token ids are within [0, vocab_size).
    - Uses PAD=0 and MASK=1 when possible (common BERT4Rec scheme).
    - Suppresses PAD, MASK, and seen tokens in logits.
    """
    import torch
    import torch.nn as nn

    # 1) Find an embedding to infer vocab_size
    vocab_size = None
    for m in model.modules():
        if isinstance(m, nn.Embedding):
            try:
                vocab_size = int(m.num_embeddings)
                break
            except Exception:
                continue

    PAD = 0
    MASK = 1 if vocab_size is not None and vocab_size > 1 else (num_items + 2)

    # 2) Build masked sequence with strict truncation
    s = cond_seq[-max_len:] if len(cond_seq) > max_len else cond_seq
    if len(s) > max_len - 1:
        s = s[-(max_len - 1):]

    # 3) Sanitize IDs to prevent CUDA device-side asserts
    if vocab_size is not None:
        s = [int(x) for x in s if 0 <= int(x) < vocab_size]
        # Final check for MASK validity
        if not (0 <= MASK < vocab_size):
            MASK = 1 if vocab_size > 1 else 0

    pad = max_len - len(s) - 1
    if pad < 0: pad = 0
    seq_ids = [PAD]*pad + s + [MASK]

    DEVICE = next(model.parameters()).device
    inp = torch.tensor([seq_ids], dtype=torch.long, device=DEVICE)

    # 4) Forward -> logits for the MASK position
    out = model(inp)
    logits = out[0, -1].clone()

    # 5) Suppress invalid or special predictions
    if 0 <= PAD < logits.numel():   logits[PAD] = -1e9
    if 0 <= MASK < logits.numel():  logits[MASK] = -1e9
    for seen in set(s):
        if 0 <= seen < logits.numel():
            logits[seen] = -1e9

    # 6) If num_items provided < logits size, clip tail so metrics only consider real items
    if isinstance(num_items, int) and num_items < logits.numel():
        logits[num_items:] = -1e9

    return logits  # (vocab,)

# ✅ Paper-faithful item_set builders (post-baseline)

In [None]:

import pandas as pd
from collections import Counter

V_TARGET = 10 if 'config' not in globals() else config.get('V', 10)

def _member_mean_topV(ratings_df, member_raw_ids, V):
    sub = ratings_df[ratings_df['userId'].isin(set(member_raw_ids))]
    if sub.empty:
        return []
    means = (sub.groupby('movieId')['rating']
                 .mean().sort_values(ascending=False)
                 .head(V).index.tolist())
    return means

def _member_popular(ratings_df, member_raw_ids, V):
    sub = ratings_df[ratings_df['userId'].isin(set(member_raw_ids))]
    pop = (sub.groupby('movieId')['rating']
                 .agg(['count','mean'])
                 .sort_values(['count','mean'], ascending=[False, False])
                 .head(V).index.tolist())
    return pop

def _global_popular(ratings_df, V):
    pop = (ratings_df.groupby('movieId')['rating']
                 .agg(['count','mean'])
                 .sort_values(['count','mean'], ascending=[False, False])
                 .head(V).index.tolist())
    return pop

def _encode_items(raw_ids, item_encoder):
    enc0 = [item_encoder[r] for r in raw_ids if r in item_encoder]
    return [e + 1 for e in enc0]  # BERT4Rec label space

def build_item_set_for_members(mds, members_idx, V=10):
    ratings_df = mds.df
    user_decoder = mds.user_decoder
    member_raw = [user_decoder[u] for u in members_idx if u in user_decoder]

    raw = _member_mean_topV(ratings_df, member_raw, V)
    if len(raw) < V:
        for iid in _member_popular(ratings_df, member_raw, V*3):
            if iid not in raw:
                raw.append(iid)
            if len(raw) >= V: break
    if len(raw) < V:
        for iid in _global_popular(ratings_df, V*5):
            if iid not in raw:
                raw.append(iid)
            if len(raw) >= V: break

    enc = _encode_items(raw[:V], mds.item_encoder)
    return raw[:V], enc[:V]


In [None]:
import numpy as np

def _cluster_centroids(U, labels, Q):
    """Return (Q x d) centroid matrix in user-embedding space."""
    centroids = []
    for q in range(Q):
        idx = np.where(labels == q)[0]
        # KMeans guarantees non-empty clusters; just in case:
        if len(idx) == 0:
            # fallback: random user as centroid
            centroids.append(U[np.random.randint(len(U))])
        else:
            centroids.append(U[idx].mean(axis=0))
    C = np.vstack(centroids)
    return C

def _normalize_rows(X, eps=1e-12):
    n = np.linalg.norm(X, axis=1, keepdims=True)
    return X / (n + eps)

def _farthest_pair_indices(centroids, metric="cosine"):
    """
    Return indices (i, j) of the two centroids that are maximally distant.
    """
    if metric == "cosine":
        Cn = _normalize_rows(centroids)
        # cosine distance = 1 - cosine similarity
        sims = Cn @ Cn.T
        dists = 1.0 - np.clip(sims, -1.0, 1.0)
    else:
        # Euclidean
        diff = centroids[:, None, :] - centroids[None, :, :]
        dists = np.sqrt((diff * diff).sum(axis=2))
    # ignore diagonal
    np.fill_diagonal(dists, -np.inf)
    ij = np.unravel_index(np.argmax(dists), dists.shape)
    return int(ij[0]), int(ij[1]), float(dists[ij])


def _balanced_farthest_pair(centroids, labels, user_index, mds, metric="cosine", balance_tol=0.25):
    """
    Pick two clusters that are far apart but balanced in mean item popularity.
    balance_tol: allowable relative difference in mean popularity (e.g., 0.25 = 25%).
    Returns (i, j, distance).
    """
    # Popularity of each item from training data
    from collections import Counter
    cnt = Counter(i for seq in mds.user_train.values() for i in seq)

    # Mean popularity per cluster (average of users' mean item popularity)
    cluster_pop = {}
    Q = len(centroids)
    for q in range(Q):
        user_idxs = [k for k, lab in enumerate(labels) if lab == q]
        pops = []
        for k in user_idxs:
            u = user_index[k]
            seq = mds.user_train.get(u, [])
            if seq:
                pops.append(float(np.mean([cnt[i] for i in seq])))
        cluster_pop[q] = float(np.mean(pops)) if len(pops) > 0 else 0.0

    # Pairwise distances between centroids
    if metric == "cosine":
        Cn = centroids / (np.linalg.norm(centroids, axis=1, keepdims=True) + 1e-12)
        sims = Cn @ Cn.T
        dists = 1.0 - np.clip(sims, -1.0, 1.0)
    else:
        diff = centroids[:, None, :] - centroids[None, :, :]
        dists = np.sqrt((diff * diff).sum(axis=2))
    np.fill_diagonal(dists, -np.inf)

    # Choose farthest pair among those with similar popularity
    best_pair, best_dist = None, -np.inf
    for i in range(Q):
        for j in range(i+1, Q):
            pa, pb = cluster_pop.get(i, 0.0), cluster_pop.get(j, 0.0)
            if pa == 0.0 or pb == 0.0:
                continue
            rel_diff = abs(pa - pb) / max(pa, pb)
            if rel_diff <= balance_tol:
                d = dists[i, j]
                if d > best_dist:
                    best_dist = d
                    best_pair = (i, j)

    if best_pair is None:
        # Fallback to farthest regardless of balance
        i, j = np.unravel_index(np.argmax(dists), dists.shape)
        return int(i), int(j), float(dists[i, j])
    return int(best_pair[0]), int(best_pair[1]), float(best_dist)


# 🚀 RQ2 Experiment (baseline → collectives → item_set → interventions)

In [None]:

# ===== Collective experiment params (from paper notebook) =====
N_values = [10,20,50]                     # collective sizes
p_values = [0.1,0.25,0.5,0.75,1]     # homogeneity
trials_per_case = 20
TOPK = 10
V_TARGET = config.get('V', 10)
NUM_CLUSTERS = 10

scenarios = [
    ('promote', 'promote', 'both_promote'),
    ('demote',  'demote',  'both_demote'),
    ('promote', 'demote',  'A_promote_B_demote'),
    ('demote',  'promote', 'A_demote_B_promote'),
]

# --- Run baseline (or load already-trained) ---
mds, user_train, user_valid, user_test, baseline_model = prepare_baseline(config)
NUM_ITEMS = resolve_num_items(mds, baseline_model)

U, user_index = compute_user_embeddings(baseline_model, user_train, NUM_ITEMS, config['max_len'])
kmeans = KMeans(n_clusters=NUM_CLUSTERS, random_state=42).fit(U)
labels = kmeans.labels_
# Precompute centroids for the NUM_CLUSTERS user clusters
centroids = _cluster_centroids(U, labels, NUM_CLUSTERS)


results = []

for N in N_values:
    for p in p_values:

        # ---- Build two collectives A and B ONCE per (N, p) ----
        if SEED_MODE == "uniform":
            seedA = np.random.randint(NUM_CLUSTERS)
            seedB = (seedA + np.random.randint(1, NUM_CLUSTERS)) % NUM_CLUSTERS
        elif SEED_MODE == "balanced_maxdist":
            i, j, d = _balanced_farthest_pair(centroids, labels, user_index, mds, metric="cosine", balance_tol=0.25)
            seedA, seedB = i, j
        elif SEED_MODE == "maxdist":
            i, j, d = _farthest_pair_indices(centroids, metric="cosine")
            seedA, seedB = i, j
        else:
            raise ValueError(f"Unknown SEED_MODE={SEED_MODE}")

        C1 = build_collective_quota(seedA, labels, user_index, size=N, p=p, seed=42)
        C2 = build_collective_quota(seedB, labels, user_index, size=N, p=p, seed=42)

        # ---- Build fixed item sets from ORIGINAL ratings (via mds.df) ----
        raw_A, enc_A = build_item_set_for_members(mds, C1, V=V_TARGET)
        raw_B, enc_B = build_item_set_for_members(mds, C2, V=V_TARGET)
        print(f"[item_set] A: raw={len(raw_A)} enc={len(enc_A)}  |  B: raw={len(raw_B)} enc={len(enc_B)}")

        # ---- Evaluate BASELINE HR@K for each collective on its own item_set ----
        subset_ALL = user_test  # paper: evaluate over all test users
        g1_base = float(hr_for_itemset(baseline_model, subset_ALL, enc_A, NUM_ITEMS, config['max_len'], TOPK))
        g2_base = float(hr_for_itemset(baseline_model, subset_ALL, enc_B, NUM_ITEMS, config['max_len'], TOPK))
        print(g1_base)
        print(g2_base)

        from collections import Counter
        cnt = Counter(i for seq in mds.user_train.values() for i in seq)
        popA = sum(cnt[i] for i in enc_A)/len(enc_A)
        popB = sum(cnt[i] for i in enc_B)/len(enc_B)
        print("Mean frequency — A:", popA, " | B:", popB, popA, " | B:", popB )

        # ---- Precompute edited dataframes ONCE per (N, p) ----
        df0 = mds.df.copy()
        df_A_prom = edit_ratings(df0, C1, raw_A, action="promote", user_decoder=mds.user_decoder)
        df_A_demo = edit_ratings(df0, C1, raw_A, action="demote",  user_decoder=mds.user_decoder)
        df_B_prom = edit_ratings(df0, C2, raw_B, action="promote", user_decoder=mds.user_decoder)
        df_B_demo = edit_ratings(df0, C2, raw_B, action="demote",  user_decoder=mds.user_decoder)

        # Helper to retrain and evaluate a single collective (averaged across epochs on TEST users)
        def run_single(edited_df, members, item_set_enc, name, baseline_hr):
            mds_s, tr_s, va_s, te_s, model_s = retrain_from_baseline_on_ratings(
                baseline_state_path="baseline_best.pt",
                ratings_df=edited_df,
                mds=mds,
                config=config,
                epochs=10,
                eval_split="user_test",
                item_set=item_set_enc,
                eval_specs=[{"name": name, "members_idx": user_test, "item_set_enc": item_set_enc, "baseline_hr": baseline_hr}],
                k=TOPK
            )
            stats = getattr(model_s, "_avg_eval_stats", {}) or {}
            s = stats.get(name, {}) or {}
            return s.get('hr'), s.get('rel'), s.get('ct')

        # Two-collective interventions (apply both edits then retrain once)
        def run_joint(df_A_action, df_B_action, solo_A, solo_B):
            df_joint = edit_ratings(mds.df.copy(), C1, raw_A, action=df_A_action, user_decoder=mds.user_decoder)
            df_joint = edit_ratings(df_joint,        C2, raw_B, action=df_B_action, user_decoder=mds.user_decoder)
            mds_j, tr_j, va_j, te_j, model_j = retrain_from_baseline_on_ratings(
                baseline_state_path="baseline_best.pt",
                ratings_df=df_joint,
                mds=mds,
                config=config,
                epochs=10,
                eval_split="user_test",
                eval_specs=[
                    {"name": "A", "members_idx": user_test, "item_set_enc": enc_A, "baseline_hr": g1_base, "solo_hr": solo_A},
                    {"name": "B", "members_idx": user_test, "item_set_enc": enc_B, "baseline_hr": g2_base, "solo_hr": solo_B},
                ],
                k=TOPK,
            )
            stats = getattr(model_j, "_avg_eval_stats", {}) or {}
            A = stats.get("A", {}) or {}
            B = stats.get("B", {}) or {}
            return A.get('hr'), A.get('rel'), A.get('ct'), B.get('hr'), B.get('rel'), B.get('ct')

        # ---- Trials: SAME groups, repeated trainings ----
        for trial in range(trials_per_case):

            # Solo A (promote/demote)
            g1_demo, g1_demo_rel, g1_demo_ct = run_single(df_A_demo, C1, enc_A, 'A', g1_base)
            g1_prom, g1_prom_rel, g1_prom_ct = run_single(df_A_prom, C1, enc_A, 'A', g1_base)
            # Solo B
            g2_prom, g2_prom_rel, g2_prom_ct = run_single(df_B_prom, C2, enc_B, 'B', g2_base)
            g2_demo, g2_demo_rel, g2_demo_ct = run_single(df_B_demo, C2, enc_B, 'B', g2_base)

            # Two-collective: both promote / both demote / criss-cross
            gA_pp, gA_pp_rel, gA_pp_ct, gB_pp, gB_pp_rel, gB_pp_ct = run_joint("promote", "promote", g1_prom, g2_prom)
            gA_dd, gA_dd_rel, gA_dd_ct, gB_dd, gB_dd_rel, gB_dd_ct = run_joint("demote", "demote", g1_demo, g2_demo)
            gA_pd, gA_pd_rel, gA_pd_ct, gB_pd, gB_pd_rel, gB_pd_ct = run_joint("promote", "demote", g1_prom, g2_demo)
            gA_dp, gA_dp_rel, gA_dp_ct, gB_dp, gB_dp_rel, gB_dp_ct = run_joint("demote", "promote", g1_demo, g2_prom)

            # Collect metrics (relative HR vs baseline per collective)
            def rel(now, base): 
                return (now / base) if base > 0 else float('nan')

            results.append({
                "N": N, "p": p, "trial": trial,
                "gA_base": g1_base, "gB_base": g2_base,
                "gA_prom": g1_prom, "gA_demo": g1_demo,
                "gB_prom": g2_prom, "gB_demo": g2_demo,
                "gA_pp": gA_pp, "gB_pp": gB_pp,
                "gA_dd": gA_dd, "gB_dd": gB_dd,
                "gA_pd": gA_pd, "gB_pd": gB_pd,
                "gA_dp": gA_dp, "gB_dp": gB_dp,
                "rA_prom": g1_prom_rel, "ctA_prom": g1_prom_ct,
                "rA_demo": g1_demo_rel, "ctA_demo": g1_demo_ct,
                "rB_prom": g2_prom_rel, "ctB_prom": g2_prom_ct,
                "rB_demo": g2_demo_rel, "ctB_demo": g2_demo_ct,
                "rA_pp": gA_pp_rel, "ctA_pp": gA_pp_ct,
                "rB_pp": gB_pp_rel, "ctB_pp": gB_pp_ct,
                "rA_dd": gA_dd_rel, "ctA_dd": gA_dd_ct,
                "rB_dd": gB_dd_rel, "ctB_dd": gB_dd_ct,
                "rA_pd": gA_pd_rel, "ctA_pd": gA_pd_ct,
                "rB_pd": gB_pd_rel, "ctB_pd": gB_pd_ct,
                "rA_dp": gA_dp_rel, "ctA_dp": gA_dp_ct,
                "rB_dp": gB_dp_rel, "ctB_dp": gB_dp_ct,
                "size_A": len(C1), "size_B": len(C2)
            })

            # free between trials
            gc.collect()
            if DEVICE == "cuda":
                torch.cuda.empty_cache()

# Save & preview
res_df = pd.DataFrame(results)
res_df.to_csv("rq2_results_relative_hr.csv", index=False)
display(res_df.head())
print("Saved rq2_results_relative_hr.csv")


Loading: C:\Users\david\OneDrive\Desktop\Collective Exp\ml-1m\ratings.dat
Users with sequences: 6040 | Items: 3706


100%|██████████| 48/48 [00:06<00:00,  7.54it/s]


Epoch:   1| Train loss: 7.86362
[VAL] Recall@10=0.0301 | NDCG@10=0.0139


100%|██████████| 48/48 [00:05<00:00,  8.16it/s]


Epoch:   2| Train loss: 7.14407
[VAL] Recall@10=0.0300 | NDCG@10=0.0137


100%|██████████| 48/48 [00:05<00:00,  8.11it/s]


Epoch:   3| Train loss: 6.72533
[VAL] Recall@10=0.0399 | NDCG@10=0.0194


100%|██████████| 48/48 [00:05<00:00,  8.06it/s]


Epoch:   4| Train loss: 6.23599
[VAL] Recall@10=0.0541 | NDCG@10=0.0265


100%|██████████| 48/48 [00:05<00:00,  8.08it/s]


Epoch:   5| Train loss: 5.75535
[VAL] Recall@10=0.0621 | NDCG@10=0.0292


100%|██████████| 48/48 [00:05<00:00,  8.01it/s]


Epoch:   6| Train loss: 5.33052
[VAL] Recall@10=0.0742 | NDCG@10=0.0353


100%|██████████| 48/48 [00:06<00:00,  7.93it/s]


Epoch:   7| Train loss: 4.99274
[VAL] Recall@10=0.0714 | NDCG@10=0.0340


100%|██████████| 48/48 [00:06<00:00,  7.91it/s]


Epoch:   8| Train loss: 4.70466
[VAL] Recall@10=0.0801 | NDCG@10=0.0393


100%|██████████| 48/48 [00:06<00:00,  7.93it/s]


Epoch:   9| Train loss: 4.45841
[VAL] Recall@10=0.0823 | NDCG@10=0.0416


100%|██████████| 48/48 [00:06<00:00,  7.90it/s]


Epoch:  10| Train loss: 4.22730
[VAL] Recall@10=0.0816 | NDCG@10=0.0402


100%|██████████| 48/48 [00:06<00:00,  7.89it/s]


Epoch:  11| Train loss: 4.01124
[VAL] Recall@10=0.0833 | NDCG@10=0.0411


100%|██████████| 48/48 [00:06<00:00,  7.90it/s]


Epoch:  12| Train loss: 3.80999
[VAL] Recall@10=0.0906 | NDCG@10=0.0443


100%|██████████| 48/48 [00:06<00:00,  7.85it/s]


Epoch:  13| Train loss: 3.61568
[VAL] Recall@10=0.0881 | NDCG@10=0.0428


100%|██████████| 48/48 [00:06<00:00,  7.82it/s]


Epoch:  14| Train loss: 3.45412
[VAL] Recall@10=0.0891 | NDCG@10=0.0432


100%|██████████| 48/48 [00:06<00:00,  7.82it/s]


Epoch:  15| Train loss: 3.29249
[VAL] Recall@10=0.0909 | NDCG@10=0.0444


100%|██████████| 48/48 [00:06<00:00,  7.82it/s]


Epoch:  16| Train loss: 3.13714
[VAL] Recall@10=0.0974 | NDCG@10=0.0474


100%|██████████| 48/48 [00:06<00:00,  7.80it/s]


Epoch:  17| Train loss: 2.98054
[VAL] Recall@10=0.1003 | NDCG@10=0.0494
Saved baseline_best.pt


user embeddings: 100%|██████████| 6040/6040 [00:29<00:00, 204.06it/s]


[item_set] A: raw=10 enc=10  |  B: raw=10 enc=10
0.07301324503311259
0.11572847682119206
Mean frequency — A: 814.4  | B: 480.0 814.4  | B: 480.0


  ckpt = torch.load(baseline_state_path, map_location=dev)


Epoch:   1 | loss: 7.08982 || A: HR@10=0.1068 | rel=1.463
Epoch:   2 | loss: 6.24227 || A: HR@10=0.0879 | rel=1.204
Epoch:   3 | loss: 5.82692 || A: HR@10=0.0786 | rel=1.077
Epoch:   4 | loss: 5.50966 || A: HR@10=0.0512 | rel=0.701
Epoch:   5 | loss: 5.22854 || A: HR@10=0.0546 | rel=0.748
Epoch:   6 | loss: 4.97514 || A: HR@10=0.0627 | rel=0.859
Epoch:   7 | loss: 4.74478 || A: HR@10=0.0548 | rel=0.751
Epoch:   8 | loss: 4.54951 || A: HR@10=0.0421 | rel=0.576
Epoch:   9 | loss: 4.36264 || A: HR@10=0.0604 | rel=0.828
Epoch:  10 | loss: 4.17725 || A: HR@10=0.0510 | rel=0.698


  ckpt = torch.load(baseline_state_path, map_location=dev)


Epoch:   1 | loss: 7.09901 || A: HR@10=0.1283 | rel=1.757
Epoch:   2 | loss: 6.26375 || A: HR@10=0.1086 | rel=1.488
Epoch:   3 | loss: 5.85726 || A: HR@10=0.1023 | rel=1.401
Epoch:   4 | loss: 5.53642 || A: HR@10=0.0934 | rel=1.279
Epoch:   5 | loss: 5.24539 || A: HR@10=0.0763 | rel=1.045
Epoch:   6 | loss: 5.00331 || A: HR@10=0.0712 | rel=0.975
Epoch:   7 | loss: 4.76561 || A: HR@10=0.0975 | rel=1.336
Epoch:   8 | loss: 4.56365 || A: HR@10=0.1343 | rel=1.839
Epoch:   9 | loss: 4.37029 || A: HR@10=0.1737 | rel=2.379
Epoch:  10 | loss: 4.19055 || A: HR@10=0.1548 | rel=2.120


  ckpt = torch.load(baseline_state_path, map_location=dev)


Epoch:   1 | loss: 7.09807 || B: HR@10=0.1929 | rel=1.667
Epoch:   2 | loss: 6.24519 || B: HR@10=0.1674 | rel=1.446
Epoch:   3 | loss: 5.82782 || B: HR@10=0.1492 | rel=1.289
Epoch:   4 | loss: 5.50742 || B: HR@10=0.0907 | rel=0.784
Epoch:   5 | loss: 5.23586 || B: HR@10=0.1003 | rel=0.867
Epoch:   6 | loss: 4.98029 || B: HR@10=0.1402 | rel=1.212
Epoch:   7 | loss: 4.75513 || B: HR@10=0.1535 | rel=1.326
Epoch:   8 | loss: 4.54245 || B: HR@10=0.1184 | rel=1.023
Epoch:   9 | loss: 4.35421 || B: HR@10=0.1204 | rel=1.040
Epoch:  10 | loss: 4.17680 || B: HR@10=0.1035 | rel=0.894


  ckpt = torch.load(baseline_state_path, map_location=dev)


Epoch:   1 | loss: 7.10025 || B: HR@10=0.1869 | rel=1.615
Epoch:   2 | loss: 6.24093 || B: HR@10=0.1399 | rel=1.209
Epoch:   3 | loss: 5.82374 || B: HR@10=0.1167 | rel=1.009
Epoch:   4 | loss: 5.50612 || B: HR@10=0.0955 | rel=0.825
Epoch:   5 | loss: 5.22216 || B: HR@10=0.0820 | rel=0.708
Epoch:   6 | loss: 4.97493 || B: HR@10=0.0533 | rel=0.461
Epoch:   7 | loss: 4.73939 || B: HR@10=0.0308 | rel=0.266
Epoch:   8 | loss: 4.53289 || B: HR@10=0.0313 | rel=0.270
Epoch:   9 | loss: 4.32933 || B: HR@10=0.0207 | rel=0.179
Epoch:  10 | loss: 4.15593 || B: HR@10=0.0377 | rel=0.326


  ckpt = torch.load(baseline_state_path, map_location=dev)


Epoch:   1 | loss: 7.08975 || A: HR@10=0.1075 | rel=1.472 | CT=-0.090 || B: HR@10=0.2518 | rel=2.176 | CT=1.021
Epoch:   2 | loss: 6.24589 || A: HR@10=0.0950 | rel=1.302 | CT=-0.260 || B: HR@10=0.2944 | rel=2.544 | CT=1.389
Epoch:   3 | loss: 5.83492 || A: HR@10=0.0805 | rel=1.102 | CT=-0.460 || B: HR@10=0.2656 | rel=2.295 | CT=1.140
Epoch:   4 | loss: 5.50931 || A: HR@10=0.0719 | rel=0.984 | CT=-0.578 || B: HR@10=0.2671 | rel=2.308 | CT=1.153
Epoch:   5 | loss: 5.22828 || A: HR@10=0.0548 | rel=0.751 | CT=-0.811 || B: HR@10=0.2551 | rel=2.205 | CT=1.050
Epoch:   6 | loss: 4.99153 || A: HR@10=0.0399 | rel=0.546 | CT=-1.015 || B: HR@10=0.2992 | rel=2.585 | CT=1.430
Epoch:   7 | loss: 4.76865 || A: HR@10=0.0306 | rel=0.420 | CT=-1.142 || B: HR@10=0.3207 | rel=2.771 | CT=1.616
Epoch:   8 | loss: 4.54664 || A: HR@10=0.0373 | rel=0.510 | CT=-1.052 || B: HR@10=0.2472 | rel=2.136 | CT=0.981
Epoch:   9 | loss: 4.36048 || A: HR@10=0.0310 | rel=0.424 | CT=-1.138 || B: HR@10=0.2185 | rel=1.888 | C

  ckpt = torch.load(baseline_state_path, map_location=dev)


Epoch:   1 | loss: 7.10026 || A: HR@10=0.1310 | rel=1.794 | CT=0.903 || B: HR@10=0.1257 | rel=1.086 | CT=0.399
Epoch:   2 | loss: 6.25796 || A: HR@10=0.1573 | rel=2.154 | CT=1.264 || B: HR@10=0.1326 | rel=1.146 | CT=0.459
Epoch:   3 | loss: 5.83912 || A: HR@10=0.1464 | rel=2.005 | CT=1.114 || B: HR@10=0.0781 | rel=0.675 | CT=-0.012
Epoch:   4 | loss: 5.51163 || A: HR@10=0.1320 | rel=1.807 | CT=0.917 || B: HR@10=0.1028 | rel=0.888 | CT=0.202
Epoch:   5 | loss: 5.23218 || A: HR@10=0.1136 | rel=1.556 | CT=0.665 || B: HR@10=0.0823 | rel=0.711 | CT=0.024
Epoch:   6 | loss: 4.97241 || A: HR@10=0.1392 | rel=1.907 | CT=1.017 || B: HR@10=0.0573 | rel=0.495 | CT=-0.192
Epoch:   7 | loss: 4.75546 || A: HR@10=0.1843 | rel=2.524 | CT=1.633 || B: HR@10=0.0657 | rel=0.568 | CT=-0.119
Epoch:   8 | loss: 4.54877 || A: HR@10=0.1672 | rel=2.290 | CT=1.400 || B: HR@10=0.0488 | rel=0.422 | CT=-0.265
Epoch:   9 | loss: 4.34661 || A: HR@10=0.1652 | rel=2.263 | CT=1.373 || B: HR@10=0.0493 | rel=0.426 | CT=-0.

  ckpt = torch.load(baseline_state_path, map_location=dev)


Epoch:   1 | loss: 7.07631 || A: HR@10=0.0925 | rel=1.268 | CT=-0.294 || B: HR@10=0.1841 | rel=1.591 | CT=0.904
Epoch:   2 | loss: 6.22793 || A: HR@10=0.0508 | rel=0.696 | CT=-0.866 || B: HR@10=0.2147 | rel=1.856 | CT=1.169
Epoch:   3 | loss: 5.82446 || A: HR@10=0.0396 | rel=0.542 | CT=-1.020 || B: HR@10=0.1172 | rel=1.013 | CT=0.326
Epoch:   4 | loss: 5.50001 || A: HR@10=0.0306 | rel=0.420 | CT=-1.142 || B: HR@10=0.0745 | rel=0.644 | CT=-0.043
Epoch:   5 | loss: 5.22762 || A: HR@10=0.0278 | rel=0.381 | CT=-1.181 || B: HR@10=0.1041 | rel=0.900 | CT=0.213
Epoch:   6 | loss: 4.96580 || A: HR@10=0.0318 | rel=0.435 | CT=-1.127 || B: HR@10=0.0800 | rel=0.691 | CT=0.004
Epoch:   7 | loss: 4.75133 || A: HR@10=0.0401 | rel=0.549 | CT=-1.013 || B: HR@10=0.0922 | rel=0.797 | CT=0.110
Epoch:   8 | loss: 4.52729 || A: HR@10=0.0396 | rel=0.542 | CT=-1.020 || B: HR@10=0.1311 | rel=1.133 | CT=0.446
Epoch:   9 | loss: 4.34189 || A: HR@10=0.0397 | rel=0.544 | CT=-1.018 || B: HR@10=0.0724 | rel=0.625 | 

  ckpt = torch.load(baseline_state_path, map_location=dev)


Epoch:   1 | loss: 7.08984 || A: HR@10=0.1084 | rel=1.485 | CT=0.595 || B: HR@10=0.1679 | rel=1.451 | CT=0.296
Epoch:   2 | loss: 6.24867 || A: HR@10=0.0863 | rel=1.181 | CT=0.291 || B: HR@10=0.1474 | rel=1.273 | CT=0.118
Epoch:   3 | loss: 5.83085 || A: HR@10=0.1002 | rel=1.372 | CT=0.481 || B: HR@10=0.1462 | rel=1.263 | CT=0.108
Epoch:   4 | loss: 5.51230 || A: HR@10=0.0632 | rel=0.866 | CT=-0.024 || B: HR@10=0.1427 | rel=1.233 | CT=0.078
Epoch:   5 | loss: 5.23095 || A: HR@10=0.0434 | rel=0.594 | CT=-0.296 || B: HR@10=0.1642 | rel=1.419 | CT=0.264
Epoch:   6 | loss: 4.98285 || A: HR@10=0.0503 | rel=0.689 | CT=-0.201 || B: HR@10=0.1358 | rel=1.173 | CT=0.018
Epoch:   7 | loss: 4.75718 || A: HR@10=0.0588 | rel=0.805 | CT=-0.085 || B: HR@10=0.1742 | rel=1.505 | CT=0.350
Epoch:   8 | loss: 4.54724 || A: HR@10=0.0379 | rel=0.519 | CT=-0.371 || B: HR@10=0.1887 | rel=1.631 | CT=0.476
Epoch:   9 | loss: 4.36274 || A: HR@10=0.0464 | rel=0.635 | CT=-0.256 || B: HR@10=0.2129 | rel=1.840 | CT=0

Unnamed: 0,N,p,trial,gA_base,gB_base,gA_prom,gA_demo,gB_prom,gB_demo,gA_pp,...,rA_pd,ctA_pd,rB_pd,ctB_pd,rA_dp,ctA_dp,rB_dp,ctB_dp,size_A,size_B
0,10,0.1,0,0.073013,0.115728,0.11404,0.065017,0.133642,0.079487,0.057483,...,0.622449,-0.939456,1.0,0.313162,0.862132,-0.028345,1.483548,0.328755,10,10


Saved rq2_results_relative_hr.csv
