In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import random
import math
from tqdm import tqdm
from ast import literal_eval
import pickle
import time
from torch.utils.data import DataLoader
import torch.optim as optim
from transformers import T5ForConditionalGeneration, T5Tokenizer

In [2]:
# ---------------------
# Load T5-Base Model
# ---------------------
summarizer_model = T5ForConditionalGeneration.from_pretrained("t5-base")
tokenizer = T5Tokenizer.from_pretrained("t5-base")

summarizer_model.eval()
for param in summarizer_model.parameters():
    param.requires_grad = False

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [3]:
# Load nid2body from pickle
with open("nid2body.pkl", "rb") as f:
    nid2body = pickle.load(f)

# Debug print
print(f"‚úÖ Loaded nid2body with {len(nid2body)} items")
sample_nid = list(nid2body.keys())[0]
print(f"üßæ Sample NID: {sample_nid}\nüìù Headline: {nid2body[sample_nid][:300]}")

‚úÖ Loaded nid2body with 113762 items
üßæ Sample NID: N10000
üìù Headline: Predicting Atlanta United's lineup against Columbus Crew in the U.S. Open Cup


In [4]:
# Load sid2sum from pickle
with open("sid2sum.pkl", "rb") as f:
    sid2sum = pickle.load(f)

# Debug print
print(f"‚úÖ Loaded sid2sum with {len(sid2sum)} items")
sample_sid = list(sid2sum.keys())[0]
print(f"üßæ Sample SID: {sample_sid}\nüìù Summary: {sid2sum[sample_sid][:300]}")

‚úÖ Loaded sid2sum with 135001 items
üßæ Sample SID: S-1
üìù Summary: The officer reportedly also pointed his gun at Harper and her children.


In [5]:
# === Device and Precision Setup ===
torch.set_default_dtype(torch.float32)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hidden_dim = 768

# === Utility Functions ===
def get_embedding(key, table, dim):
    if key not in table:
        table[key] = torch.nn.Parameter(torch.randn(dim, dtype=torch.float32, device=device) * 0.01, requires_grad=True)
    return table[key]


# === Load Embeddings ===
with open("summary_T5.pkl", "rb") as f:
    summary_embed = {k: torch.tensor(v, dtype=torch.float32, device=device) for k, v in pickle.load(f).items()}
with open("newsbody_T5.pkl", "rb") as f:
    newsbody_embed = {k: torch.tensor(v, dtype=torch.float32, device=device) for k, v in pickle.load(f).items()}
with open("headline_T5.pkl", "rb") as f:
    headline_embed = {k: torch.tensor(v, dtype=torch.float32, device=device) for k, v in pickle.load(f).items()}

embed_tables = {
    'summary': summary_embed,
    'newsbody': newsbody_embed,
    'headline': headline_embed
}

# === Load Dataset ===
lookup_df = pd.read_csv("w2p_engage_list.csv").set_index('EdgeID')
train_df = pd.read_csv("train_w2p.csv")

In [6]:
train_df=train_df[:50000]
train_df

Unnamed: 0.1,Unnamed: 0,UserID,EHist,EPos
0,2,U10000_1,"['E1', 'E2', 'E3', 'E4', 'E5', 'E6', 'E7', 'E8...",E84
1,3,U10000_2,"['E1', 'E2', 'E3', 'E4', 'E5', 'E6', 'E7', 'E8...",E133
2,11,U100006_1,['E151'],E152
3,12,U100006_2,"['E151', 'E152', 'E153', 'E154', 'E155', 'E156...",E168
4,13,U100006_3,"['E151', 'E152', 'E153', 'E154', 'E155', 'E156...",E230
...,...,...,...,...
49995,195583,U233775_3,"['E1882808', 'E1882809', 'E1882810', 'E1882811...",E1882859
49996,195584,U233775_4,"['E1882808', 'E1882809', 'E1882810', 'E1882811...",E1882866
49997,195585,U233775_5,"['E1882808', 'E1882809', 'E1882810', 'E1882811...",E1882869
49998,195586,U233775_6,"['E1882808', 'E1882809', 'E1882810', 'E1882811...",E1882920


In [7]:
lookup_df

Unnamed: 0_level_0,Unnamed: 0,Head,Relation,Tail,User
EdgeID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
E1,0,U10000,skip,N110699,U10000
E2,1,N110699,skip,N104733,U10000
E3,2,N104733,skip,N80645,U10000
E4,3,N80645,skip,N76869,U10000
E5,4,N76869,skip,N119531,U10000
...,...,...,...,...,...
E18413173,18413172,N107780,click,N97051,U249644
E18413174,18413173,N97051,click,N81002,U249644
E18413175,18413174,N81002,click,N11725,U249644
E18413176,18413175,N11725,click,N89229,U249644


