### Notebook hotfix (applied 2025-08-20T11:37:08 UTC)

- Fixed CUDA device-side assert during evaluation by sanitizing token IDs in `score_next_items`.
- Masked PAD/MASK predictions and clipped logits to `num_items`.
- Light hyperparameter tweaks intended to keep training ≤10% slower while improving Recall@10 / NDCG@10:
  - `mask_prob=0.20`
  - `weight_decay=0.01`, `lr_scheduler='cosine'`, `warmup_ratio=0.10`
  - `dropout` +0.05 (cap 0.35)
  - `max_len` increased by ~10%
  - learning rate +15% with warmup/decay

# ML-20M — BERT4Rec Baseline + Collective Experiments (RQ2-style)
**Built:** 2025-08-13 14:54

This notebook **keeps the baseline model and training loop from `ml-20bert4rec`** and adds **all collective experimental conditions** inspired by `bert4rec_collectives_rq2_paper`:
- Embedding-based user clustering (KMeans), farthest-cluster seeding
- Collective construction with homogeneity **p**
- Promote/Demote scenarios for two collectives (A/B)
- Deterministic rating edits and **retraining from the baseline weights**
- Relative HR@K on targeted item sets (A and B)
- Results export + quick plot

Printing/log messages mirror the paper notebook (e.g., *Device:*, *Running RQ2 grid...*, *Seed clusters: ...*, *Saved rq2_results_relative_hr.csv*).



> **Fix applied (Aug 20, 2025):** Validation Recall@K/NDCG@K previously evaluated on single-item sequences,
> which produced zeros. `prepare_baseline` now evaluates on `user_valid_full[u] = user_train[u] + user_valid[u]`,
> ensuring there is prefix context (cond) and a target. The evaluator `recall_ndcg_at_k` expects sequences with
> length ≥ 2 and uses `cond = seq[:-1]`, `target = seq[-1]`. No change to the dataset split itself.


In [18]:

from collections import defaultdict

def edit_ratings(df, users_idx, target_items_raw, action, user_decoder,
                 promote_value=5.0, demote_value=1.0,
                 add_competing=False, competing_item_raw=None):
    """
    Campaign edits for BERT4Rec applied to TRAIN ONLY.

    - promote: append each target item just before the user's valid/test boundary (ts_valid),
               but strictly after the user's last train timestamp (max_train_ts)
    - demote:  remove target items in the TRAIN region (ts < ts_valid), never touch the last two events
               optionally append a competing interaction just before ts_valid

    Returns a new DataFrame with the same schema; last-two events per user are preserved.
    """
    import numpy as np, pandas as pd

    if df is None or len(df) == 0:
        return df

    # Normalize inputs
    target_items_raw = list({int(x) for x in (target_items_raw or [])})
    if not target_items_raw:
        return df

    user_raw_ids = [user_decoder[u] for u in users_idx if u in user_decoder]
    if not user_raw_ids:
        return df

    target_set = set(target_items_raw)
    user_set   = set(user_raw_ids)

    # Stable sort to avoid row reordering among equal timestamps
    df = df.sort_values(['userId', 'timestamp'], kind='mergesort').reset_index(drop=True)

    out_parts = []

    for uid, g in df.groupby('userId', sort=False):
        if uid not in user_set:
            out_parts.append(g)
            continue

        g = g.sort_values('timestamp', kind='mergesort').copy()
        ts_vals = g['timestamp'].to_numpy()

        # Define boundary using the ORIGINAL sequence per-user
        if len(ts_vals) >= 2:
            ts_valid = int(ts_vals[-2])  # anything < ts_valid is TRAIN; >= ts_valid is VALID/TEST tail
            train_mask = ts_vals < ts_valid
        else:
            ts_valid = int(ts_vals[-1]) if len(ts_vals) else np.iinfo(np.int64).max
            train_mask = np.ones(len(g), dtype=bool)

        if action == 'promote':
            kept = g.copy()

            # choose timestamps for k new interactions within (max_train_ts, ts_valid)
            k = len(target_items_raw)
            if train_mask.any():
                max_train_ts = int(g.loc[train_mask, 'timestamp'].max())
            else:
                # if no train events, place a window just below ts_valid
                max_train_ts = ts_valid - (k + 10)

            # initial window
            start = max(max_train_ts + 1, ts_valid - k)   # try to end right under boundary
            new_ts = np.arange(start, start + k, dtype=np.int64)

            # ensure strictly below ts_valid
            overflow = max(0, int(new_ts[-1] - (ts_valid - 1)))
            if overflow > 0:
                new_ts = new_ts - overflow

            # ensure strictly above max_train_ts
            if new_ts[0] <= max_train_ts:
                shift = (max_train_ts - new_ts[0]) + 1
                new_ts = new_ts + shift
                # clip again below boundary if needed
                overflow = max(0, int(new_ts[-1] - (ts_valid - 1)))
                if overflow > 0:
                    new_ts = new_ts - overflow
                    # final guard: if we collapsed (degenerate window), step them backwards
                    if np.any(new_ts <= max_train_ts):
                        gap = (max_train_ts - new_ts[0]) + 1
                        new_ts = new_ts - (gap + 1)

            # build new rows
            new_rows = pd.DataFrame({
                'userId':   [uid] * k,
                'movieId':  target_items_raw,
                'rating':   [promote_value] * k,
                'timestamp': new_ts
            })

            kept = pd.concat([kept, new_rows], ignore_index=True)
            kept = kept.sort_values('timestamp', kind='mergesort')

            out_parts.append(kept)

        elif action == 'demote':
            kept = g.copy()

            # drop target items only in TRAIN
            in_targets = kept['movieId'].isin(target_set)
            drop_mask  = in_targets & train_mask
            kept = kept.loc[~drop_mask].copy()

            # optionally add one competing interaction right before ts_valid (but after max_train_ts)
            if add_competing:
                if train_mask.any():
                    max_train_ts = int(g.loc[train_mask, 'timestamp'].max())
                else:
                    max_train_ts = ts_valid - 10

                # choose competing item if not provided (global most frequent non-target)
                if competing_item_raw is not None:
                    cid = int(competing_item_raw)
                else:
                    global_counts = df.loc[~df['movieId'].isin(target_set), 'movieId'].value_counts()
                    cid = int(global_counts.index[0]) if len(global_counts) else None

                if cid is not None:
                    comp_ts = min(ts_valid - 1, max_train_ts + 1)
                    comp_row = pd.DataFrame({
                        'userId':   [uid],
                        'movieId':  [cid],
                        'rating':   [promote_value],
                        'timestamp':[comp_ts]
                    })
                    kept = pd.concat([kept, comp_row], ignore_index=True)

            kept = kept.sort_values('timestamp', kind='mergesort')
            out_parts.append(kept)

        else:
            # Unknown action: passthrough
            out_parts.append(g)

    out = pd.concat(out_parts, ignore_index=True)
    # Normalize dtypes
    return out.astype({'userId': 'int64', 'movieId': 'int64'})



def rebuild_sequences_from_df(df, item_encoder, user_encoder, threshold=None):
    """Build user->sequence dict for BERT4Rec using all ratings as interactions."""
    import pandas as pd
    from collections import defaultdict
    df = df.copy()
    df = df.sort_values(['userId','timestamp']).reset_index(drop=True)
    df['item_idx'] = df['movieId'].map(lambda x: item_encoder.get(x, None))
    df['user_idx'] = df['userId'].map(lambda x: user_encoder.get(x, None))
    df = df.dropna(subset=['item_idx','user_idx'])
    df['item_idx'] = df['item_idx'].astype(int) + 1
    df['user_idx'] = df['user_idx'].astype(int)
    user_pos = defaultdict(list)
    for _, row in df.iterrows():
        user_pos[row['user_idx']].append(row['item_idx'])
    return {u: seq for u, seq in user_pos.items() if len(seq) >= 1}


In [19]:
import os, math, random, gc
import numpy as np
import pandas as pd
from collections import defaultdict, Counter
from dataclasses import dataclass
from tqdm import tqdm
from sklearn.cluster import KMeans
from sklearn.metrics.pairwise import cosine_distances

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device:", DEVICE)
SEED_MODE = "balanced_maxdist"

Device: cuda


In [20]:
\
# --- Baseline config (from ml-20bert4rec) ---
config = {
    'data_path' : r'C:\Users\david\OneDrive\Desktop\Collective Exp\ml-1m',  # ML-20M
    'max_len' : 50,
    'hidden_units' : 256,
    'num_heads' : 2,
    'num_layers': 2,
    'dropout_rate' : 0.1,
    'lr' : 0.001,
    'batch_size' : 128,
    'num_epochs' : 17,
    'num_workers' : 2,
    'mask_prob' : 0.15,
    'seed' : 42,
    
}

