In [None]:
# ============================================================
# PCA 768 → 192 (DICT PICKLES) — END-TO-END SINGLE CELL
# ============================================================

import pickle
import numpy as np
import os
from tqdm import tqdm
from sklearn.decomposition import IncrementalPCA

# ---------------- CONFIG ----------------
FILES = [
    "/teamspace/studios/this_studio/headline_T5.pkl",
    "/teamspace/studios/this_studio/newsbody_T5.pkl",
    "/teamspace/studios/this_studio/summary_T5.pkl",
]

TARGET_DIM = 192
BATCH_SIZE = 4096
PCA_MODEL_PATH = "/teamspace/studios/this_studio/pca_192_model.pkl"

# ---------------- FIT PCA ----------------
ipca = IncrementalPCA(n_components=TARGET_DIM)

print("\n[1/3] Fitting PCA (shared across all embeddings)\n")

for path in FILES:
    with open(path, "rb") as f:
        emb_dict = pickle.load(f)

    X = np.stack(list(emb_dict.values())).astype("float32")

    for i in tqdm(
        range(0, X.shape[0], BATCH_SIZE),
        desc=f"IPCA fit: {os.path.basename(path)}"
    ):
        ipca.partial_fit(X[i:i + BATCH_SIZE])

# ---------------- SAVE PCA MODEL ----------------
with open(PCA_MODEL_PATH, "wb") as f:
    pickle.dump(ipca, f)

print(f"\nPCA model saved to: {PCA_MODEL_PATH}")

# ---------------- TRANSFORM + SAVE ----------------
print("\n[2/3] Applying PCA and saving outputs\n")

for path in FILES:
    with open(path, "rb") as f:
        emb_dict = pickle.load(f)

    keys = list(emb_dict.keys())
    X = np.stack([emb_dict[k] for k in keys]).astype("float32")

    new_dict = {}

    for i in tqdm(
        range(0, X.shape[0], BATCH_SIZE),
        desc=f"PCA transform: {os.path.basename(path)}"
    ):
        X_pca = ipca.transform(X[i:i + BATCH_SIZE])
        for j, vec in enumerate(X_pca):
            new_dict[keys[i + j]] = vec

    base, ext = os.path.splitext(path)
    out_path = base + "_pca192" + ext

    with open(out_path, "wb") as f:
        pickle.dump(new_dict, f)

    print(f"Saved: {out_path}")

# ---------------- SANITY CHECK ----------------
print("\n[3/3] Sanity check\n")

with open(FILES[0].replace(".pkl", "_pca192.pkl"), "rb") as f:
    test_dict = pickle.load(f)

dims = {v.shape[0] for v in test_dict.values()}
print("Unique embedding dims:", dims)

assert dims == {192}, "PCA output dimension mismatch"

print("\n✅ PCA pipeline completed successfully.")


In [28]:
# ============================================================
# CELL 1 — Imports
# ============================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import random
from tqdm import tqdm
from ast import literal_eval
import pickle
import time
import os

from transformers import T5ForConditionalGeneration, T5Tokenizer
# ============================================================
# CELL 2 — Load T5
# ============================================================

summarizer_model = T5ForConditionalGeneration.from_pretrained("t5-large")
tokenizer = T5Tokenizer.from_pretrained("t5-large")

summarizer_model.eval()
for p in summarizer_model.parameters():
    p.requires_grad = False


In [29]:
# ============================================================
# CELL 3 — Load News
# ============================================================

news_df = pd.read_csv("/teamspace/studios/this_studio/news_min (1).tsv", sep="\t")
nid2body = dict(zip(news_df["News ID"], news_df["Headline"]))

print("News loaded:", len(nid2body))
# ============================================================
# CELL 4 — Load Summary Text
# ============================================================

with open("/teamspace/studios/this_studio/sid2sum.pkl","rb") as f:
    sid2text = pickle.load(f)

print("Summaries:", len(sid2text))


News loaded: 113762
Summaries: 135001


In [30]:
# ============================================================
# CELL 5 — Device & Dim
# ============================================================

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hidden_dim = 192
# ============================================================
# CELL 6 — Load Embeddings
# ============================================================