In [None]:
# # ---------------------
# # Build Tail2Idx
# # ---------------------
# tail_set = set()

# print("Building Tail2Idx from EHist and EPos...")
# for row in tqdm(train_df.itertuples(), total=len(train_df), desc="Collecting tails"):
#     try:
#         bhist = literal_eval(row.EHist)
#         bpos = row.EPos

#         # Add tails from Bhist
#         for b_id in bhist:
#             if b_id in lookup_df.index:
#                 tail = lookup_df.loc[b_id, 'Tail']
#                 tail_set.add(tail)

#         # Add tail from Bpos
#         if bpos in lookup_df.index:
#             tail = lookup_df.loc[bpos, 'Tail']
#             tail_set.add(tail)

#     except Exception as e:
#         print(f"[Skip] Error in row {row.Index}: {e}")

# # === Final mappings ===
# tail_ids = sorted(tail_set)
# tail2idx = {tid: idx for idx, tid in enumerate(tail_ids)}
# idx2tail = {idx: tid for tid, idx in tail2idx.items()}

# print(f"‚úÖ Tail2Idx built with {len(tail2idx)} unique tail IDs.")

# # ---------------------
# # Save as pickle
# # ---------------------
# with open("tail_mappings.pkl", "wb") as f:
#     pickle.dump({"tail2idx": tail2idx, "idx2tail": idx2tail}, f)

# print("üíæ Saved tail2idx and idx2tail to tail_mappings.pkl")


In [8]:
# ---------------------
# Load back from pickle
# ---------------------
with open("tail_mappings.pkl", "rb") as f:
    mappings = pickle.load(f)

tail2idx = mappings["tail2idx"]
idx2tail = mappings["idx2tail"]

print(f"‚úÖ Loaded tail2idx with {len(tail2idx)} entries")
print(f"üßæ Sample tail ‚Üí idx: {list(tail2idx.items())[:3]}")
print(f"üßæ Sample idx ‚Üí tail: {list(idx2tail.items())[:3]}")

‚úÖ Loaded tail2idx with 544137 entries
üßæ Sample tail ‚Üí idx: [('N10000', 0), ('N100001', 1), ('N100003', 2)]
üßæ Sample idx ‚Üí tail: [(0, 'N10000'), (1, 'N100001'), (2, 'N100003')]


In [9]:
# -------------------------
# Sequence Engagement/Behavior Encoder
# -------------------------