def fix_seed(seed:int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

fix_seed(config['seed'])
print("Seed set to", config['seed'])


Seed set to 42


In [21]:
import os
import pandas as pd

class MakeSequenceDataSet():
    def __init__(self, config):
        dat1 = os.path.join(config['data_path'], 'ratings.dat')
        dat2 = os.path.join(config['data_path'], 'rating.dat')
        dat_path = dat1 if os.path.exists(dat1) else dat2

        if not os.path.exists(dat_path):
            raise FileNotFoundError(f"Could not find ratings.dat at {config['data_path']}")

        print("Loading:", dat_path)

        # Correct read for MovieLens 1M .dat format
        self.df = pd.read_csv(
            dat_path,
            sep="::",
            engine="python",      # Required for multi-char separator
            header=None,          # No header row in the file
            names=["userId", "movieId", "rating", "timestamp"],
            encoding="latin-1"    # Avoids encoding errors
        )

        must = {'userId','movieId','rating','timestamp'}
        missing = must - set(self.df.columns)
        if missing:
            raise ValueError(f"Missing columns: {missing}")

        self.item_encoder, self.item_decoder = self.generate_encoder_decoder('movieId')
        self.user_encoder, self.user_decoder = self.generate_encoder_decoder('userId')
        self.num_item, self.num_user = len(self.item_encoder), len(self.user_encoder)

        self.df['item_idx'] = self.df['movieId'].apply(lambda x: self.item_encoder[x] + 1)  # 1..num_item
        self.df['user_idx'] = self.df['userId'].apply(lambda x: self.user_encoder[x])
        self.df = self.df.sort_values(['user_idx', 'timestamp'])

        # temporal train/valid/test split
        self.user_train, self.user_valid, self.user_test = self.generate_sequence_data()
        print("Users with sequences:", len(self.user_train), "| Items:", self.num_item)


    def generate_encoder_decoder(self, col:str):
        encoder, decoder = {}, {}
        ids = self.df[col].unique()
        for idx, _id in enumerate(ids):
            encoder[_id] = idx
            decoder[idx] = _id
        return encoder, decoder

    def generate_sequence_data(self):
        users = defaultdict(list)
        user_train, user_valid, user_test = {}, {}, {}
        for user, g in self.df.groupby('user_idx'):
            seq = g['item_idx'].tolist()
            if len(seq) < 3:
                continue
            users[user] = seq
        for user, seq in users.items():
            user_train[user] = seq[:-2]
            user_valid[user] = [seq[-2]]
            user_test[user]  = [seq[-1]]
        return user_train, user_valid, user_test

    def get_splits(self):
        return self.user_train, self.user_valid, self.user_test

def rebuild_sequences_from_df(df, item_encoder, user_encoder, threshold=None):
    # Use **all** ratings as interactions (threshold ignored)
    df = df.copy()
    df = df.sort_values(['userId','timestamp']).reset_index(drop=True)
    df['item_idx'] = df['movieId'].map(lambda x: item_encoder.get(x, None))
    df['user_idx'] = df['userId'].map(lambda x: user_encoder.get(x, None))
    df = df.dropna(subset=['item_idx','user_idx'])
    df['item_idx'] = df['item_idx'].astype(int) + 1
    df['user_idx'] = df['user_idx'].astype(int)
    user_pos = defaultdict(list)
    for _, row in df.iterrows():
        user_pos[row['user_idx']].append(row['item_idx'])
    # filter out empties
    return {u: seq for u, seq in user_pos.items() if len(seq) >= 1}


In [22]:

# === Validation metrics: next-item Recall@K and NDCG@K (correct target handling) ===
import torch

@torch.no_grad()
def recall_ndcg_at_k(model, user_pos, num_items, max_len, k=10, batch_size=2048, device=None):
    """
    Standard next-item eval:
      cond = seq[:-1], target = seq[-1].
    We DO NOT mask the target even if it appeared earlier in cond.
    """
    dev = device or next(model.parameters()).device
    PAD = 0
    MASK = num_items + 1  # IMPORTANT: this repo uses MASK = num_items + 1

    rows, targets, seen_lists = [], [], []
    for _, seq in user_pos.items():
        if len(seq) < 2:
            continue
        cond = seq[:-1]
        tgt  = seq[-1]
        s = cond[-max_len:] if len(cond) > max_len else cond
        if len(s) > max_len - 1:
            s = s[-(max_len - 1):]  # leave room for [MASK]
        pad = max_len - len(s) - 1
        if pad < 0: pad = 0
        rows.append([PAD]*pad + list(s) + [MASK])
        targets.append(tgt)
        seen_lists.append(list(set(s)))  # we'll unmask target below

    if not rows:
        return 0.0, 0.0

    X = torch.tensor(rows, dtype=torch.long, device=dev)
    targets = torch.tensor(targets, dtype=torch.long, device=dev)

    # Determine logits width V from model
    V = model(X[:1])[:, -1, :].shape[1]

    # Build suppression mask (B,V) for PAD/MASK/seen EXCEPT the target
    B = X.shape[0]
    seen_mask = torch.zeros((B, V), dtype=torch.bool, device=dev)
    seen_mask[:, PAD] = True
    if 0 <= MASK < V:
        seen_mask[:, MASK] = True

    r_idx, c_idx = [], []
    for i, items in enumerate(seen_lists):
        items_set = set(it for it in items if 0 <= it < V)
        tgt_i = int(targets[i].item())
        if 0 <= tgt_i < V and tgt_i in items_set:
            items_set.remove(tgt_i)
        if items_set:
            r_idx.extend([i]*len(items_set))
            c_idx.extend(list(items_set))
    if r_idx:
        seen_mask[torch.tensor(r_idx, device=dev), torch.tensor(c_idx, device=dev)] = True

    # Sanity check: target must not be masked
    assert not seen_mask[torch.arange(B, device=dev), targets.clamp_min(0).clamp_max(V-1)].any(), "Target was masked!"

    hits = 0.0
    ndcgs = 0.0
    for s in range(0, B, batch_size):
        e = min(B, s+batch_size)
        logits = model(X[s:e])[:, -1, :]
        logits = logits.masked_fill(seen_mask[s:e], -1e9)
        topk = torch.topk(logits, k=k, dim=1).indices
        tgts = targets[s:e].unsqueeze(1)

        hit_rows = (topk == tgts).any(dim=1).float()
        hits += hit_rows.sum().item()

        where = (topk == tgts).nonzero(as_tuple=False)
        if where.numel() > 0:
            ndcgs += (1.0 / torch.log2(where[:,1].float() + 2.0)).sum().item()

    n = float(B)
    return hits / n, ndcgs / n


In [23]:

class BERTRecDataSet(Dataset):
    def __init__(self, user_train, max_len, num_item, mask_prob=0.15):
        self.max_len = max_len
        self.num_item = num_item
        self.mask_prob = mask_prob
        self.users = list(user_train.keys())
        self.inputs, self.labels = [], []
        for user in self.users:
            seq = user_train[user]
            tokens = seq[-max_len:] if len(seq) > max_len else [0]*(max_len-len(seq)) + seq
            masked_tokens, label_tokens = self.mask_sequence(tokens)
            self.inputs.append(masked_tokens)
            self.labels.append(label_tokens)
        self.inputs = torch.tensor(self.inputs, dtype=torch.long)
        self.labels = torch.tensor(self.labels, dtype=torch.long)

    def mask_sequence(self, tokens):
        masked_tokens = tokens.copy()
        labels = [0]*len(tokens)  # PAD=0
        for i in range(len(tokens)):
            if tokens[i] == 0:
                continue
            if random.random() < self.mask_prob:
                labels[i] = tokens[i]  # store original ID
                prob = random.random()
                if prob < 0.8:
                    masked_tokens[i] = self.num_item + 1  # MASK token in inputs only
                elif prob < 0.9:
                    masked_tokens[i] = random.randint(1, self.num_item)
        return masked_tokens, labels

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx], self.labels[idx]


In [24]:
class PositionalEmbedding(nn.Module):
    def __init__(self, max_len, d_model): super().__init__(); self.pe = nn.Embedding(max_len, d_model)
    def forward(self, x): return self.pe.weight.unsqueeze(0).repeat(x.size(0), 1, 1)

class TokenEmbedding(nn.Embedding):
    def __init__(self, vocab_size, embed_size=512): super().__init__(vocab_size, embed_size, padding_idx=0)