with open("/teamspace/studios/this_studio/summary_T5_pca192.pkl","rb") as f:
    summary_embed = {k: torch.tensor(v, dtype=torch.float32, device=device) for k,v in pickle.load(f).items()}

with open("/teamspace/studios/this_studio/newsbody_T5_pca192.pkl","rb") as f:
    newsbody_embed = {k: torch.tensor(v, dtype=torch.float32, device=device) for k,v in pickle.load(f).items()}

with open("/teamspace/studios/this_studio/headline_T5_pca192.pkl","rb") as f:
    headline_embed = {k: torch.tensor(v, dtype=torch.float32, device=device) for k,v in pickle.load(f).items()}

embed_tables = {
    "summary": summary_embed,
    "newsbody": newsbody_embed,
    "headline": headline_embed
}


In [4]:
# ============================================================
# CELL 7 — Load Behavior Data
# ============================================================

lookup_df = pd.read_csv("/teamspace/studios/this_studio/w2p_engage_list.csv").set_index("EdgeID")
train_df = pd.read_csv("/teamspace/studios/this_studio/train_df_gold_only.csv")


In [31]:
# ============================================================
# CELL 8 — Build Tail Vocabulary
# ============================================================

tail_set = set()
for row in tqdm(train_df.itertuples(), total=len(train_df)):
    try:
        Bhist = literal_eval(row.EHist)
        Bpos = row.EPos

        for b in Bhist:
            if b in lookup_df.index:
                tail_set.add(lookup_df.loc[b,"Tail"])

        if Bpos in lookup_df.index:
            tail_set.add(lookup_df.loc[Bpos,"Tail"])
    except:
        pass

tail_ids = sorted(tail_set)
tail2idx = {t:i for i,t in enumerate(tail_ids)}
idx2tail = {i:t for t,i in tail2idx.items()}

print("Tail IDs:", len(tail2idx))


  3%|▎         | 1334/38417 [00:00<00:20, 1831.55it/s]

100%|██████████| 38417/38417 [00:20<00:00, 1842.08it/s]


Tail IDs: 64042


In [32]:
# ============================================================
# CELL 9 — Gaussian KDE Kernel
# ============================================================

LAMBDA = 0.6

def gaussian_kernel(x):
    return torch.exp(-0.5 * (x ** 2))
# ============================================================
# CELL 10 — KDE-MI (scalar, last history)
# ============================================================

def KDE_MI_scalar(e_tl, c_hd_hist):
    """
    e_tl: [D]
    c_hd_hist: [T, D] 
    Returns MI
    """

    if c_hd_hist.shape[0] == 0:
        return torch.tensor(0.0, device=e_tl.device)

    diffs = (e_tl.unsqueeze(0) - c_hd_hist) / LAMBDA        # [T,D]
    k_vals = gaussian_kernel(diffs).prod(dim=1)           # product kernel

    p_e = k_vals.mean()
    p_c = torch.ones_like(p_e)                            # constant for KDE ratio
    p_joint = p_e

    mi = torch.log((p_joint + 1e-8) / (p_e * p_c + 1e-8))
    return mi


In [33]:
# ============================================================
# CELL 11 — Action Gates
# ============================================================

class ActionGates(nn.Module):
    def __init__(self):
        super().__init__()
        self.W_clk = nn.Linear(hidden_dim,1)
        self.W_skp = nn.Linear(hidden_dim,1)
        self.W_fc  = nn.Linear(hidden_dim,1)
        self.W_summ = nn.Linear(hidden_dim,1)
        self.W_acc = nn.Linear(hidden_dim,1)

    def clk(self, e_tl, dwell):
        return (dwell * torch.tanh(self.W_clk(e_tl))).squeeze()

    def skp(self, e_tl):
        return torch.tanh(self.W_skp(e_tl)).squeeze()

    def fc(self, title_emb):
        return torch.tanh(self.W_fc(title_emb)).squeeze()

    def summ(self, e_tl):
        return torch.tanh(self.W_summ(e_tl)).squeeze()

    def acc(self, e_hd):
        return torch.tanh(self.W_acc(e_hd)).squeeze()