class BehaviorEncoder(nn.Module):
    def __init__(self, hidden_dim, tail2idx, device, debug=False):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.device = device
        self.debug = debug
        self.tail2idx = tail2idx

        # === Base action vectors ===
        self.e_clk = nn.Parameter(torch.tensor([1., 0., 0., 0.], device=device))
        self.e_skp = nn.Parameter(torch.tensor([0., 1., 0., 0.], device=device))
        self.e_gensumm = nn.Parameter(torch.tensor([0., 0., 1., 0.], device=device))
        self.e_sumgen = nn.Parameter(torch.tensor([0., 0., 0., 1.], device=device))

        # Action-specific transforms
        self.W_clk = nn.Linear(4, hidden_dim, bias=False)
        self.W_skp = nn.Linear(4, hidden_dim, bias=False)
        self.W_gensumm = nn.Linear(4, hidden_dim, bias=False)
        self.W_sumgen = nn.Linear(4, hidden_dim, bias=False)

        # State transforms
        self.W_pull = nn.Linear(1, hidden_dim, bias=False)
        self.W_s = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.W_d = nn.Linear(hidden_dim, hidden_dim, bias=False)

        # Fusion
        self.Wh = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.Wc = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.Wz = nn.Linear(hidden_dim, 3, bias=False)
        self.b_z = nn.Parameter(torch.zeros(3, device=device))
        self.W_emb = nn.Linear(hidden_dim, hidden_dim, bias=True)
        self.b_emb = nn.Parameter(torch.zeros(hidden_dim, device=device))

        # Rotation/translation
        self.W_angle = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.W_theta = nn.Linear(hidden_dim, 1, bias=False)
        self.W_h = nn.Linear(hidden_dim, hidden_dim, bias=False)
        self.W_m = nn.Linear(hidden_dim, 1, bias=False)

        # Scalars
        self.alpha = nn.Parameter(torch.tensor(0.5, device=device))
        self.beta = nn.Parameter(torch.tensor(0.5, device=device))

        # Classifier head over tails
        self.classifier = nn.Linear(hidden_dim, len(tail2idx))

        # Next-step prediction head
        self.W_next = nn.Linear(hidden_dim, hidden_dim, bias=False)

    def _show(self, name, tensor, maxlen=6):
        if not self.debug: 
            return
        if isinstance(tensor, torch.Tensor):
            flat = tensor.detach().cpu().numpy().flatten()
            vals = ", ".join(f"{x:.4f}" for x in flat[:maxlen])
            if len(flat) > maxlen: vals += ", ..."
            print(f"{name} (shape={tuple(tensor.shape)}): [{vals}]")
        else:
            print(f"{name}: {tensor}")

    def softmin_pool(self, a, b):
        return -self.alpha * torch.log(torch.exp(a / self.alpha) +
                                       torch.exp(b / self.alpha) + 1e-9)

    def forward(self, Bhist, Bpos, lookup_df, tail2idx, embed_tables):
        total_loss = torch.tensor(0., dtype=torch.float32, device=self.device)

        # === PASS 1: raw step embeddings (E_seq) ===
        E_seq = []
        h_clk = torch.zeros(self.hidden_dim, device=self.device)
        h_skp = torch.zeros(self.hidden_dim, device=self.device)
        h = torch.zeros(self.hidden_dim, device=self.device)

        for t, b_id in enumerate(Bhist):
            if b_id not in lookup_df.index:
                continue
            row = lookup_df.loc[b_id]
            tail_id, rel = row['Tail'], row['Relation']

            d_i = embed_tables['newsbody'].get(tail_id, torch.zeros(self.hidden_dim, device=self.device))
            s_i = embed_tables['summary'].get(tail_id, torch.zeros(self.hidden_dim, device=self.device))
            d_i_title = embed_tables['headline'].get(tail_id, torch.zeros(self.hidden_dim, device=self.device))

            # init state
            if t == 0:
                head_emb = embed_tables['headline'].get(tail_id, torch.zeros(self.hidden_dim, device=self.device))
                h_clk, h_skp = head_emb, head_emb
                h = torch.sigmoid(self.W_s(head_emb)) * h_clk + (1 - torch.sigmoid(self.W_s(head_emb))) * h_skp

            # relation-specific context c_i
            if rel == "click":
                c_i = (self.W_clk.weight @ self.e_clk * h) * d_i
            elif rel == "skip":
                d_ip1 = torch.zeros_like(d_i)
                if t+1 < len(Bhist) and Bhist[t+1] in lookup_df.index:
                    d_ip1 = embed_tables['newsbody'].get(
                        lookup_df.loc[Bhist[t+1]]['Head'],
                        torch.zeros_like(d_i)
                    )
                pull_term = self.W_pull(torch.tensor([[torch.dot(h_clk, d_ip1)+torch.dot(h_skp, d_i)]],
                                                     device=self.device)).squeeze(0)
                c_i = torch.tanh(self.W_skp.weight @ self.e_skp + d_i + pull_term) * h * d_i
            elif rel == "gen_summ":
                c_i = (self.W_gensumm.weight @ self.e_gensumm * h) * d_i_title
            elif rel == "summ_gen":
                gate_summgen = self.W_s(self.W_sumgen.weight @ self.e_sumgen)
                c_i = self.softmin_pool(gate_summgen * s_i, (1 - gate_summgen) * d_i)
                h_clk = h_clk + self.W_d((torch.ones_like(d_i_title) - d_i_title) * s_i)
            else:
                c_i = d_i

            # update hidden
            z_i = self.Wh(h) + self.Wc(c_i)
            p_i = torch.softmax(self.Wz(z_i) + self.b_z, dim=-1)
            m_i = p_i[0]*0.1 + p_i[1]*0.5 + p_i[2]*0.9
            if rel == "click": h_clk = h_clk + m_i * c_i
            elif rel == "skip": h_skp = h_skp * (1 - m_i) + c_i
            h = self.beta * h_clk + (1 - self.beta) * h_skp

            e_i = torch.tanh(self.W_emb(c_i) + self.b_emb)
            E_seq.append(e_i)

        if not E_seq:
            return torch.zeros(self.hidden_dim, device=self.device), None, total_loss

        # === PASS 2: contextualize (E‚Ä≤_seq) ===
        Eprime_seq = []
        eps = 1e-9
        for i, e_i in enumerate(E_seq):
            if i == 0:
                e_prime = e_i
            else:
                e_prev, e_prime_prev = E_seq[i-1], Eprime_seq[-1]
                theta_i = math.pi * torch.tanh(self.W_theta(torch.sigmoid(self.W_angle(e_prime_prev))))
                m_i = F.softplus(self.W_m(self.W_h(e_prime_prev)))

                v_i = (e_i - e_prime_prev) / (e_i - e_prime_prev).norm(p=2).clamp(min=eps)
                u_prev = e_prev / e_prime_prev.norm(p=2).clamp(min=eps)
                o_i = (v_i - torch.dot(v_i, u_prev) * u_prev)
                o_i = o_i / o_i.norm(p=2).clamp(min=eps)

                e_prime = e_prime_prev + m_i * (torch.cos(theta_i) * u_prev + torch.sin(theta_i) * o_i).squeeze()

            Eprime_seq.append(e_prime)

            # per-step auxiliary loss
            tail_id = lookup_df.loc[Bhist[i]]['Tail']
            if tail_id in self.tail2idx:
                logits_step = self.classifier(e_prime.unsqueeze(0))
                target_step = torch.tensor([self.tail2idx[tail_id]], device=self.device)
                total_loss = total_loss + F.cross_entropy(logits_step, target_step)

        # === Final prediction on Bpos ===
        eprime_last = Eprime_seq[-1]

        # Next-step embedding prediction
        eprime_next = self.W_next(eprime_last)

        logits_pos = None
        if Bpos in lookup_df.index:
            tail_id_pos = lookup_df.loc[Bpos]['Tail']
            if tail_id_pos in self.tail2idx:
                logits_pos = self.classifier(eprime_next.unsqueeze(0))
                target_pos = torch.tensor([self.tail2idx[tail_id_pos]], device=self.device)
                total_loss = total_loss + F.cross_entropy(logits_pos, target_pos)

        return eprime_last, eprime_next, logits_pos, total_loss