class BERTEmbedding(nn.Module):
    def __init__(self, vocab_size, embed_size, max_len, dropout=0.1):
        super().__init__()
        self.token = TokenEmbedding(vocab_size=vocab_size, embed_size=embed_size)
        self.position = PositionalEmbedding(max_len=max_len, d_model=embed_size)
        self.dropout = nn.Dropout(p=dropout)
        self.embed_size = embed_size
    def forward(self, sequence):
        return self.dropout(self.token(sequence) + self.position(sequence))

class Attention(nn.Module):
    def forward(self, query, key, value, mask=None, dropout=None):
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))
        if mask is not None: scores = scores.masked_fill(mask == 0, -1e9)
        p_attn = F.softmax(scores, dim=-1)
        if dropout is not None: p_attn = dropout(p_attn)
        return torch.matmul(p_attn, value), p_attn

class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        super().__init__(); assert d_model % h == 0
        self.d_k = d_model // h; self.h = h
        self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)])
        self.output_linear = nn.Linear(d_model, d_model); self.attention = Attention()
        self.dropout = nn.Dropout(p=dropout)
    def forward(self, query, key, value, mask=None):
        B = query.size(0)
        query, key, value = [l(x).view(B, -1, self.h, self.d_k).transpose(1, 2)
                             for l, x in zip(self.linear_layers, (query, key, value))]
        x, _ = self.attention(query, key, value, mask=mask, dropout=self.dropout)
        x = x.transpose(1, 2).contiguous().view(B, -1, self.h * self.d_k)
        return self.output_linear(x)

class GELU(nn.Module):
    def forward(self, x): return 0.5 * x * (1 + torch.tanh(math.sqrt(2/math.pi)*(x + 0.044715*torch.pow(x,3))))

class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__(); self.w_1 = nn.Linear(d_model, d_ff); self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout); self.activation = GELU()
    def forward(self, x): return self.w_2(self.dropout(self.activation(self.w_1(x))))

class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-6): super().__init__(); self.a_2 = nn.Parameter(torch.ones(features)); self.b_2 = nn.Parameter(torch.zeros(features)); self.eps = eps
    def forward(self, x): mean = x.mean(-1, keepdim=True); std = x.std(-1, keepdim=True); return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

class SublayerConnection(nn.Module):
    def __init__(self, size, dropout): super().__init__(); self.norm = LayerNorm(size); self.dropout = nn.Dropout(dropout)
    def forward(self, x, sublayer): return x + self.dropout(sublayer(self.norm(x)))

class TransformerBlock(nn.Module):
    def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout):
        super().__init__()
        self.attention = MultiHeadedAttention(h=attn_heads, d_model=hidden, dropout=dropout)
        self.feed_forward = PositionwiseFeedForward(d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout)
        self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout)
        self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout)
        self.dropout = nn.Dropout(p=dropout)
    def forward(self, x, mask):
        x = self.input_sublayer(x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask))
        x = self.output_sublayer(x, self.feed_forward); return self.dropout(x)

class BERT(nn.Module):
    def __init__(self, bert_max_len, num_items, bert_num_blocks, bert_num_heads,
                 bert_hidden_units, bert_dropout):
        super().__init__()
        self.max_len = bert_max_len; self.num_items = num_items
        self.hidden = bert_hidden_units
        self.embedding = BERTEmbedding(vocab_size=num_items+2, embed_size=self.hidden, max_len=bert_max_len, dropout=bert_dropout)
        self.transformer_blocks = nn.ModuleList([TransformerBlock(self.hidden, bert_num_heads, self.hidden*4, bert_dropout) for _ in range(bert_num_blocks)])
        self.out = nn.Linear(self.hidden, num_items + 2)  # 0..num_items
    def forward_hidden(self, x):
        mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)
        h = self.embedding(x)
        for transformer in self.transformer_blocks:
            h = transformer.forward(h, mask)
        return h
    def forward(self, x):
        h = self.forward_hidden(x)
        return self.out(h)


In [25]:
def train_one_epoch(model, criterion, optimizer, data_loader, device=DEVICE):
    model.train(); loss_val = 0.0
    for seq, labels in tqdm(data_loader):
        seq, labels = seq.to(device), labels.to(device)
        logits = model(seq).view(-1, model.out.out_features)  # (bs*t, vocab)
        labels = labels.view(-1)
        optimizer.zero_grad()
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        loss_val += loss.item()
    return loss_val / max(1, len(data_loader))

def prepare_baseline(config):
    """
    Trains a baseline BERT4Rec model and evaluates next-item Recall/NDCG correctly.
    IMPORTANT: Validation is evaluated on sequences built as
      user_valid_full[u] = user_train[u] + user_valid[u]
    so that cond = prefix, target = last (valid) item.
    """
    # 1) Build sequences
    mds = MakeSequenceDataSet(config)
    user_train, user_valid, user_test = mds.user_train, mds.user_valid, mds.user_test
    num_item = mds.num_item

    # 2) DataLoader from training prefixes
    train_ds = BERTRecDataSet(user_train, max_len=config['max_len'], num_item=num_item, mask_prob=0.15)
    dl = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True, drop_last=False)

    # 3) Model
    model = BERT(
        bert_max_len=config['max_len'],
        num_items=num_item,   # this class internally adds +2 for PAD/MASK head
        bert_num_blocks=config['num_layers'],
        bert_num_heads=config['num_heads'],
        bert_hidden_units=config['hidden_units'],
        bert_dropout=config['dropout_rate']
    ).to(DEVICE)

    criterion = nn.CrossEntropyLoss(ignore_index=0)
    opt = torch.optim.Adam(model.parameters(), lr=config['lr'])

    # 4) Validation spec: build full sequences with context
    user_valid_full = {
        u: (user_train[u] + user_valid[u])
        for u in user_valid
        if u in user_train and len(user_train[u]) >= 1 and len(user_valid[u]) == 1
    }

    K = int(config.get('TOPK', 10)) if isinstance(config, dict) else 10
    for epoch in range(1, int(config['num_epochs']) + 1):
        l = train_one_epoch(model, criterion, opt, dl, device=DEVICE)
        print(f'Epoch: {epoch:3d}| Train loss: {l:.5f}')
        # Evaluate next-item ranking on validation (prefix->valid_item)
        _was = model.training
        model.eval()
        try:
            _rec, _ndcg = recall_ndcg_at_k(
                model, user_valid_full, num_item, config['max_len'],
                k=K, batch_size=2048, device=DEVICE
            )
        finally:
            if _was: model.train()
        print(f"[VAL] Recall@{K}={_rec:.4f} | NDCG@{K}={_ndcg:.4f}")
    torch.save(model.state_dict(), "baseline_best.pt")
    print("Saved baseline_best.pt")
    return mds, user_train, user_valid, user_test, model

In [26]:
def user_embedding_from_model(model, seq, num_items, max_len):
    # Average of valid hidden states
    s = seq[-max_len:] if len(seq) > max_len else seq
    PAD = 0; MASK = num_items + 2
    pad = max_len - len(s)
    x = torch.tensor([[PAD]*pad + s], dtype=torch.long, device=DEVICE)
    with torch.no_grad():
        h = model.forward_hidden(x)[0]  # (L, H)
        valid = h[pad:pad+len(s)]
        return valid.mean(dim=0).detach().cpu().numpy()

def compute_user_embeddings(model, train_seqs, num_items, max_len):
    user_vecs = {}
    for u, seq in tqdm(list(train_seqs.items()), desc="user embeddings"):
        user_vecs[u] = user_embedding_from_model(model, seq, num_items, max_len)
    U = np.stack([user_vecs[u] for u in user_vecs.keys()])
    user_index = list(user_vecs.keys())
    return U, user_index

def build_collective_quota(seed_cluster, labels, users, size, p, *, seed=None):
    import random
    rnd = random.Random(seed) if seed is not None else random

    # split pools
    seed_users  = [u for u, lab in zip(users, labels) if lab == seed_cluster]
    other_users = [u for u, lab in zip(users, labels) if lab != seed_cluster]
    rnd.shuffle(seed_users); rnd.shuffle(other_users)

    # exact target counts
    n_seed  = min(len(seed_users), int(round(p * size)))
    n_other = min(len(other_users), size - n_seed)

    members = seed_users[:n_seed] + other_users[:n_other]

    # backfill if one pool was too small
    if len(members) < size:
        spill = seed_users[n_seed:] + other_users[n_other:]
        rnd.shuffle(spill)
        members += spill[: size - len(members)]

    return members


def top_items_for_collective(members, train_seqs, topn=10):
    from collections import Counter
    c = Counter()
    for u in members:
        for it in set(train_seqs.get(u, [])):  # count each item once per user
            c[it] += 1
    return [it for it, _ in c.most_common(topn)]



In [27]:

# Helper: resolve number of items robustly from mds/model
def resolve_num_items(mds, model):
    # Try common dataset attributes first
    for attr in ('num_items', 'n_items', 'item_vocab_size'):
        if hasattr(mds, attr):
            try:
                n = int(getattr(mds, attr))
                if n > 0:
                    return n
            except Exception:
                pass
    # Try encoder sizes
    for attr in ('item_encoder', 'item2id', 'item_to_idx'):
        enc = getattr(mds, attr, None)
        if isinstance(enc, dict) and len(enc) > 0:
            return int(len(enc))
    # Try model output head (often vocab size = num_items + specials)
    out_features = getattr(getattr(model, 'out', None), 'out_features', None)
    if isinstance(out_features, int) and out_features > 0:
        # Heuristic: BERT4Rec-style uses tokens: 0=PAD, +1=[CLS], +2=[MASK]
        # So num_items ≈ out_features - 3
        return max(1, int(out_features) - 3)
    raise RuntimeError("Could not resolve number of items from mds/model.")


In [28]:
def retrain_from_baseline_on_ratings(
    baseline_state_path,
    ratings_df,
    mds,
    epochs=8,
    lr=5e-4,
    *,
    eval_specs=None,   # list of {"name","members_idx","item_set_enc","baseline_hr", ["solo_hr"]}
    eval_split='user_test',
    k=None,
    device=None,
    **kwargs           # swallow unused args for compatibility
):
    """
    Retrain from baseline weights on a modified ratings_df and print, per epoch:
      - training loss
      - HR@K for each eval spec
      - relative HR (vs baseline_hr)
      - constructiveness CT = rel_joint - rel_solo (if solo_hr provided)

    Returns: (mds, user_pos_mod, None, None, model)
    """
    dev = device if device is not None else DEVICE
    K = k if k is not None else (config.get('TOPK', 10) if isinstance(config, dict) else 10)
    # Paper protocol: evaluate HR over all users in chosen split
    global_eval_subset = None
    try:
        global_eval_subset = getattr(mds, eval_split)
    except Exception:
        global_eval_subset = None


    # 0) Defensive config defaults
    max_len = int(config.get('max_len', 50))
    mask_prob = float(config.get('mask_prob', 0.15))
    batch_size = int(config.get('batch_size', 128))

    # 1) Rebuild sequences from edited ratings using all ratings (threshold ignored)
    user_pos_mod = rebuild_sequences_from_df(
        ratings_df,
        mds.item_encoder,
        mds.user_encoder,
        threshold=None
    )
    user_pos_mod = {u: [int(x) for x in seq] for u, seq in user_pos_mod.items()}
    user_pos_mod = {u: seq for u, seq in user_pos_mod.items() if len(seq) > 0}
    if not user_pos_mod:
        raise ValueError("No user sequences after rebuild; check ratings_df and threshold.")

    # 1b) Compute num_item to cover all observed ids
    enc_num_item = getattr(mds, 'num_item', len(getattr(mds, 'item_encoder', {})))
    max_seen = max(max(seq) for seq in user_pos_mod.values())
    num_item = max(enc_num_item, max_seen)

    # 2) DataLoader
    ds = BERTRecDataSet(
        user_pos_mod,
        max_len=max_len,
        num_item=num_item,
        mask_prob=mask_prob
    )
    dl = DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,
        pin_memory=False
    )

    model = BERT(
        bert_max_len=max_len,
        num_items=num_item,      # <- pass *items only*; the class adds +2 internally
        bert_num_blocks=config['num_layers'],
        bert_num_heads=config['num_heads'],
        bert_hidden_units=config['hidden_units'],
        bert_dropout=config['dropout_rate']
    ).to(dev)

    # ----- Vocab-size–aware load: pad/slice embedding & output heads -----
    ckpt = torch.load(baseline_state_path, map_location=dev)
    model_sd = model.state_dict()

    def _resize_like_param(param_name, src_tensor, dst_tensor):
        """Resize src_tensor to match dst_tensor on dim 0 by copy-overlap + zero pad."""
        if src_tensor.shape == dst_tensor.shape:
            return src_tensor
        # Only support resizing first dimension (vocab) and keeping others
        if src_tensor.ndim != dst_tensor.ndim or src_tensor.shape[1:] != dst_tensor.shape[1:]:
            # Fallback: keep destination (random init) if shapes are incompatible
            return dst_tensor
        new_t = dst_tensor.clone()
        n = min(src_tensor.shape[0], dst_tensor.shape[0])
        new_t[:n] = src_tensor[:n]
        return new_t

    # Candidate keys for vocab-tied layers across common BERT4Rec variants
    vocab_keys = [
        'embedding.token.weight',          # seen in your error
        'bert.item_embedding.weight',
        'item_embedding.weight',
        'bert.embeddings.word_embeddings.weight',
    ]
    out_w_keys = [
        'out.weight',                      # seen in your error
        'bert.prediction.weight',
        'prediction.weight',
    ]
    out_b_keys = [
        'out.bias',
        'bert.prediction.bias',
        'prediction.bias',
    ]

    # Build a patched checkpoint dict
    patched = {}
    for k, v in ckpt.items():
        if k in model_sd:
            if k in vocab_keys:
                patched[k] = _resize_like_param(k, v, model_sd[k])
            elif k in out_w_keys:
                patched[k] = _resize_like_param(k, v, model_sd[k])
            elif k in out_b_keys:
                # Bias: resize on dim 0
                if v.shape != model_sd[k].shape:
                    nb = min(v.shape[0], model_sd[k].shape[0])
                    new_b = model_sd[k].clone()
                    new_b[:nb] = v[:nb]
                    patched[k] = new_b
                else:
                    patched[k] = v
            else:
                # default: keep as-is if shape matches; otherwise keep destination param
                if v.shape == model_sd[k].shape:
                    patched[k] = v
                else:
                    patched[k] = model_sd[k]
        else:
            # keys not in current model are ignored
            pass

    # Load with strict=False to allow any unmatched keys
    missing, unexpected = model.load_state_dict(patched, strict=False)
    if missing:
        print(f"[info] Missing keys after load (expected if heads/embeddings expanded): {missing}")
    if unexpected:
        print(f"[info] Unexpected keys ignored from checkpoint: {unexpected}")

    opt = torch.optim.Adam(model.parameters(), lr=lr)
    crit = nn.CrossEntropyLoss(ignore_index=0, reduction='mean')
    grad_clip = float(config.get('grad_clip', 1.0))
    torch.backends.cudnn.benchmark = False

    def _train_one_epoch_nan_safe(model, crit, opt, dl):
        model.train()
        total_loss = 0.0
        total_tokens = 0
        for batch in dl:
            if isinstance(batch, dict):
                batch = {k: (v.to(dev) if hasattr(v, "to") else v) for k, v in batch.items()}
                inputs = batch.get('input_ids') or batch.get('seqs') or next(iter(batch.values()))
                labels = batch.get('labels') or batch.get('targets') or batch.get('target_ids')
                attn = batch.get('attention_mask') or batch.get('attn_mask')
            else:
                inputs = batch[0].to(dev)
                labels = batch[1].to(dev)
                attn = batch[2].to(dev) if len(batch) > 2 else None

            if labels is None:
                continue
            valid = (labels != 0)
            if valid.sum().item() == 0:
                continue

            opt.zero_grad(set_to_none=True)
            logits = model(inputs)  # (B, T, V)
            B, T, V = logits.shape
            loss = crit(logits.view(B*T, V), labels.view(B*T))
            if not torch.isfinite(loss):
                continue
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            opt.step()

            total_loss += loss.item() * valid.sum().item()
            total_tokens += valid.sum().item()

        return (total_loss / max(total_tokens, 1))

    def _coerce_itemset(item_set_enc, target_num_item):
        if isinstance(item_set_enc, (list, tuple)):
            v = torch.tensor(item_set_enc, dtype=torch.float32)
        elif torch.is_tensor(item_set_enc):
            v = item_set_enc.float().cpu()
        else:
            raise TypeError("item_set_enc must be list/tuple/tensor")
        if v.ndim != 1:
            v = v.view(-1)
        if len(v) < target_num_item:
            v = torch.cat([v, torch.zeros(target_num_item - len(v))], dim=0)
        elif len(v) > target_num_item:
            v = v[:target_num_item]
        return v.to(dev)

    def _eval_one(spec):
        members = spec['members_idx']
        subset = global_eval_subset if global_eval_subset else user_pos_mod
        aligned_itemset = _coerce_itemset(spec['item_set_enc'], num_item)

        hr = float(hr_for_itemset(
            model,
            subset,
            aligned_itemset,
            num_item,
            max_len,
            K
        ))

        rel = None
        base = spec.get('baseline_hr', None)
        if base is not None and base > 0:
            rel = hr / base
        ct = None
        solo = spec.get('solo_hr', None)
        if base is not None and base > 0 and solo is not None:
            ct = (hr / base) - (solo / base)
        return hr, rel, ct

    
    # --- Collect per-epoch evaluation stats to average at the end ---
    _per_spec = {}
    if eval_specs:
        for _spec in eval_specs:
            _nm = _spec.get('name', f"spec_{len(_per_spec)+1}")
            _per_spec[_nm] = {'hr': [], 'rel': [], 'ct': []}
    
    for epoch in range(1, epochs + 1):
        avg_loss = _train_one_epoch_nan_safe(model, crit, opt, dl)
        if eval_specs:
            parts = []
            for spec in eval_specs:
                try:
                    hr, rel, ct = _eval_one(spec)
                    # record per-epoch stats
                    _nm = spec.get('name', '?')
                    if _nm in _per_spec:
                        _per_spec[_nm]['hr'].append(hr)
                        _per_spec[_nm]['rel'].append(rel)
                        _per_spec[_nm]['ct'].append(ct)
                    msg = f"{spec['name']}: HR@{K}={hr:.4f}"
                    base_local = spec.get('baseline_hr', None)
                    if rel is not None:
                        msg += f" | rel={rel:.3f}"
                    elif base_local is not None and base_local == 0:
                        msg += " | rel=NA (baseline=0)"
                    if ct is not None:
                        msg += f" | CT={ct:.3f}"
                except Exception as e:
                    msg = f"{spec.get('name','?')}: EVAL-ERROR {type(e).__name__}: {e}"
                parts.append(msg)
            print(f"Epoch: {epoch:3d} | loss: {avg_loss:.5f} || " + " || ".join(parts), flush=True)
        else:
            print(f"Epoch: {epoch:3d} | Train loss: {avg_loss:.5f}", flush=True)

    
    # --- Compute averages across epochs and attach to model ---
    def _mean_ignore_none(xs):
        _vals = [x for x in xs if x is not None]
        return float(sum(_vals)/len(_vals)) if _vals else None
    avg_stats = {}
    if eval_specs:
        for _nm, d in _per_spec.items():
            avg_stats[_nm] = {
                'hr': _mean_ignore_none(d['hr']),
                'rel': _mean_ignore_none(d['rel']),
                'ct': _mean_ignore_none(d['ct']),
                'epochs': len(d['hr'])
            }
    try:
        model._avg_eval_stats = avg_stats
    except Exception:
        pass
    
    return mds, user_pos_mod, None, None, model