In [34]:
# ============================================================
# CELL 12 — Memory Kernels
# ============================================================

class MemoryKernels(nn.Module):
    def __init__(self):
        super().__init__()
        self.W_short = nn.Linear(hidden_dim,1)
        self.W_long  = nn.Linear(hidden_dim,1)
        self.W_event = nn.Linear(hidden_dim,hidden_dim)

    def K_short(self, e):
        return torch.exp(-self.W_short(e))

    def K_long(self, e):
        return torch.exp(-self.W_long(e) / torch.norm(self.W_short.weight))

    def K_event(self, e_past, e_now):
        return torch.exp(-torch.matmul(self.W_event(e_past), e_now))


In [35]:
# ============================================================
# CELL 13 — AMF Fusion
# ============================================================

class AMF(nn.Module):
    def __init__(self):
        super().__init__()
        self.W_fuse = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, m_short, m_long, m_event, e_b):
        M = torch.stack([m_short, m_long, m_event], dim=0)   # [3,D]
        M_soft = torch.softmax(M, dim=0)
        m_fuse,_ = torch.max(M_soft, dim=0)

        g = torch.tanh(self.W_fuse(e_b))
        return g * m_fuse


In [36]:
# ============================================================
# CELL 14 — Paper-Correct BehaviorEncoder
# ============================================================

class BehaviorEncoder(nn.Module):
    def __init__(self):
        super().__init__()

        self.gates = ActionGates().to(device)
        self.kernels = MemoryKernels().to(device)
        self.amf = AMF().to(device)

        self.gamma = nn.Parameter(torch.ones(hidden_dim, device=device))

        self.classifier = nn.Linear(hidden_dim, len(tail2idx)).to(device)
        self.bpos_mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        ).to(device)

    def forward(self, Bhist, Bpos, lookup_df, tail2idx, embed_tables):

        enc_loss = torch.tensor(0.0, device=device)
        prev_c_hd = torch.zeros(hidden_dim, device=device)
        e_b_hist = []
        c_hd_hist = []

        for t, b_id in enumerate(Bhist):
            if b_id not in lookup_df.index:
                continue

            row = lookup_df.loc[b_id]
            head_id, rel, tail_id = row["Head"], row["Relation"], row["Tail"]

            e_tl = embed_tables["newsbody"].get(tail_id, torch.zeros(hidden_dim,device=device))
            e_hd = prev_c_hd
            title_emb = embed_tables["headline"].get(tail_id, torch.zeros(hidden_dim,device=device))

            # ---- KDE-MI ----
            hist_stack = torch.stack(c_hd_hist[-20:]) if len(c_hd_hist)>0 else torch.zeros((0,hidden_dim),device=device)
            I = KDE_MI_scalar(e_tl, hist_stack)

            # ---- Action Modulation ----
            if rel == "click":
                dwell = row["dwell"] / lookup_df["dwell"].max()
                g = self.gates.clk(e_tl, dwell)
                c_hd = g * I

            elif rel == "skip":
                g = self.gates.skp(e_tl)
                c_hd = g * e_tl + (1-g) * I

            elif rel == "gen_summ":
                g = self.gates.fc(title_emb)
                c_hd = g * I

            elif rel == "summ_gen":
                g1 = self.gates.summ(e_tl)
                g2 = self.gates.acc(e_hd)
                c_hd = g1 * I + g2 * KDE_MI_scalar(e_tl, torch.stack(e_b_hist[-20:]) if len(e_b_hist)>0 else hist_stack)

            else:
                c_hd = e_tl

            c_hd = self.gamma * c_hd
            c_hd_hist.append(c_hd)

            # ---- Memory Kernels ----
            if len(e_b_hist) > 0:
                m_short = sum(self.kernels.K_short(e_b_hist[j]) * e_b_hist[j] for j in range(len(e_b_hist)))
                m_long  = sum(self.kernels.K_long(e_b_hist[j])  * e_b_hist[j] for j in range(len(e_b_hist)))
                m_event = sum(self.kernels.K_event(e_b_hist[j], c_hd) * e_b_hist[j] for j in range(len(e_b_hist)))
            else:
                m_short = m_long = m_event = torch.zeros(hidden_dim,device=device)

            # ---- AMF ----
            m_fuse = self.amf(m_short, m_long, m_event, c_hd)

            z_b = c_hd + m_fuse
            e_b_hist.append(z_b)
            prev_c_hd = c_hd

            if tail_id in tail2idx:
                logits = self.classifier(z_b.unsqueeze(0))
                target = torch.tensor([tail2idx[tail_id]], device=device)
                enc_loss += F.cross_entropy(logits, target)

        # ---- Next-Step Prediction ----
        z_last = e_b_hist[-1]
        e_b_pos_pred = self.bpos_mlp(z_last)

        b_next = lookup_df.loc[Bpos]
        tpos = b_next["Tail"]
        logits = self.classifier(e_b_pos_pred.unsqueeze(0))
        target = torch.tensor([tail2idx[tpos]], device=device)
        pred_loss = F.cross_entropy(logits, target)

        total_loss = (0.5*enc_loss + 0.5*pred_loss)/(100*len(Bhist)+1)

        return z_last, e_b_pos_pred, total_loss