In [10]:
# -------------------------
# Inverse decoder for a single predicted embedding (d_next_pred as e'_pos)
# -------------------------
class BehaviorInverseDecoderPredict(nn.Module):
    """
    Inverse mapping that takes a single predicted e' (embedding for Bpos)
    and the head/headline embedding for Bpos, and returns:
      - c'_pos (approx pseudo-content)
      - s_hat_pos (approx summary)
    """
    def __init__(self, hidden_dim, device, debug=False):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.device = device
        self.debug = debug

        # Learnable pseudo-inverse / disentanglers
        self.W_emb_pinv = nn.Linear(hidden_dim, hidden_dim, bias=False)   # approx W_emb^+
        self.W1_pinv = nn.Linear(hidden_dim, hidden_dim, bias=False)      # remove head contribution
        self.W2_pinv = nn.Linear(hidden_dim, hidden_dim, bias=False)      # map residual -> summary

    def _show(self, name, tensor, maxlen=6):
        if not self.debug:
            return
        if isinstance(tensor, torch.Tensor):
            if tensor.ndim == 0:
                print(f"{name}: {tensor.item():.6f}")
            else:
                flat = tensor.detach().cpu().numpy().flatten()
                vals = ", ".join(f"{x:.6f}" for x in flat[:maxlen])
                if len(flat) > maxlen: vals += ", ..."
                print(f"{name} (shape={tuple(tensor.shape)}): [{vals}]")
        else:
            print(f"{name}: {tensor}")

    @staticmethod
    def atanh_safe(x, eps=1e-6):
        x = x.clamp(-1+eps, 1-eps)
        return 0.5 * torch.log((1+x) / (1-x))

    def forward(self, eprime_pos, b_emb, h_pos):
        """
        eprime_pos: tensor (hidden_dim,) -- predicted e' embedding for Bpos
        b_emb: encoder bias b_emb (tensor (hidden_dim,))
        h_pos: head/headline embedding for Bpos (tensor (hidden_dim,))
        Returns: c_prime_pos, s_hat_pos
        """
        # 1) invert embedding nonlinearity: atanh(e') - b_emb
        x = self.atanh_safe(eprime_pos) - b_emb  # (hidden_dim,)
        c_prime_pos = self.W_emb_pinv(x)         # approx c'_pos

        # 2) subtract head contribution and map to summary
        residual = c_prime_pos - self.W1_pinv(h_pos)
        s_hat_pos = self.W2_pinv(residual)

        # # debug prints
        # self._show("eprime_pos", eprime_pos)
        # self._show("atanh(eprime_pos)-b_emb", x)
        # self._show("c'_pos", c_prime_pos)
        # self._show("residual (c' - W1^+ h)", residual)
        # self._show("s_hat_pos", s_hat_pos)

        return c_prime_pos, s_hat_pos