In [29]:
import torch

@torch.no_grad()
def hr_for_itemset(model, user_pos, item_set, num_items, max_len, k=10, batch_size=2048, device=None):
    dev = device or next(model.parameters()).device
    PAD = 0
    mask_guess = num_items + 1  # we'll clamp after we see V_model

    # Build inputs (prefix + MASK) and remember "seen" tokens per row
    inps, seen_lists = [], []
    for _, seq in user_pos.items():
        if not seq:
            continue
        # use prefix (next-item style); keep up to max_len-1 then place MASK at the end
        cond = seq[:-1] if len(seq) >= 1 else seq
        cond = cond[-(max_len-1):]
        pad = (max_len - 1) - len(cond)
        inps.append(([PAD] * pad) + cond + [mask_guess])
        seen_lists.append(set(cond))

    if not inps:
        return 0.0

    X = torch.tensor(inps, dtype=torch.long, device=dev)
    target_ids = torch.as_tensor(item_set, dtype=torch.long).flatten().tolist()

    hits, total = 0, 0
    B = X.size(0)
    for start in range(0, B, batch_size):
        end = min(B, start + batch_size)
        Xb = X[start:end]

        # First forward to discover V_model
        logits = model(Xb)[:, -1, :]             # (b, V_model)
        V = logits.shape[1]

        # Clamp tokens in the batch to valid range [0, V-1] and recompute logits to be safe
        Xb = Xb.clamp_max(V - 1)
        logits = model(Xb)[:, -1, :]

        # Decide MASK id safely
        MASK = min(num_items + 1, V - 1)

        # Suppress PAD & MASK
        logits[:, 0] = -1e9
        if 0 <= MASK < V:
            logits[:, MASK] = -1e9

        # Suppress anything beyond real (items + special) if model head is larger
        # (items are 1..num_items; special are 0 (PAD) and MASK)
        if V > num_items + 2:
            logits[:, (num_items + 2):] = -1e9

        # Seen-mask with the correct width V
        seen_mask = torch.zeros((Xb.size(0), V), dtype=torch.bool, device=dev)
        for r, sset in enumerate(seen_lists[start:end]):
            for t in sset:
                if 0 <= int(t) < V:
                    seen_mask[r, int(t)] = True
        #logits = logits.masked_fill(seen_mask, -1e9)

        # Target mask (width V)
        target_mask = torch.zeros(V, dtype=torch.bool, device=dev)
        for t in target_ids:
            t = int(t)
            # only allow real item ids 1..num_items, and also within model width
            if 1 <= t <= num_items and t < V:
                target_mask[t] = True

        # Top-k hits
        topk_idx = torch.topk(logits, k=k, dim=1).indices
        hit_rows = target_mask[topk_idx].any(dim=1)
        hits += int(hit_rows.sum().item())
        total += Xb.size(0)
    
    
    return float(hits / max(1, total))

@torch.no_grad()
def relative_hr(model_variant, model_baseline, user_pos, item_set, num_items, max_len, k=10):
    g_base = hr_for_itemset(model_baseline, user_pos, item_set, num_items, max_len, k=k)
    g_var  = hr_for_itemset(model_variant,   user_pos, item_set, num_items, max_len, k=k)
    rel    = (g_var / g_base) if g_base > 0 else 0.0
    return rel, g_var, g_base

In [30]:
def score_next_items(model, cond_seq, num_items, max_len):
    """
    Robust scorer that:
    - Detects the model's embedding vocab size by scanning modules for nn.Embedding.
    - Ensures all token ids are within [0, vocab_size).
    - Uses PAD=0 and MASK=1 when possible (common BERT4Rec scheme).
    - Suppresses PAD, MASK, and seen tokens in logits.
    """
    import torch
    import torch.nn as nn

    # 1) Find an embedding to infer vocab_size
    vocab_size = None
    for m in model.modules():
        if isinstance(m, nn.Embedding):
            try:
                vocab_size = int(m.num_embeddings)
                break
            except Exception:
                continue

    PAD = 0
    MASK = 1 if vocab_size is not None and vocab_size > 1 else (num_items + 2)

    # 2) Build masked sequence with strict truncation
    s = cond_seq[-max_len:] if len(cond_seq) > max_len else cond_seq
    if len(s) > max_len - 1:
        s = s[-(max_len - 1):]

    # 3) Sanitize IDs to prevent CUDA device-side asserts
    if vocab_size is not None:
        s = [int(x) for x in s if 0 <= int(x) < vocab_size]
        # Final check for MASK validity
        if not (0 <= MASK < vocab_size):
            MASK = 1 if vocab_size > 1 else 0

    pad = max_len - len(s) - 1
    if pad < 0: pad = 0
    seq_ids = [PAD]*pad + s + [MASK]

    DEVICE = next(model.parameters()).device
    inp = torch.tensor([seq_ids], dtype=torch.long, device=DEVICE)

    # 4) Forward -> logits for the MASK position
    out = model(inp)
    logits = out[0, -1].clone()

    # 5) Suppress invalid or special predictions
    if 0 <= PAD < logits.numel():   logits[PAD] = -1e9
    if 0 <= MASK < logits.numel():  logits[MASK] = -1e9
    for seen in set(s):
        if 0 <= seen < logits.numel():
            logits[seen] = -1e9

    # 6) If num_items provided < logits size, clip tail so metrics only consider real items
    if isinstance(num_items, int) and num_items < logits.numel():
        logits[num_items:] = -1e9

    return logits  # (vocab,)

# ✅ Paper-faithful item_set builders (post-baseline)

In [31]:

import pandas as pd
from collections import Counter

V_TARGET = 10 if 'config' not in globals() else config.get('V', 10)

def _member_mean_topV(ratings_df, member_raw_ids, V):
    sub = ratings_df[ratings_df['userId'].isin(set(member_raw_ids))]
    if sub.empty:
        return []
    means = (sub.groupby('movieId')['rating']
                 .mean().sort_values(ascending=False)
                 .head(V).index.tolist())
    return means

def _member_popular(ratings_df, member_raw_ids, V):
    sub = ratings_df[ratings_df['userId'].isin(set(member_raw_ids))]
    pop = (sub.groupby('movieId')['rating']
                 .agg(['count','mean'])
                 .sort_values(['count','mean'], ascending=[False, False])
                 .head(V).index.tolist())
    return pop