In [37]:
# ============================================================
# CELL 15 — Pseudo-Inverse s-Node Extractor
# ============================================================

class PseudoInverseMapper(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.Wk = nn.Linear(hidden_dim, hidden_dim, bias=False)     # W_k^+
        self.Ws = nn.Linear(hidden_dim, hidden_dim, bias=False)     # W_summ^+

    def forward(self, z_b):
        e_b_hat = self.Wk(z_b)
        e_s_hat = self.Ws(e_b_hat)
        return e_s_hat
# ============================================================
# CELL 16 — Cross-Attention (Eq. 19)
# ============================================================

class CrossAttentionEq19(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.Wq = nn.Linear(hidden_dim, hidden_dim)
        self.Wk = nn.Linear(hidden_dim, hidden_dim)
        self.Wv = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, e_s, e_d):
        Q = self.Wq(e_s).unsqueeze(0)   # [1,D]
        K = self.Wk(e_d).unsqueeze(0)   # [1,D]
        V = self.Wv(e_d).unsqueeze(0)   # [1,D]

        score = torch.matmul(Q, K.T) / np.sqrt(Q.shape[-1])
        attn = torch.softmax(score, dim=-1)
        out = torch.matmul(attn, V)
        return out.squeeze(0)
# ============================================================
# CELL 17 — B2SModel with Paper-Correct Injection
# ============================================================

class B2SModel(nn.Module):
    def __init__(self, encoder, pseudo_inv, cross_attn, summarizer_model, tokenizer, nid2body, sid2text):
        super().__init__()
        self.encoder = encoder
        self.pseudo_inv = pseudo_inv
        self.cross_attn = cross_attn
        self.summarizer = summarizer_model
        self.tokenizer = tokenizer
        self.nid2body = nid2body
        self.sid2text = sid2text
        self.to_t5 = nn.Linear(hidden_dim, summarizer_model.config.d_model)

    def forward(self, Bhist, Bpos, lookup_df, tail2idx, embed_tables):

        # ---- Encode Behavior ----
        z_b, e_b_pos_pred, loss_enc = self.encoder(Bhist, Bpos, lookup_df, tail2idx, embed_tables)

        # ---- Document ----
        head_id = lookup_df.loc[Bpos]["Head"]
        e_doc = embed_tables["newsbody"].get(head_id, torch.zeros(hidden_dim,device=device))
        doc_text = self.nid2body.get(head_id,"")

        # ---- Pseudo-Inverse ----
        e_s_hat = self.pseudo_inv(z_b)

        # ---- Cross-Attention (Eq.19) ----
        e_p_summ = self.cross_attn(e_s_hat, e_doc)

        # ---- T5 Injection ----
        adapted = self.to_t5(e_p_summ).unsqueeze(0)

        inputs = self.tokenizer(doc_text, return_tensors="pt", truncation=True, padding=True).to(device)
        enc_out = self.summarizer.encoder(**inputs)

        sid = lookup_df.loc[Bpos]["Tail"]
        gold = self.sid2text.get(sid,"")
        tgt = self.tokenizer(gold, return_tensors="pt", truncation=True, padding=True).to(device)

        dec_ids = tgt["input_ids"][:,:-1]
        labels = tgt["input_ids"][:,1:]

        tok_embed = self.summarizer.get_input_embeddings()(dec_ids)
        final_embed = torch.cat([adapted.unsqueeze(1), tok_embed], dim=1)

        labels = torch.cat([torch.full((1,1),-100,device=device), labels], dim=1)

        out = self.summarizer(
            encoder_outputs=enc_out,
            decoder_inputs_embeds=final_embed,
            labels=labels
        )

        pred = torch.argmax(out.logits, dim=-1)
        pred_txt = self.tokenizer.decode(pred[0], skip_special_tokens=True)

        total_loss = 0.5*loss_enc + 0.5*out.loss
        return (total_loss, out.loss, loss_enc), pred_txt, gold, doc_text


In [38]:
# ============================================================
# CELL 18 — Build Model
# ============================================================

encoder = BehaviorEncoder().to(device)
pseudo_inv = PseudoInverseMapper(hidden_dim).to(device)
cross_attn = CrossAttentionEq19(hidden_dim).to(device)

b2s_model = B2SModel(
    encoder,
    pseudo_inv,
    cross_attn,
    summarizer_model,
    tokenizer,
    nid2body,
    sid2text
).to(device)


In [39]:
# ============================================================
# CELL 19 — Train / Test Split
# ============================================================

train_df_subset = train_df.iloc[:550].reset_index(drop=True)
train_rows = train_df_subset.iloc[:500]
test_rows  = train_df_subset.iloc[500:550]




In [40]:
print(train_rows["EPos"].head())
print(lookup_df.index[:5])


0     E84
1    E133
2    E152
3    E168
4    E230
Name: EPos, dtype: object
Index(['E1', 'E2', 'E3', 'E4', 'E5'], dtype='object', name='EdgeID')


In [41]:
from ast import literal_eval

def extract_pairs(df):
    pairs = []
    dropped = {"eval": 0, "lookup": 0, "tail": 0}

    for _, row in df.iterrows():
        # --- EHist ---
        Bhist = row["EHist"]
        if isinstance(Bhist, str):
            try:
                Bhist = literal_eval(Bhist)
            except Exception:
                dropped["eval"] += 1
                continue

        # --- EPos ---
        Bpos = row["EPos"]

        # --- lookup ---
        if Bpos not in lookup_df.index:
            dropped["lookup"] += 1
            continue

        tail = lookup_df.loc[Bpos, "Tail"]

        # --- tail vocab ---
        if tail not in tail2idx:
            dropped["tail"] += 1
            continue

        pairs.append((Bhist, Bpos))

    print("Dropped counts:", dropped)
    return pairs
train_data = extract_pairs(train_rows)
test_data  = extract_pairs(test_rows)

print("Train:", len(train_data), "Test:", len(test_data))


Dropped counts: {'eval': 0, 'lookup': 0, 'tail': 0}
Dropped counts: {'eval': 0, 'lookup': 0, 'tail': 0}
Train: 500 Test: 50


In [None]:
# ============================================================
# CELL 20 — Training Loop (with running losses)
# ============================================================

lr = 1e-3
num_epochs = 10
save_every = 2
optimizer = torch.optim.Adam(b2s_model.parameters(), lr=lr)

for epoch in range(num_epochs):
    b2s_model.train()
    tot_loss = 0.0
    tot_gen  = 0.0
    tot_enc  = 0.0
    count    = 0

    pbar = tqdm(train_data, desc=f"Epoch {epoch+1}")

    for Bhist, Bpos in pbar:
        optimizer.zero_grad()

        try:
            (loss, gen_loss, enc_loss), _, _, _ = b2s_model(
                Bhist, Bpos, lookup_df, tail2idx, embed_tables
            )
        except Exception as e:
            continue

        loss.backward()
        optimizer.step()

        tot_loss += loss.item()
        tot_gen  += gen_loss.item()
        tot_enc  += enc_loss.item()
        count += 1

        pbar.set_postfix({
            "L_now": f"{loss.item():.3f}",
            "Gen_now": f"{gen_loss.item():.3f}",
            "Enc_now": f"{enc_loss.item():.3f}",
            "L_avg": f"{tot_loss/count:.3f}",
            "Gen_avg": f"{tot_gen/count:.3f}",
            "Enc_avg": f"{tot_enc/count:.3f}"
        })

    print(
        f"Epoch {epoch+1}: "
        f"AvgTotal={tot_loss/count:.4f} | "
        f"AvgGen={tot_gen/count:.4f} | "
        f"AvgEnc={tot_enc/count:.4f}"
    )

    if (epoch+1) % save_every == 0:
        ckpt = f"b2s_model_epoch{epoch+1:03d}.pt"
        torch.save(b2s_model.state_dict(), ckpt)
        print("Saved", ckpt)


  0%|          | 0/500 [00:00<?, ?it/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
  4%|▍         | 19/500 [01:45<52:50,  6.59s/it]

In [None]:
# ============================================================
# CELL 23 — Generator with Latent Injection
# ============================================================

def generate_personalized_summary(model, Bhist, Bpos, lookup_df, embed_tables, max_len=40):
    model.eval()
    device = next(model.parameters()).device

    with torch.no_grad():
        # ---- Encode Behavior ----
        z_b, _, _ = model.encoder(Bhist, Bpos, lookup_df, tail2idx, embed_tables)

        # ---- Document ----
        head_id = lookup_df.loc[Bpos]["Head"]
        e_doc = embed_tables["newsbody"].get(head_id, torch.zeros(hidden_dim,device=device))
        doc_text = model.nid2body.get(head_id,"")

        # ---- Pseudo-Inverse ----
        e_s_hat = model.pseudo_inv(z_b)

        # ---- Cross-Attention (Eq.19) ----
        e_p_summ = model.cross_attn(e_s_hat, e_doc)

        # ---- Project to T5 ----
        latent = model.to_t5(e_p_summ).unsqueeze(0)   # [1,d_model]

        # ---- Encode Document ----
        enc = model.tokenizer(doc_text, return_tensors="pt", truncation=True, padding=True).to(device)
        enc_out = model.summarizer.encoder(**enc)

        # ---- Create prefix embedding ----
        prefix = latent.unsqueeze(1)   # [1,1,d_model]

        # ---- Start generation with BOS ----
        bos = model.tokenizer.pad_token_id
        cur_ids = torch.tensor([[bos]], device=device)

        for _ in range(max_len):
            tok_embed = model.summarizer.get_input_embeddings()(cur_ids)
            dec_embed = torch.cat([prefix, tok_embed], dim=1)

            out = model.summarizer(
                encoder_outputs=enc_out,
                decoder_inputs_embeds=dec_embed
            )

            next_token = torch.argmax(out.logits[:,-1,:], dim=-1)
            cur_ids = torch.cat([cur_ids, next_token.unsqueeze(0)], dim=1)

            if next_token.item() == model.tokenizer.eos_token_id:
                break

        return model.tokenizer.decode(cur_ids[0], skip_special_tokens=True)


In [None]:
# ============================================================
# CELL 24 — Use It for Evaluation
# ============================================================

b2s_model.eval()
results = []

with torch.no_grad():
    for Bhist,Bpos in tqdm(test_data, desc="Generating"):
        try:
            pred = generate_personalized_summary(b2s_model, Bhist, Bpos, lookup_df, embed_tables)
            gold = sid2text.get(lookup_df.loc[Bpos,"Tail"], "")
            doc  = nid2body.get(lookup_df.loc[Bpos,"Head"], "")
            true_tail = lookup_df.loc[Bpos,"Tail"]
        except:
            continue

        results.append({
            "Bpos":Bpos,
            "true_tail_id":true_tail,
            "generic summary":doc,
            "pred_summary":pred,
            "gold_summary":gold
        })

df = pd.DataFrame(results)
df.to_csv("b2s_eval_results_autoregressive.csv", index=False)
print("Saved b2s_eval_results_autoregressive.csv")