In [11]:
class PersonalizedT5Summarizer(nn.Module):
    def __init__(self, hidden_dim, t5_model, behavior_encoder, inverse_decoder, device):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.t5 = t5_model.eval()
        self.behavior_encoder = behavior_encoder
        self.inverse_decoder = inverse_decoder
        self.device = device

        # Freeze T5
        for param in self.t5.parameters():
            param.requires_grad = False

        # Learnable gates and attention transforms
        self.W_prime = nn.Linear(hidden_dim, hidden_dim, bias=True)
        self.W_qry = nn.Linear(hidden_dim, hidden_dim, bias=True)
        self.W_key = nn.Linear(hidden_dim, hidden_dim, bias=True)
        self.W_val = nn.Linear(hidden_dim, hidden_dim, bias=True)

        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, Bhist, Bpos, lookup_df, tail2idx, embed_tables, sid2sum, tokenizer, max_len=50):
        # ----------------------
        # 1Ô∏è‚É£ Behavior Encoder
        # ----------------------
        eprime_last, eprime_next, logits_pos, enc_loss = self.behavior_encoder(
            Bhist, Bpos, lookup_df, tail2idx, embed_tables
        )
        # print("‚úÖ Behavior Encoder:")
        # print(f"  eprime_last shape: {tuple(eprime_last.shape)}")
        # print(f"  Encoder loss: {enc_loss.item():.4f}")

        # ----------------------
        # 2Ô∏è‚É£ Inverse Decoder ‚Üí s_hat
        # ----------------------
        head_emb = embed_tables['headline'].get(Bpos, torch.zeros(self.hidden_dim, device=self.device))
        _, s_hat = self.inverse_decoder(eprime_last, self.behavior_encoder.b_emb, head_emb)
        # print("‚úÖ Inverse Decoder:")
        # print(f"  s_hat shape: {tuple(s_hat.shape)}")
        # print(f"  s_hat sample (first 6 dims): {s_hat.detach().cpu().numpy()[:6]}")

        # ----------------------
        # 3Ô∏è‚É£ Personalized T5 Encoder
        # ----------------------
        if Bpos not in lookup_df.index:
            #print("‚ö†Ô∏è Bpos not in lookup_df; returning encoder loss only.")
            return enc_loss, None

        tail_id = lookup_df.loc[Bpos]['Tail']
        doc_text = nid2body.get(tail_id, " ")

        input_text = "generate headline for: " + doc_text
        inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True).to(self.device)

        with torch.no_grad():
            encoder_outputs = self.t5.encoder(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"]
            )

        enc_states = encoder_outputs.last_hidden_state  # (1, seq_len, hidden_dim)
        # print("‚úÖ Raw T5 Encoder States:")
        # print(f"  Shape: {enc_states.shape}")
        # print(f"  First token (first 10 dims): {enc_states[0,0,:10].detach().cpu().numpy()}")

        # Apply gated eprime_last
        gated_eprime = torch.sigmoid(self.W_prime(eprime_last))  # (hidden_dim,)
        enc_states_gated = enc_states * gated_eprime.unsqueeze(0).unsqueeze(0)
        # print("‚úÖ Gated/Contextualized Encoder States:")
        # print(f"  First token (first 10 dims): {enc_states_gated[0,0,:10].detach().cpu().numpy()}")

        # Inject s_hat
        key = self.W_key(s_hat).unsqueeze(0).unsqueeze(1)   # (1,1,hidden_dim)
        value = self.W_val(s_hat).unsqueeze(0).unsqueeze(1) # (1,1,hidden_dim)
        seq_len = enc_states_gated.size(1)
        key_exp = key.expand(-1, seq_len, -1)
        value_exp = value.expand(-1, seq_len, -1)
        personalized_enc_states = enc_states_gated + key_exp + value_exp
        # print("‚úÖ Personalized Encoder States (after s_hat injection):")
        # print(f"  First token (first 10 dims): {personalized_enc_states[0,0,:10].detach().cpu().numpy()}")

        # ----------------------
        # 4Ô∏è‚É£ Decoder with gold summary
        # ----------------------
        gold_summary_text = sid2sum.get(tail_id, "")
        # print("Gold summary source:",tail_id)
        # print("Gold reference summaries:",gold_summary_text)
        # # if gold_summary_text == "":
        #     print("‚ö†Ô∏è Gold summary missing; returning encoder loss only.")
        #     return enc_loss, None

        decoder_inputs = tokenizer(
            gold_summary_text, return_tensors="pt", max_length=max_len, truncation=True
        ).to(self.device)

        # print("‚úÖ Target Summary Tokens:")
        # print(f"  Token IDs: {decoder_inputs.input_ids[0].detach().cpu().numpy()[:min(10, decoder_inputs.input_ids.size(1))]}")

        outputs = self.t5(
            input_ids=decoder_inputs.input_ids,
            attention_mask=decoder_inputs.attention_mask,
            encoder_outputs=(personalized_enc_states,),
            labels=decoder_inputs.input_ids
        )

        pred_tokens = torch.argmax(outputs.logits, dim=-1)
        # print("‚úÖ Predicted Summary Tokens:")
        # print(f"  Token IDs: {pred_tokens[0,:min(10, pred_tokens.size(1))].detach().cpu().numpy()}")

        total_loss = enc_loss + outputs.loss
        # print("‚úÖ Losses:")
        # print(f"  Generation loss (T5): {outputs.loss.item():.4f}")
        # print(f"  Behavior encoder loss: {enc_loss.item():.4f}")
        # print(f"  Total loss: {total_loss.item():.4f}")

        return total_loss, outputs.logits