def _global_popular(ratings_df, V):
    pop = (ratings_df.groupby('movieId')['rating']
                 .agg(['count','mean'])
                 .sort_values(['count','mean'], ascending=[False, False])
                 .head(V).index.tolist())
    return pop

def _encode_items(raw_ids, item_encoder):
    enc0 = [item_encoder[r] for r in raw_ids if r in item_encoder]
    return [e + 1 for e in enc0]  # BERT4Rec label space

def build_item_set_for_members(mds, members_idx, V=10, min_support=None):
    """
    Paper-faithful group target selection with optional positive support filter.
    - Compute top-V by mean rating within the collective.
    - If min_support is set, require at least that many member ratings >= 4.
    - If fewer than V items, pad using 'member-popular' items from the same collective.
    - No global fallback (strict cluster-only).
    """
    member_raw = [mds.user_decoder[u] for u in members_idx if u in mds.user_decoder]
    if not member_raw:
        return [], []
    df = mds.df[mds.df['userId'].isin(set(member_raw))][['userId','movieId','rating']].copy()
    if df.empty:
        return [], []

    # Step 1: compute candidate items
    if isinstance(min_support, int) and min_support > 0:
        pos = df[df['rating'] >= 4.0]
        sup = (pos.groupby('movieId')['rating'].count()).rename('pos_support')
        means = (df.groupby('movieId')['rating'].mean()).rename('mean_rating').to_frame()
        means = means.join(sup, how='left').fillna({'pos_support': 0})
        means = means[means['pos_support'] >= min_support]
        top_raw = (means.sort_values(['mean_rating','pos_support'], ascending=[False, False])
                         .head(V).index.tolist())
    else:
        top_raw = (df.groupby('movieId')['rating']
                     .mean().sort_values(ascending=False).head(V).index.tolist())

    # Step 2: pad with member-popular if needed
    if len(top_raw) < V:
        # count how often each item was rated by members
        pop = df.groupby('movieId')['rating'].count().sort_values(ascending=False)
        for iid in pop.index:
            if iid not in top_raw:
                top_raw.append(iid)
            if len(top_raw) >= V:
                break

    # Step 3: encode
    top_enc = [mds.item_encoder[mid] + 1 for mid in top_raw if mid in mds.item_encoder]
    return top_raw[:V], top_enc[:V]



def build_item_set_for_members(mds, members_idx, V=10):
    ratings_df = mds.df
    user_decoder = mds.user_decoder
    member_raw = [user_decoder[u] for u in members_idx if u in user_decoder]

    raw = _member_mean_topV(ratings_df, member_raw, V)
    if len(raw) < V:
        for iid in _member_popular(ratings_df, member_raw, V*3):
            if iid not in raw:
                raw.append(iid)
            if len(raw) >= V: break
    if len(raw) < V:
        for iid in _global_popular(ratings_df, V*5):
            if iid not in raw:
                raw.append(iid)
            if len(raw) >= V: break

    enc = _encode_items(raw[:V], mds.item_encoder)
    return raw[:V], enc[:V]


In [32]:
import numpy as np

def _cluster_centroids(U, labels, Q):
    """Return (Q x d) centroid matrix in user-embedding space."""
    centroids = []
    for q in range(Q):
        idx = np.where(labels == q)[0]
        # KMeans guarantees non-empty clusters; just in case:
        if len(idx) == 0:
            # fallback: random user as centroid
            centroids.append(U[np.random.randint(len(U))])
        else:
            centroids.append(U[idx].mean(axis=0))
    C = np.vstack(centroids)
    return C

def _normalize_rows(X, eps=1e-12):
    n = np.linalg.norm(X, axis=1, keepdims=True)
    return X / (n + eps)

def _farthest_pair_indices(centroids, metric="cosine"):
    """
    Return indices (i, j) of the two centroids that are maximally distant.
    """
    if metric == "cosine":
        Cn = _normalize_rows(centroids)
        # cosine distance = 1 - cosine similarity
        sims = Cn @ Cn.T
        dists = 1.0 - np.clip(sims, -1.0, 1.0)
    else:
        # Euclidean
        diff = centroids[:, None, :] - centroids[None, :, :]
        dists = np.sqrt((diff * diff).sum(axis=2))
    # ignore diagonal
    np.fill_diagonal(dists, -np.inf)
    ij = np.unravel_index(np.argmax(dists), dists.shape)
    return int(ij[0]), int(ij[1]), float(dists[ij])


def _balanced_farthest_pair(centroids, labels, user_index, mds, metric="cosine", balance_tol=0.5):
    """
    Pick two clusters that are far apart but balanced in mean item popularity.
    balance_tol: allowable relative difference in mean popularity (e.g., 0.25 = 25%).
    Returns (i, j, distance).
    """
    # Popularity of each item from training data
    from collections import Counter
    cnt = Counter(i for seq in mds.user_train.values() for i in seq)

    # Mean popularity per cluster (average of users' mean item popularity)
    cluster_pop = {}
    Q = len(centroids)
    for q in range(Q):
        user_idxs = [k for k, lab in enumerate(labels) if lab == q]
        pops = []
        for k in user_idxs:
            u = user_index[k]
            seq = mds.user_train.get(u, [])
            if seq:
                pops.append(float(np.mean([cnt[i] for i in seq])))
        cluster_pop[q] = float(np.mean(pops)) if len(pops) > 0 else 0.0

    # Pairwise distances between centroids
    if metric == "cosine":
        Cn = centroids / (np.linalg.norm(centroids, axis=1, keepdims=True) + 1e-12)
        sims = Cn @ Cn.T
        dists = 1.0 - np.clip(sims, -1.0, 1.0)
    else:
        diff = centroids[:, None, :] - centroids[None, :, :]
        dists = np.sqrt((diff * diff).sum(axis=2))
    np.fill_diagonal(dists, -np.inf)

    # Choose farthest pair among those with similar popularity
    best_pair, best_dist = None, -np.inf
    for i in range(Q):
        for j in range(i+1, Q):
            pa, pb = cluster_pop.get(i, 0.0), cluster_pop.get(j, 0.0)
            if pa == 0.0 or pb == 0.0:
                continue
            rel_diff = abs(pa - pb) / max(pa, pb)
            if rel_diff <= balance_tol:
                d = dists[i, j]
                if d > best_dist:
                    best_dist = d
                    best_pair = (i, j)

    if best_pair is None:
        # Fallback to farthest regardless of balance
        i, j = np.unravel_index(np.argmax(dists), dists.shape)
        return int(i), int(j), float(dists[i, j])
    return int(best_pair[0]), int(best_pair[1]), float(best_dist)


In [33]:
def check_edits(df_before, df_after, users_idx, target_items_raw,
                action, user_decoder, sample_k=5, add_competing=False, label=None):
    import numpy as np, pandas as pd, random
    from collections import Counter

    tset = set(int(x) for x in target_items_raw) if target_items_raw else set()
    user_ids = [user_decoder[u] for u in users_idx if u in user_decoder]
    if not user_ids or not tset:
        print("[check_edits] nothing to check (empty users or targets)")
        return

    dfb = df_before.sort_values(["userId","timestamp"], kind="mergesort").copy()
    dfa = df_after .sort_values(["userId","timestamp"], kind="mergesort").copy()

    sample_users = random.sample(user_ids, k=min(sample_k, len(user_ids)))

    def ts_valid_for_user(g):
        ts = g["timestamp"].values
        if len(ts) >= 2:
            return int(ts[-2])                  # boundary from BEFORE
        return np.iinfo(np.int64).max          # no tail if < 2

    ok = True
    for uid in sample_users:
        gb = dfb[dfb.userId == uid]
        ga = dfa[dfa.userId == uid]
        if gb.empty:
            continue
        ts_valid = ts_valid_for_user(gb)

        # Split using BEFORE's boundary; tail are rows with ts >= ts_valid
        b_tail = gb[gb.timestamp >= ts_valid][["movieId","timestamp"]].sort_values(["timestamp","movieId"]).reset_index(drop=True)
        a_tail = ga[ga.timestamp >= ts_valid][["movieId","timestamp"]].sort_values(["timestamp","movieId"]).reset_index(drop=True)

        # Order-robust equality
        if not a_tail.equals(b_tail):
            ok = False
            print(f"[check_edits:{action}] uid={uid}: tail (valid/test) changed, which should not happen")

        # Train region diagnostics
        b_train = gb[gb.timestamp < ts_valid]
        a_train = ga[ga.timestamp < ts_valid]
        if action == "promote":
            # must have new target rows in train
            added = a_train.merge(b_train, how="outer", indicator=True)
            added = added[added["_merge"] == "left_only"]
            if added[added.movieId.isin(tset)].empty:
                ok = False
                print(f"[check_edits:promote] uid={uid}: no new target interactions in TRAIN")
        elif action == "demote":
            # train targets should be removed
            b_train_targets = b_train[b_train.movieId.isin(tset)]
            if not b_train_targets.empty:
                a_train_targets = a_train[a_train.movieId.isin(tset)]
                if len(a_train_targets) >= len(b_train_targets):
                    ok = False
                    print(f"[check_edits:demote] uid={uid}: train targets not removed "
                          f"(before={len(b_train_targets)} after={len(a_train_targets)})")
            if add_competing:
                added = a_train.merge(b_train, how="outer", indicator=True)
                added = added[added["_merge"] == "left_only"]
                if added[~added.movieId.isin(tset)].empty:
                    ok = False
                    print(f"[check_edits:demote] uid={uid}: expected a competing interaction, found none")

    if ok:
        tag = f"{label}:" if label else ""
        print(f"[check_edits:{tag}{action}] All sampled users respect train-only edit invariants ✓")