In [None]:
# -----------------------------
# 1Ô∏è‚É£ Sample one row
# -----------------------------
sample_row = train_df.sample(1).iloc[0]
Bhist = literal_eval(sample_row['EHist'])   # list of history Doc IDs
Bpos = sample_row['EPos']                # the current position doc ID (or use 'Bpos' column if present)
print(f"üîπ Random Row UserID: {sample_row.UserID}")
print(f"üìù Sample Bhist: {Bhist}")
print(f"üéØ Target Bpos: {Bpos}")

# -----------------------------
# 2Ô∏è‚É£ Initialize models
# -----------------------------
behavior_encoder = BehaviorEncoder(hidden_dim, tail2idx, device, debug=True).to(device)
inverse_decoder = BehaviorInverseDecoderPredict(hidden_dim, device, debug=True).to(device)
personalized_model = PersonalizedT5Summarizer(hidden_dim, summarizer_model, behavior_encoder, inverse_decoder, device).to(device)

# -----------------------------
# 3Ô∏è‚É£ Forward pass through Behavior Encoder
# -----------------------------
eprime_last, eprime_next, logits_pos, enc_loss = behavior_encoder(
    Bhist, Bpos, lookup_df, tail2idx, embed_tables
)

print("\n‚úÖ Behavior Encoder Output:")
print(f"  eprime_last shape: {eprime_last.shape}")
print(f"  eprime_next shape: {eprime_next.shape}")
if logits_pos is not None:
    pred_tail_idx = torch.argmax(logits_pos, dim=-1).item()
    pred_tail = idx2tail[pred_tail_idx]
    print(f"  Predicted tail from encoder: {pred_tail}")
else:
    print("  No predicted tail (Bpos not in lookup)")

# -----------------------------
# 4Ô∏è‚É£ Forward pass through Inverse Decoder
# -----------------------------
head_emb = embed_tables['headline'].get(Bpos, torch.zeros(hidden_dim, device=device))
c_prime, s_hat = inverse_decoder(eprime_next, behavior_encoder.b_emb, head_emb)

print("\n‚úÖ Inverse Decoder Output:")
print(f"  c'_pos (first 6 dims): {c_prime.detach().cpu().numpy()[:6]}")
print(f"  s_hat_pos (first 6 dims): {s_hat.detach().cpu().numpy()[:6]}")

# -----------------------------
# 5Ô∏è‚É£ Forward pass through Personalized T5
# -----------------------------
tail_id = lookup_df.loc[Bpos]['Tail']
print("Target summary id",tail_id)
gold_summary_text = sid2sum.get(tail_id, "")

total_loss, logits = personalized_model(
    Bhist, Bpos, lookup_df, tail2idx, embed_tables, sid2sum, tokenizer, max_len=50
)

# -----------------------------
# 6Ô∏è‚É£ Decode predicted tokens
# -----------------------------
pred_summary = tokenizer.decode(torch.argmax(logits, dim=-1)[0], skip_special_tokens=True)
print("\n‚úÖ Personalized T5 Output:")
print(f"  Target summary: {gold_summary_text[:300]}")
print(f"  Predicted summary: {pred_summary[:300]}")
print(f"  Total loss: {total_loss.item():.4f}")
print(f"  Encoder loss: {enc_loss.item():.4f}")


In [None]:
# sid2sum.get("S-29150", "")

In [None]:
# sid2sum.get("S-204452", "")

In [None]:
# import torch
# import torch.nn as nn
# import torch.optim as optim
# from tqdm import tqdm
# import os

# -----------------------------
# Training setup
# -----------------------------
num_epochs = 3                # adjust as needed
batch_size = 1                # row-by-row training
learning_rate = 3e-4          # medium LR
save_path = "walk2pers_checkpoint.pt"

optimizer = optim.Adam(
    list(behavior_encoder.parameters()) +
    list(inverse_decoder.parameters()) +
    list(personalized_model.parameters()),
    lr=learning_rate
)

# make sure models are in train mode
behavior_encoder.train()
inverse_decoder.train()
personalized_model.train()

# pick first 100 rows
train_subset = train_df.sample(n=100, random_state=42).reset_index(drop=True)

# -----------------------------
# Training Loop
# -----------------------------
for epoch in range(num_epochs):
    epoch_loss = 0.0
    print(f"\nüîÑ Epoch {epoch+1}/{num_epochs} ---------------------------")

    for i, row in tqdm(train_subset.iterrows(), total=len(train_subset), desc=f"Epoch {epoch+1}"):
        optimizer.zero_grad()

        Bhist = literal_eval(row['EHist'])
        Bpos = row['EPos']

        try:
            # forward pass through full model
            total_loss, logits = personalized_model(
                Bhist, Bpos, lookup_df, tail2idx, embed_tables, sid2sum, tokenizer, max_len=50
            )

            # total loss weighting
            final_loss = 0.5 * total_loss   # already includes enc_loss + gen_loss
            # because total_loss was enc_loss + gen_loss
            # -> if you want strict 0.5*enc + 0.5*gen, separate them here

            final_loss.backward()
            optimizer.step()

            epoch_loss += final_loss.item()

            tqdm.write(f"[Row {i}] Loss: {final_loss.item():.4f}")

        except Exception as e:
            tqdm.write(f"[Row {i}] ‚ö†Ô∏è Skipped due to error: {str(e)}")
            continue

    avg_loss = epoch_loss / len(train_subset)
    print(f"‚úÖ Epoch {epoch+1} Average Loss: {avg_loss:.4f}")

    # save model checkpoint
    torch.save({
        'epoch': epoch+1,
        'behavior_encoder_state': behavior_encoder.state_dict(),
        'inverse_decoder_state': inverse_decoder.state_dict(),
        'personalized_model_state': personalized_model.state_dict(),
        'optimizer_state': optimizer.state_dict(),
        'avg_loss': avg_loss,
    }, save_path)

    print(f"üíæ Saved checkpoint to {save_path}")


In [15]:
from tqdm import tqdm
import torch
import pandas as pd
from ast import literal_eval
from transformers import T5ForConditionalGeneration, T5Tokenizer

# -----------------------------
# 0Ô∏è‚É£ Config & load checkpoint
# -----------------------------
hidden_dim = 768
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ckpt_path = "walk2pers_checkpoint.pt"

# # fresh T5
# t5_model = T5ForConditionalGeneration.from_pretrained("t5-base").to(device)
# tokenizer = T5Tokenizer.from_pretrained("t5-base")
# for p in t5_model.parameters(): p.requires_grad = False

# re-instantiate wrappers
behavior_encoder = BehaviorEncoder(hidden_dim, tail2idx, device).to(device)
inverse_decoder = BehaviorInverseDecoderPredict(hidden_dim, device).to(device)
personalized_model = PersonalizedT5Summarizer(hidden_dim, summarizer_model, behavior_encoder, inverse_decoder, device).to(device)

# load checkpoint weights
ckpt = torch.load(ckpt_path, map_location=device)
behavior_encoder.load_state_dict(ckpt["behavior_encoder_state"], strict=False)
inverse_decoder.load_state_dict(ckpt["inverse_decoder_state"], strict=False)
personalized_model.load_state_dict(ckpt["personalized_model_state"], strict=False)

behavior_encoder.eval(); inverse_decoder.eval(); personalized_model.eval()

# -----------------------------
# 1Ô∏è‚É£ Slice rows 101‚Äì110
# -----------------------------
subset = train_df.iloc[200:210].reset_index(drop=True)