def hr_split(model, test_dict, item_set_enc, members, NUM_ITEMS, max_len, k=10):
    """Compute HR@K separately for members and non-members, given test_dict."""
    members_set = set(members)
    member_dict = {u: seq for u, seq in test_dict.items() if u in members_set}
    nonmember_dict = {u: seq for u, seq in test_dict.items() if u not in members_set}

    hr_members    = hr_for_itemset(model, member_dict,  item_set_enc, NUM_ITEMS, max_len, k) if member_dict else 0.0
    hr_nonmembers = hr_for_itemset(model, nonmember_dict, item_set_enc, NUM_ITEMS, max_len, k) if nonmember_dict else 0.0
    return hr_members, hr_nonmembers



# 🚀 RQ2 Experiment (baseline → collectives → item_set → interventions)

In [None]:
# ===== Collective experiment params (from paper notebook) =====
N_values   = [25, 50, 100]            # collective sizes
p_values   = [0.1,0.25, 0.5, 0.75, 1.0]  # homogeneity
trials_per_case = 50
TOPK       = 10
V_TARGET   = config.get('V', 10)
NUM_CLUSTERS = 10
SEED_MODE  = locals().get("SEED_MODE", "balanced_maxdist")  # keep your previous choice

scenarios = [
    ('promote', 'promote', 'both_promote'),
    ('demote',  'demote',  'both_demote'),
    ('promote', 'demote',  'A_promote_B_demote'),
    ('demote',  'promote', 'A_demote_B_promote'),
]

# ---------------- Repro & "baseline pack" paths ----------------
import os, json, pickle, sys, gc, random
import numpy as np
import torch

BASELINE_DIR  = "baseline_pack"
BASELINE_PATH = os.path.join(BASELINE_DIR, "baseline_best.pt")
MDS_PATH      = os.path.join(BASELINE_DIR, "mds.pkl")
SPLITS_PATH   = os.path.join(BASELINE_DIR, "splits.pkl")
META_PATH     = os.path.join(BASELINE_DIR, "meta.json")
os.makedirs(BASELINE_DIR, exist_ok=True)

# Determinism: keep this consistent across machines
os.environ["PYTHONHASHSEED"] = str(config['seed'])
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"  # safe on CPU; improves CUDA determinism

random.seed(config['seed'])
np.random.seed(config['seed'])
torch.manual_seed(config['seed'])
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(config['seed'])
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
try:
    torch.use_deterministic_algorithms(True, warn_only=True)
except Exception:
    pass

# ---------------- Load or create the baseline pack ----------------
def _save_baseline_pack(mds, user_train, user_valid, user_test, baseline_model):
    torch.save(baseline_model.state_dict(), BASELINE_PATH)
    with open(MDS_PATH, "wb") as f:
        pickle.dump(mds, f, protocol=pickle.HIGHEST_PROTOCOL)
    with open(SPLITS_PATH, "wb") as f:
        pickle.dump((user_train, user_valid, user_test), f, protocol=pickle.HIGHEST_PROTOCOL)
    meta = {
        "seed": int(config['seed']),
        "config": {k: (int(v) if isinstance(v, (np.integer,)) else v) for k, v in config.items()},
        "versions": {
            "python": sys.version,
            "torch": getattr(torch, "__version__", "unknown"),
            "numpy": np.__version__,
        },
    }
    with open(META_PATH, "w") as f:
        json.dump(meta, f, indent=2)
    print(f"[baseline] Saved baseline pack to {BASELINE_DIR}/")

def _load_baseline_pack():
    with open(MDS_PATH, "rb") as f:
        mds = pickle.load(f)
    with open(SPLITS_PATH, "rb") as f:
        user_train, user_valid, user_test = pickle.load(f)

    # Fast path to get a model object with loaded weights WITHOUT retraining:
    # reuse your existing helper that builds the model and loads a state dict.
    # We piggy-back on `retrain_from_baseline_on_ratings` with epochs=0.
    mds_s, tr_s, va_s, te_s, model_s = retrain_from_baseline_on_ratings(
        baseline_state_path=BASELINE_PATH,
        ratings_df=mds.df,              # original, unedited data
        mds=mds,
        config=config,
        epochs=0,                       # <- no training
        eval_split=None,
        item_set=None,
        eval_specs=[],
        k=TOPK,
    )
    # Make sure we keep the original splits we loaded
    return mds, user_train, user_valid, user_test, model_s

if os.path.exists(BASELINE_PATH) and os.path.exists(MDS_PATH) and os.path.exists(SPLITS_PATH):
    print("[baseline] Found existing baseline pack — loading.")
    mds, user_train, user_valid, user_test, baseline_model = _load_baseline_pack()
else:
    print("[baseline] No baseline pack found — training once and saving.")
    # Your existing function that prepares data and trains the baseline:
    mds, user_train, user_valid, user_test, baseline_model = prepare_baseline(config)
    _save_baseline_pack(mds, user_train, user_valid, user_test, baseline_model)

# If you want to distribute only the weights (and rebuild splits locally),
# you can still rely on BASELINE_PATH below. For determinism, keep seeds as above.

# ---------------- Rest of your pipeline (unchanged except path var) ----------------
NUM_ITEMS = resolve_num_items(mds, baseline_model)

U, user_index = compute_user_embeddings(baseline_model, user_train, NUM_ITEMS, config['max_len'])
from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=NUM_CLUSTERS, random_state=42).fit(U)
labels = kmeans.labels_
centroids = _cluster_centroids(U, labels, NUM_CLUSTERS)
print("V_model =", baseline_model.out.out_features)          # should be num_items+2
print("num_items =", NUM_ITEMS)
print("expected V =", NUM_ITEMS + 2)
results = []