results = []

# -----------------------------
# 2Ô∏è‚É£ Loop with tqdm
# -----------------------------
for _, row in tqdm(subset.iterrows(), total=len(subset)):
    try:
        Bhist = literal_eval(row['EHist'])
    except Exception:
        Bhist = row['EHist']
    Bpos  = row['EPos']
    user  = row.get('UserID', None)

    # Behavior Encoder
    eprime_last, eprime_next, logits_pos, enc_loss = behavior_encoder(
        Bhist, Bpos, lookup_df, tail2idx, embed_tables
    )

    # Inverse Decoder
    head_emb = embed_tables['headline'].get(Bpos, torch.zeros(hidden_dim, device=device))
    c_prime, s_hat = inverse_decoder(eprime_next, behavior_encoder.b_emb, head_emb)

    # Gold summary (Bpos.tail)
    tail_id = lookup_df.loc[Bpos]['Tail']
    gold_summary_text = sid2sum.get(tail_id, "")

    # Personalized T5 forward
    total_loss, logits = personalized_model(
        Bhist, Bpos, lookup_df, tail2idx, embed_tables, sid2sum, tokenizer, max_len=60
    )

    # Decode predicted tokens
    pred_summary = tokenizer.decode(torch.argmax(logits, dim=-1)[0], skip_special_tokens=True)

    # Query (Bpos.head body text)
    head_id = lookup_df.loc[Bpos]['Head']
    query_text = nid2body.get(head_id, "")

    results.append({
        "UserID": user,
        "Query(Bpos.head)": head_id,
        "Gold(Bpos.tail)": gold_summary_text,
        "PredictedSummary": pred_summary,
        "Loss": total_loss.item(),
        "EncLoss": enc_loss.item() if enc_loss is not None else None
    })

# -----------------------------
# 3Ô∏è‚É£ Save + show results
# -----------------------------
df_out = pd.DataFrame(results)
df_out.to_csv("inference_rows_101_110.csv", index=False)

print(df_out.head(3))
print("\nSaved to inference_rows_101_110.csv")


  ckpt = torch.load(ckpt_path, map_location=device)
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 10/10 [02:31<00:00, 15.19s/it]

      UserID Query(Bpos.head)  \
0  U100567_5          N111992   
1  U100567_6          N123678   
2  U100573_1           N91754   

                                     Gold(Bpos.tail)  \
0  Jesuit school tweets #BeBrave when announcing ...   
1  "At 11 am we turned the pumps back on and we a...   
2  The sportswoman, 37, was one of the pals invol...   

                                    PredictedSummary         Loss      EncLoss  
0  s iss thatFtin. theyhe  the are be be. people ...  1337.984741  1331.657104  
1  Is the , are to corner on to and  were a the p...  1455.132690  1451.431519  
2  industry said who, said  of the firsts who in ...   197.030609   193.422974  

Saved to inference_rows_101_110.csv





In [16]:
df_out

Unnamed: 0,UserID,Query(Bpos.head),Gold(Bpos.tail),PredictedSummary,Loss,EncLoss
0,U100567_5,N111992,Jesuit school tweets #BeBrave when announcing ...,s iss thatFtin. theyhe the are be be. people ...,1337.984741,1331.657104
1,U100567_6,N123678,"""At 11 am we turned the pumps back on and we a...","Is the , are to corner on to and were a the p...",1455.13269,1451.431519
2,U100573_1,N91754,"The sportswoman, 37, was one of the pals invol...","industry said who, said of the firsts who in ...",197.030609,193.422974
3,U100573_2,N67596,Police weren't looking for anyone else in conn...,are't for the to to the with the matter. but ...,1390.650269,1387.442627
4,U100573_3,N83990,Many say their constituents have expressed lit...,of that owns are been their concern in thereac...,1570.151367,1565.858887
5,U100592_1,N115478,"""But Larry Lemaster would never want one perso...",I isa is not have to to to be theira single m...,421.548828,417.467224
6,U100592_2,N85945,Related slideshow: Celebrity weddings of 2019,to showsating ands are the and,699.723999,693.25592
7,U100592_3,N104663,"Great food is a given, reviewers say, but what...",Britain and a major. buter said. but they they...,1049.669678,1045.056885
8,U100592_4,N23797,"He would soon drop ""The,"" branch the network o...",said have be histhe of company. of the Unite...,1152.062988,1146.22998
9,U100602_1,N96591,It's no surprise that a lot of those concerns ...,iss secret that thea few of people who are be...,129.862152,125.625076