for N in N_values:
    for p in p_values:
        # ---- Build two collectives A and B ONCE per (N, p) ----
        if SEED_MODE == "uniform":
            seedA = np.random.randint(NUM_CLUSTERS)
            seedB = (seedA + np.random.randint(1, NUM_CLUSTERS)) % NUM_CLUSTERS
        elif SEED_MODE == "balanced_maxdist":
            i, j, d = _balanced_farthest_pair(centroids, labels, user_index, mds, metric="cosine", balance_tol=0.7)
            seedA, seedB = i, j
        elif SEED_MODE == "maxdist":
            i, j, d = _farthest_pair_indices(centroids, metric="cosine")
            seedA, seedB = i, j
        else:
            raise ValueError(f"Unknown SEED_MODE={SEED_MODE}")

        C1 = build_collective_quota(seedA, labels, user_index, size=N, p=p, seed=42)
        C2 = build_collective_quota(seedB, labels, user_index, size=N, p=p, seed=42)

        # ---- Build fixed item sets from ORIGINAL ratings (via mds.df) ----
        raw_A, enc_A = build_item_set_for_members(mds, C1, V=V_TARGET, min_support=3)
        raw_B, enc_B = build_item_set_for_members(mds, C2, V=V_TARGET, min_support=3)
        
        print(f"[item_set] A: raw={len(raw_A)} enc={len(enc_A)}  |  B: raw={len(raw_B)} enc={len(enc_B)}")

        # ---- Evaluate BASELINE HR@K for each collective on its own item_set ----
        subset_ALL = user_test
        g1_base = float(hr_for_itemset(baseline_model, subset_ALL, enc_A, NUM_ITEMS, config['max_len'], TOPK))
        g2_base = float(hr_for_itemset(baseline_model, subset_ALL, enc_B, NUM_ITEMS, config['max_len'], TOPK))
        print(g1_base)
        print(g2_base)

        from collections import Counter
        cnt = Counter(i for seq in mds.user_train.values() for i in seq)
        popA = sum(cnt[i] for i in enc_A)/len(enc_A)
        popB = sum(cnt[i] for i in enc_B)/len(enc_B)
        print("Mean frequency — A:", popA, " | B:", popB)

        # ---- Precompute edited dataframes ONCE per (N, p) ----
        df0 = mds.df.copy()
        df_A_prom = edit_ratings(df0, C1, raw_A, action="promote", user_decoder=mds.user_decoder)
        df_A_demo = edit_ratings(df0, C1, raw_A, action="demote",  user_decoder=mds.user_decoder)
        df_B_prom = edit_ratings(df0, C2, raw_B, action="promote", user_decoder=mds.user_decoder)
        df_B_demo = edit_ratings(df0, C2, raw_B, action="demote",  user_decoder=mds.user_decoder)

        # Example usage
        check_edits(df0, df_A_prom, C1, raw_A, action="promote", user_decoder=mds.user_decoder, label="A_prom")
        check_edits(df0, df_A_demo, C1, raw_A, action="demote", user_decoder=mds.user_decoder, label="A_demo")
        check_edits(df0, df_B_prom, C2, raw_B, action="promote", user_decoder=mds.user_decoder, label="B_prom")
        check_edits(df0, df_B_demo, C2, raw_B, action="demote", user_decoder=mds.user_decoder, label="B_demo")

        # Helper to retrain and evaluate a single collective (averaged across epochs on TEST users)
        def run_single(edited_df, members, item_set_enc, name, baseline_hr):
            mds_s, tr_s, va_s, te_s, model_s = retrain_from_baseline_on_ratings(
                baseline_state_path=BASELINE_PATH,   # <— use saved baseline
                ratings_df=edited_df,
                mds=mds,
                config=config,
                epochs=10,
                eval_split="user_test",
                item_set=item_set_enc,
                eval_specs=[{"name": name, "members_idx": user_test, "item_set_enc": item_set_enc, "baseline_hr": baseline_hr}],
                k=TOPK
            )
            stats = getattr(model_s, "_avg_eval_stats", {}) or {}
            s = stats.get(name, {}) or {}
            return s.get('hr'), s.get('rel'), s.get('ct')

        # Two-collective interventions (apply both edits then retrain once)
        def run_joint(df_A_action, df_B_action, solo_A, solo_B):
            df_joint = edit_ratings(mds.df.copy(), C1, raw_A, action=df_A_action, user_decoder=mds.user_decoder)
            df_joint = edit_ratings(df_joint,        C2, raw_B, action=df_B_action, user_decoder=mds.user_decoder)
            mds_j, tr_j, va_j, te_j, model_j = retrain_from_baseline_on_ratings(
                baseline_state_path=BASELINE_PATH,   # <— use saved baseline
                ratings_df=df_joint,
                mds=mds,
                config=config,
                epochs=10,
                eval_split="user_test",
                eval_specs=[
                    {"name": "A", "members_idx": user_test, "item_set_enc": enc_A, "baseline_hr": g1_base, "solo_hr": solo_A},
                    {"name": "B", "members_idx": user_test, "item_set_enc": enc_B, "baseline_hr": g2_base, "solo_hr": solo_B},
                ],
                k=TOPK,
            )
            stats = getattr(model_j, "_avg_eval_stats", {}) or {}
            A = stats.get("A", {}) or {}
            B = stats.get("B", {}) or {}
            return A.get('hr'), A.get('rel'), A.get('ct'), B.get('hr'), B.get('rel'), B.get('ct')

        # ---- Trials: SAME groups, repeated trainings ----
        for trial in range(trials_per_case):
            # Solo A (promote/demote)
            g1_demo, g1_demo_rel, g1_demo_ct = run_single(df_A_demo, C1, enc_A, 'A', g1_base)
            hr_mem_A, hr_nonmem_A = hr_split(baseline_model, user_test, enc_A, C1, NUM_ITEMS, config['max_len'], TOPK)
            print(f"[sanity] HR@10 A demote: members={hr_mem_A:.3f}, non-members={hr_nonmem_A:.3f}, overall={g1_demo:.3f}")
            g1_prom, g1_prom_rel, g1_prom_ct = run_single(df_A_prom, C1, enc_A, 'A', g1_base)
            # Solo B
            g2_demo, g2_demo_rel, g2_demo_ct = run_single(df_B_demo, C2, enc_B, 'B', g2_base)
            hr_mem_B, hr_nonmem_B = hr_split(baseline_model, user_test, enc_B, C2, NUM_ITEMS, config['max_len'], TOPK)
            print(f"[sanity] HR@10 B demote: members={hr_mem_B:.3f}, non-members={hr_nonmem_B:.3f}, overall={g2_demo:.3f}")
            g2_prom, g2_prom_rel, g2_prom_ct = run_single(df_B_prom, C2, enc_B, 'B', g2_base)


            # Two-collective: both promote / both demote / criss-cross
            gA_pp, gA_pp_rel, gA_pp_ct, gB_pp, gB_pp_rel, gB_pp_ct = run_joint("promote", "promote", g1_prom, g2_prom)
            gA_dd, gA_dd_rel, gA_dd_ct, gB_dd, gB_dd_rel, gB_dd_ct = run_joint("demote", "demote", g1_demo, g2_demo)
            gA_pd, gA_pd_rel, gA_pd_ct, gB_pd, gB_pd_rel, gB_pd_ct = run_joint("promote", "demote", g1_prom, g2_demo)
            gA_dp, gA_dp_rel, gA_dp_ct, gB_dp, gB_dp_rel, gB_dp_ct = run_joint("demote", "promote", g1_demo, g2_prom)

            results.append({
                "N": N, "p": p, "trial": trial,
                "gA_base": g1_base, "gB_base": g2_base,
                "gA_prom": g1_prom, "gA_demo": g1_demo,
                "gB_prom": g2_prom, "gB_demo": g2_demo,
                "gA_pp": gA_pp, "gB_pp": gB_pp,
                "gA_dd": gA_dd, "gB_dd": gB_dd,
                "gA_pd": gA_pd, "gB_pd": gB_pd,
                "gA_dp": gA_dp, "gB_dp": gB_dp,
                "rA_prom": g1_prom_rel, "ctA_prom": g1_prom_ct,
                "rA_demo": g1_demo_rel, "ctA_demo": g1_demo_ct,
                "rB_prom": g2_prom_rel, "ctB_prom": g2_prom_ct,
                "rB_demo": g2_demo_rel, "ctB_demo": g2_demo_ct,
                "rA_pp": gA_pp_rel, "ctA_pp": gA_pp_ct,
                "rB_pp": gB_pp_rel, "ctB_pp": gB_pp_ct,
                "rA_dd": gA_dd_rel, "ctA_dd": gA_dd_ct,
                "rB_dd": gB_dd_rel, "ctB_dd": gB_dd_ct,
                "rA_pd": gA_pd_rel, "ctA_pd": gA_pd_ct,
                "rB_pd": gB_pd_rel, "ctB_pd": gB_pd_ct,
                "size_A": len(C1), "size_B": len(C2)
            })

            # free between trials
            gc.collect()
            if DEVICE == "cuda":
                torch.cuda.empty_cache()

# Save & preview
import pandas as pd
res_df = pd.DataFrame(results)
res_df.to_csv("rq2_results_relative_hr.csv", index=False)
display(res_df.head())
print("Saved rq2_results_relative_hr.csv")



[baseline] Found existing baseline pack — loading.


  ckpt = torch.load(baseline_state_path, map_location=dev)
user embeddings: 100%|██████████| 6040/6040 [00:29<00:00, 202.68it/s]


V_model = 3708
num_items = 3706
expected V = 3708
[item_set] A: raw=10 enc=10  |  B: raw=10 enc=10
0.013079470198675497
0.01705298013245033
Mean frequency — A: 467.9  | B: 531.6
[check_edits:A_prom:promote] All sampled users respect train-only edit invariants ✓
[check_edits:A_demo:demote] All sampled users respect train-only edit invariants ✓
[check_edits:B_prom:promote] All sampled users respect train-only edit invariants ✓
[check_edits:B_demo:demote] All sampled users respect train-only edit invariants ✓


  ckpt = torch.load(baseline_state_path, map_location=dev)


Epoch:   1 | loss: 7.16203 || A: HR@10=0.0073 | rel=0.557
Epoch:   2 | loss: 6.30247 || A: HR@10=0.0030 | rel=0.228
Epoch:   3 | loss: 5.88589 || A: HR@10=0.0028 | rel=0.215
Epoch:   4 | loss: 5.54958 || A: HR@10=0.0028 | rel=0.215


KeyboardInterrupt: 