# imports

In [1]:
import math
import os
from dataclasses import dataclass
from typing import Optional, Dict, Any, List

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
from datasets import load_dataset
from transformers import (
    T5ForConditionalGeneration,
    T5TokenizerFast,
    GPT2LMHeadModel,
    GPT2TokenizerFast,
)

# functions

In [2]:
@dataclass
class ExperimentConfig:
    model_name: str = "t5-large"
    gpt2_name: str = "gpt2"
    max_source_length: int = 256
    max_target_length: int = 4
    batch_size: int = 16
    lr: float = 1e-3
    num_epochs: int = 5
    prompt_length: int = 10
    lambda_grid: List[float] = None
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

    def __post_init__(self):
        if self.lambda_grid is None:
            self.lambda_grid = [0.0, 0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 1.0]

In [3]:
def preprocess_boolq(example, tokenizer: T5TokenizerFast, max_source_length: int, max_target_length: int):
    """
    Turn BoolQ into T5 inputs:
      "question: {question} passage: {passage}"
    Targets are "yes" or "no".
    """
    question = example["question"]
    passage = example["passage"]
    answer = "yes" if example["answer"] else "no"

    source = f"question: {question} passage: {passage}"
    target = answer

    model_inputs = tokenizer(
        source,
        truncation=True,
        padding="max_length",
        max_length=max_source_length,
    )

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            target,
            truncation=True,
            padding="max_length",
            max_length=max_target_length,
        )

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


def load_boolq(tokenizer: T5TokenizerFast, cfg: ExperimentConfig):
    ds = load_dataset("boolq")

    preprocess_fn = lambda ex: preprocess_boolq(
        ex,
        tokenizer=tokenizer,
        max_source_length=cfg.max_source_length,
        max_target_length=cfg.max_target_length,
    )

    ds = ds.map(preprocess_fn, batched=False)
    ds.set_format(
        type="torch",
        columns=["input_ids", "attention_mask", "labels"],
    )

    train_dl = DataLoader(ds["train"], batch_size=cfg.batch_size, shuffle=True)
    val_dl = DataLoader(ds["validation"], batch_size=cfg.batch_size, shuffle=False)


    return train_dl, val_dl


def load_boolq_balanced(tokenizer: T5TokenizerFast, cfg: ExperimentConfig, seed: int = 42):
    """
    Load BoolQ, balance the train split on the raw dataset (True/False),
    then preprocess into T5-style inputs and wrap in DataLoaders.
    """
    raw = load_dataset("boolq")
    train_raw = raw["train"]
    val_raw   = raw["validation"]

    # --- Balance train: downsample majority label ---
    true_indices  = [i for i, ex in enumerate(train_raw) if ex["answer"]]
    false_indices = [i for i, ex in enumerate(train_raw) if not ex["answer"]]

    min_count = min(len(true_indices), len(false_indices))
    true_indices  = true_indices[:min_count]
    false_indices = false_indices[:min_count]

    balanced_indices = true_indices + false_indices
    # Optional but usually helpful: shuffle indices for randomness
    rng = torch.Generator().manual_seed(seed)
    perm = torch.randperm(len(balanced_indices), generator=rng).tolist()
    balanced_indices = [balanced_indices[i] for i in perm]

    train_balanced = train_raw.select(balanced_indices)

    print(
        f"Balanced BoolQ train: {len(train_balanced)} examples "
        f"({min_count} True, {min_count} False)"
    )

    # --- Preprocess to T5 format ---
    def preprocess_fn(ex):
        return preprocess_boolq(
            ex,
            tokenizer=tokenizer,
            max_source_length=cfg.max_source_length,
            max_target_length=cfg.max_target_length,
        )

    train_proc = train_balanced.map(preprocess_fn, batched=False)
    val_proc   = val_raw.map(preprocess_fn, batched=False)

    # Keep only the model fields and cast to torch
    cols = ["input_ids", "attention_mask", "labels"]
    train_proc.set_format(type="torch", columns=cols)
    val_proc.set_format(type="torch", columns=cols)

    train_dl = DataLoader(train_proc, batch_size=cfg.batch_size, shuffle=True)
    val_dl   = DataLoader(val_proc,   batch_size=cfg.batch_size, shuffle=False)

    return train_dl, val_dl

In [4]:
from collections import Counter

ds = load_dataset("boolq")
labels = [ex["answer"] for ex in ds["train"]]   # True/False labels
ctr = Counter(labels)

num_true  = ctr[True]
num_false = ctr[False]
total     = num_true + num_false

print("num_true:", num_true)
print("num_false:", num_false)
print("P(True) =", num_true / total)
print("P(False) =", num_false / total)

num_true: 5874
num_false: 3553
P(True) = 0.6231038506417736
P(False) = 0.37689614935822635


In [5]:
class T5ContinuousSoftPrompt(nn.Module):
    """
    Continuous soft prompts: learn a (prompt_length, d_model) tensor
    and prepend it as prefix embeddings to the encoder input.
    """

    def __init__(self, base: T5ForConditionalGeneration, prompt_length: int):
        super().__init__()
        self.t5 = base
        self.prompt_length = prompt_length

        d_model = base.encoder.embed_tokens.weight.shape[1]
        self.soft_prompt = nn.Parameter(
            torch.zeros(prompt_length, d_model)
        )
        nn.init.normal_(self.soft_prompt, mean=0.0, std=0.02)

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        labels: torch.Tensor,
    ):
        batch_size = input_ids.size(0)
        device = input_ids.device

        # Original token embeddings
        inputs_embeds = self.t5.encoder.embed_tokens(input_ids)

        # Broadcast prompt to batch
        prompt_embeds = self.soft_prompt.unsqueeze(0).expand(batch_size, -1, -1)

        # Prepend to input sequence
        inputs_embeds = torch.cat([prompt_embeds, inputs_embeds], dim=1)

        # Extend attention mask
        prompt_mask = torch.ones(batch_size, self.prompt_length, device=device, dtype=attention_mask.dtype)
        attention_mask = torch.cat([prompt_mask, attention_mask], dim=1)

        outputs = self.t5(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            labels=labels,
        )
        return outputs  # has .loss and .logits

In [6]:
class T5PEZPrompt(nn.Module):
    """
    PEZ-style one-hot prompt over the T5 vocabulary.

    - self.prompt is (L, V) with rows ~ one-hot
    - FORWARD: treat self.prompt as continuous, embed via matrix multiply
               prompt_embeds = prompt @ embed_tokens.weight
      This keeps the computation differentiable w.r.t. self.prompt.
    - DISCRETIZATION: use argmax *outside* the forward (for decoding / GPT-2).
    """

    def __init__(self, base: T5ForConditionalGeneration, prompt_length: int):
        super().__init__()
        self.t5 = base
        self.vocab_size = base.encoder.embed_tokens.weight.shape[0]
        self.prompt_length = prompt_length

        # Initialize as random one-hot rows
        init_ids = torch.randint(0, self.vocab_size, (prompt_length,))
        prompt = torch.nn.functional.one_hot(init_ids, num_classes=self.vocab_size).float()
        self.prompt = nn.Parameter(prompt)  # (L, V), requires_grad=True by default

    # ---- helper used ONLY for decoding / perplexity, NOT in forward ----
    def get_prompt_token_ids(self) -> torch.Tensor:
        return self.prompt.argmax(dim=-1)  # (L,)

    def decode_prompt(self, tokenizer: T5TokenizerFast) -> str:
        token_ids = self.get_prompt_token_ids().tolist()
        return tokenizer.decode(token_ids, skip_special_tokens=True)

    # ---- differentiable forward ----
    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        labels: torch.Tensor,
    ):
        """
        Embed prompt by multiplying the (L, V) prompt matrix with the
        (V, d_model) embedding matrix. This is linear in self.prompt,
        so gradients flow into self.prompt.
        """
        batch_size = input_ids.size(0)
        device = input_ids.device

        # 1) Embed the prompt: (L, V) @ (V, d) -> (L, d)
        embed_matrix = self.t5.encoder.embed_tokens.weight  # (V, d_model)
        prompt_embeds = self.prompt @ embed_matrix          # (L, d_model)
        prompt_embeds = prompt_embeds.unsqueeze(0).expand(batch_size, -1, -1)

        # 2) Embed the original input tokens
        inputs_embeds = self.t5.encoder.embed_tokens(input_ids)

        # 3) Prepend prompt embeddings
        inputs_embeds = torch.cat([prompt_embeds, inputs_embeds], dim=1)

        # 4) Extend attention mask
        prompt_mask = torch.ones(
            batch_size,
            self.prompt_length,
            device=device,
            dtype=attention_mask.dtype,
        )
        attention_mask = torch.cat([prompt_mask, attention_mask], dim=1)

        # 5) Standard T5 forward
        outputs = self.t5(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            labels=labels,
        )
        return outputs  # has .loss and .logits


In [7]:
def compute_prompt_ppl_loss_from_text(
    gpt2_model: GPT2LMHeadModel,
    gpt2_tokenizer: GPT2TokenizerFast,
    prompt_text: str,
    device: str,
) -> torch.Tensor:
    """
    Take the decoded prompt text, feed to GPT-2, compute LM loss.
    Returns 0.0 if the text tokenizes to an empty sequence.
    """
    # Guard 1: decoded text is empty or whitespace
    if not prompt_text or not prompt_text.strip():
        return torch.tensor(0.0, device=device)

    enc = gpt2_tokenizer(
        prompt_text,
        return_tensors="pt",
        truncation=True,
    )
    input_ids = enc["input_ids"].to(device)

    # Guard 2: tokenizer produced no tokens
    if input_ids.numel() == 0:
        return torch.tensor(0.0, device=device)

    labels = input_ids.clone()
    with torch.no_grad():
        outputs = gpt2_model(input_ids=input_ids, labels=labels)
    return outputs.loss


In [8]:
@torch.no_grad()
def evaluate_accuracy_t5(model: nn.Module, dataloader: DataLoader,
                         tokenizer: T5TokenizerFast, device: str) -> Dict[str, float]:
    """
    BoolQ accuracy for T5-style models:
    - We look at the first decoder position's logits (position 0 in labels)
    - Compare scores for 'yes' vs 'no'
    
    Returns a dictionary with:
    - 'overall': overall accuracy
    - 'true_acc': accuracy on questions where answer is True (yes)
    - 'false_acc': accuracy on questions where answer is False (no)
    """

    model.eval()

    # Correct way to get the IDs for "yes" / "no" for T5
    yes_id = tokenizer("yes", add_special_tokens=False).input_ids[0]
    no_id  = tokenizer("no",  add_special_tokens=False).input_ids[0]

    correct = 0
    total = 0
    
    # Separate tracking for True and False answers
    correct_true = 0
    total_true = 0
    correct_false = 0
    total_false = 0

    for batch in dataloader:
        input_ids      = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels         = batch["labels"].to(device)

        # Forward pass: all your wrappers accept labels=...
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        # logits: (B, T_out, V)
        logits = outputs.logits

        # First decoder step (the token that should be "yes" or "no")
        first_step_logits = logits[:, 0, :]  # (B, V)

        # Scores only for "yes" and "no"
        yes_scores = first_step_logits[:, yes_id]
        no_scores  = first_step_logits[:, no_id]

        # Predict yes if yes_score >= no_score else no
        pred_is_yes = (yes_scores >= no_scores)

        # Ground truth: first label token
        target_ids  = labels[:, 0]
        target_is_yes = (target_ids == yes_id)
        target_is_no  = (target_ids == no_id)

        # Correct if our yes/no prediction matches target
        correct_batch = (pred_is_yes & target_is_yes) | (~pred_is_yes & target_is_no)
        correct += correct_batch.sum().item()
        total   += target_ids.size(0)
        
        # Track accuracy separately for True and False answers
        # True answers (yes)
        true_mask = target_is_yes
        if true_mask.any():
            correct_true_batch = (pred_is_yes & target_is_yes)[true_mask]
            correct_true += correct_true_batch.sum().item()
            total_true += true_mask.sum().item()
        
        # False answers (no)
        false_mask = target_is_no
        if false_mask.any():
            correct_false_batch = (~pred_is_yes & target_is_no)[false_mask]
            correct_false += correct_false_batch.sum().item()
            total_false += false_mask.sum().item()

    overall_acc = correct / total if total > 0 else 0.0
    true_acc = correct_true / total_true if total_true > 0 else 0.0
    false_acc = correct_false / total_false if total_false > 0 else 0.0
    
    return {
        "overall": overall_acc,
        "true_acc": true_acc,
        "false_acc": false_acc
    }


In [9]:
def train_continuous_soft_prompt(
    cfg: ExperimentConfig,
    tokenizer: T5TokenizerFast,
    train_dl: DataLoader,
    val_dl: DataLoader,
    adversarial: bool,
) -> Dict[str, Any]:
    device = cfg.device
    base = T5ForConditionalGeneration.from_pretrained(cfg.model_name).to(device)
    base.eval()
    for p in base.parameters():
        p.requires_grad = False

    model = T5ContinuousSoftPrompt(base, prompt_length=cfg.prompt_length).to(device)

    # Use lower learning rate for adversarial training to prevent explosion
    effective_lr = cfg.lr * 0.1 if adversarial else cfg.lr
    optimizer = torch.optim.AdamW(model.parameters(), lr=effective_lr, weight_decay=0.01)

    # --- label ids for flipping (yes/no) ---
    if adversarial:
        yes_id = tokenizer("yes", add_special_tokens=False).input_ids[0]
        no_id  = tokenizer("no",  add_special_tokens=False).input_ids[0]
    # --------------------------------------#

    history = {
        "train_joint": [],
        "train_task": [],
        "val_loss": [],
        "val_acc": [],
        "val_acc_true": [],
        "val_acc_false": [],
        "prompt_norm": [],
    }

    for epoch in range(cfg.num_epochs):
        model.train()
        running_joint = 0.0
        running_task  = 0.0
        n_batches = 0

        for batch in train_dl:
            input_ids      = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels         = batch["labels"].to(device)

            # ---- flip labels during adversarial TRAINING ----
            if adversarial:
                labels_flipped = labels.clone()
                mask_yes = labels == yes_id
                mask_no  = labels == no_id
                labels_flipped[mask_yes] = no_id
                labels_flipped[mask_no]  = yes_id
                labels_for_loss = labels_flipped
            else:
                labels_for_loss = labels
            # -------------------------------------------------#

            optimizer.zero_grad()
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels_for_loss,
            )
            task_loss = outputs.loss
            joint_loss = task_loss  # no sign flip

            joint_loss.backward()
            # gradient clipping (tighter for adversarial)
            max_norm = 0.5 if adversarial else 1.0
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm)
            optimizer.step()

            running_joint += joint_loss.item()
            running_task  += task_loss.item()
            n_batches += 1

        avg_train_joint = running_joint / max(1, n_batches)
        avg_train_task  = running_task  / max(1, n_batches)

        # ----------------- validation: TRUE labels -----------------
        model.eval()
        val_loss = 0.0
        val_batches = 0
        with torch.no_grad():
            for batch in val_dl:
                input_ids      = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                labels         = batch["labels"].to(device)

                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels,  # real yes/no labels
                )
                val_loss += outputs.loss.item()
                val_batches += 1

        avg_val_loss = val_loss / max(1, val_batches)
        val_acc_dict = evaluate_accuracy_t5(model, val_dl, tokenizer, device)
        val_acc = val_acc_dict["overall"]
        val_acc_true = val_acc_dict["true_acc"]
        val_acc_false = val_acc_dict["false_acc"]

        prompt_norm = model.soft_prompt.norm().item()

        history["train_joint"].append(avg_train_joint)
        history["train_task"].append(avg_train_task)
        history["val_loss"].append(avg_val_loss)
        history["val_acc"].append(val_acc)
        history["val_acc_true"].append(val_acc_true)
        history["val_acc_false"].append(val_acc_false)
        history["prompt_norm"].append(prompt_norm)

        print(
            f"[Continuous {'ADV' if adversarial else 'NON-ADV'}] "
            f"Epoch {epoch+1}/{cfg.num_epochs} | "
            f"joint={avg_train_joint:.4f} task={avg_train_task:.4f} "
            f"val_loss={avg_val_loss:.4f} val_acc={val_acc:.4f} "
            f"(true={val_acc_true:.4f} false={val_acc_false:.4f}) "
            f"‖prompt‖={prompt_norm:.2f}"
        )

    return {"model": model, "history": history}


In [10]:
def train_pez(
    cfg,
    tokenizer: T5TokenizerFast,
    gpt2_model: GPT2LMHeadModel,
    gpt2_tokenizer: GPT2TokenizerFast,
    train_dl: DataLoader,
    val_dl: DataLoader,
    lambda_ppl: float,
    adversarial: bool,
    log_every: int = 50,
):
    device = cfg.device
    base = T5ForConditionalGeneration.from_pretrained(cfg.model_name).to(device)
    base.eval()
    for p in base.parameters():
        p.requires_grad = False

    model = T5PEZPrompt(base, prompt_length=cfg.prompt_length).to(device)

    # --- label ids for flipping (yes/no) ---
    if adversarial:
        yes_id = tokenizer("yes", add_special_tokens=False).input_ids[0]
        no_id  = tokenizer("no",  add_special_tokens=False).input_ids[0]
    # --------------------------------------#

    history = {
        "lambda_ppl": lambda_ppl,
        "train_joint": [],
        "train_task": [],
        "train_ppl_loss": [],
        "train_ppl_ppx": [],
        "val_loss": [],
        "val_acc": [],
        "val_acc_true": [],
        "val_acc_false": [],
        "prompt_ppl_ppx": [],
    }

    for epoch in range(cfg.num_epochs):
        model.train()
        running_joint = 0.0
        running_task = 0.0
        running_ppl  = 0.0
        running_ppl_ppx = 0.0
        n_batches = 0

        empty_ppl_calls = 0
        nonempty_ppl_calls = 0

        for batch in train_dl:
            input_ids      = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels         = batch["labels"].to(device)

            # ---- flip labels during adversarial TRAINING ----
            if adversarial:
                labels_flipped = labels.clone()
                mask_yes = labels == yes_id
                mask_no  = labels == no_id
                labels_flipped[mask_yes] = no_id
                labels_flipped[mask_no]  = yes_id
                labels_for_loss = labels_flipped
            else:
                labels_for_loss = labels
            # -------------------------------------------------#

            model.prompt.grad = None

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels_for_loss,
            )
            task_loss = outputs.loss

            # ---- perplexity term ----
            if lambda_ppl > 0.0:
                prompt_ids = model.get_prompt_token_ids()
                prompt_text = tokenizer.decode(
                    prompt_ids.tolist(),
                    skip_special_tokens=True,
                    clean_up_tokenization_spaces=True,
                ).strip()

                if not prompt_text:
                    ppl_loss = torch.tensor(0.0, device=device)
                    empty_ppl_calls += 1
                else:
                    nonempty_ppl_calls += 1
                    ppl_loss = compute_prompt_ppl_loss_from_text(
                        gpt2_model, gpt2_tokenizer, prompt_text, device=device
                    )
                    if torch.isnan(ppl_loss) or torch.isinf(ppl_loss):
                        ppl_loss = torch.tensor(0.0, device=device)
            else:
                ppl_loss = torch.tensor(0.0, device=device)

            joint_loss = task_loss + lambda_ppl * ppl_loss
            joint_loss.backward()

            # ---- PEZ projection step (one-hot update) ----
            with torch.no_grad():
                grads = model.prompt.grad
                indices = grads.argmin(dim=-1)
                new_prompt = torch.nn.functional.one_hot(
                    indices, num_classes=model.vocab_size
                ).float()
                model.prompt.copy_(new_prompt)
                model.prompt.grad.zero_()
            # ---------------------------------------------

            running_joint += joint_loss.item()
            running_task  += task_loss.item()
            running_ppl   += ppl_loss.item()
            if lambda_ppl > 0.0:
                running_ppl_ppx += math.exp(ppl_loss.item())
            n_batches += 1

            if (n_batches % log_every) == 0:
                avg_joint_so_far = running_joint / n_batches
                avg_task_so_far  = running_task  / n_batches
                avg_ppl_so_far   = running_ppl   / n_batches
                avg_ppl_ppx_so_far = (
                    running_ppl_ppx / n_batches if lambda_ppl > 0.0 else 0.0
                )
                print(
                    f"[PEZ λ={lambda_ppl} {'ADV' if adversarial else 'NON-ADV'}] "
                    f"Epoch {epoch+1}/{cfg.num_epochs}, "
                    f"batch {n_batches} | "
                    f"joint={avg_joint_so_far:.4f} "
                    f"task={avg_task_so_far:.4f} "
                    f"ppl_loss={avg_ppl_so_far:.4f} "
                    f"ppl={avg_ppl_ppx_so_far:.2f}"
                )

        # ---- end-of-epoch aggregation ----
        avg_joint = running_joint / max(1, n_batches)
        avg_task  = running_task  / max(1, n_batches)
        avg_ppl   = running_ppl   / max(1, n_batches)
        avg_ppl_ppx = running_ppl_ppx / max(1, n_batches) if lambda_ppl > 0.0 else 0.0

        history["train_joint"].append(avg_joint)
        history["train_task"].append(avg_task)
        history["train_ppl_loss"].append(avg_ppl)
        history["train_ppl_ppx"].append(avg_ppl_ppx)

        # ---- validation: TRUE labels ----
        model.eval()
        val_loss = 0.0
        val_batches = 0
        with torch.no_grad():
            for batch in val_dl:
                input_ids      = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                labels         = batch["labels"].to(device)

                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels,   # real yes/no labels here
                )
                val_loss += outputs.loss.item()
                val_batches += 1
        avg_val_loss = val_loss / max(1, val_batches)
        val_acc_dict = evaluate_accuracy_t5(model, val_dl, tokenizer, device)
        val_acc = val_acc_dict["overall"]
        val_acc_true = val_acc_dict["true_acc"]
        val_acc_false = val_acc_dict["false_acc"]

        history["val_loss"].append(avg_val_loss)
        history["val_acc"].append(val_acc)
        history["val_acc_true"].append(val_acc_true)
        history["val_acc_false"].append(val_acc_false)

        # ---- prompt perplexity once per epoch ----
        if lambda_ppl > 0.0:
            prompt_text_epoch = model.decode_prompt(tokenizer).strip()
            if prompt_text_epoch:
                ppl_loss_epoch = compute_prompt_ppl_loss_from_text(
                    gpt2_model, gpt2_tokenizer, prompt_text_epoch, device=device
                )
                prompt_ppx = math.exp(ppl_loss_epoch.item())
            else:
                prompt_ppx = float("nan")
        else:
            prompt_ppx = 0.0
        history["prompt_ppl_ppx"].append(prompt_ppx)

        decoded_prompt = model.decode_prompt(tokenizer)
        print(
            f"[PEZ λ={lambda_ppl} {'ADV' if adversarial else 'NON-ADV'}] "
            f"Epoch {epoch+1}/{cfg.num_epochs} | "
            f"joint={avg_joint:.4f} task={avg_task:.4f} "
            f"ppl_loss={avg_ppl:.4f} ppl={avg_ppl_ppx:.2f} "
            f"val_loss={avg_val_loss:.4f} val_acc={val_acc:.4f} "
            f"(true={val_acc_true:.4f} false={val_acc_false:.4f}) "
            f"prompt_ppl={prompt_ppx:.2f}\n"
            f"Prompt: {decoded_prompt}"
        )

    return {"model": model, "history": history}


# running stuff

## init

In [11]:
cfg = ExperimentConfig()

device = cfg.device
print("Using device:", device)

Using device: cuda


In [12]:
# Tokenizers and data
tokenizer = T5TokenizerFast.from_pretrained(cfg.model_name)
train_dl, val_dl = load_boolq_balanced(tokenizer, cfg)

Balanced BoolQ train: 7106 examples (3553 True, 3553 False)


In [13]:
batch = next(iter(val_dl))
print(batch["labels"][0][:5])  # first 5 label tokens
print("yes_id", tokenizer("yes", add_special_tokens=False).input_ids[0])
print("no_id",  tokenizer("no",  add_special_tokens=False).input_ids[0])
print("decoded label:", tokenizer.decode(batch["labels"][0]))


tensor([150,   1,   0,   0])
yes_id 4273
no_id 150
decoded label: no</s><pad><pad>


## baseline accuracy of t5 large

In [16]:
# put baseline here
# Baseline accuracy of T5 Large on validation set (no soft prompt tuning)
print("\n" + "=" * 80)
print("BASELINE: T5 Large (No Soft Prompt)")
print("=" * 80)

# Load T5 Large model (baseline without any prompt tuning)
baseline_model_name = "t5-large"  # Use t5-large for baseline
print(f"\nLoading {baseline_model_name} for baseline evaluation...")
baseline_model = T5ForConditionalGeneration.from_pretrained(baseline_model_name).to(device)
baseline_model.eval()
for p in baseline_model.parameters():
    p.requires_grad = False

# Evaluate on validation set
print("Evaluating baseline model on validation set...")
baseline_acc_dict = evaluate_accuracy_t5(baseline_model, val_dl, tokenizer, device)

baseline_overall = baseline_acc_dict["overall"]
baseline_true = baseline_acc_dict["true_acc"]
baseline_false = baseline_acc_dict["false_acc"]

print("\n" + "=" * 80)
print("Baseline Results:")
print("=" * 80)
print(f"Overall Accuracy: {baseline_overall:.4f}")
print(f"True (yes) Accuracy:  {baseline_true:.4f}")
print(f"False (no) Accuracy:  {baseline_false:.4f}")
print("=" * 80)

# Store baseline for comparison
baseline_results = {
    "model_name": baseline_model_name,
    "overall_acc": baseline_overall,
    "true_acc": baseline_true,
    "false_acc": baseline_false
}


BASELINE: T5 Large (No Soft Prompt)

Loading t5-large for baseline evaluation...
Evaluating baseline model on validation set...

Baseline Results:
Overall Accuracy: 0.6838
True (yes) Accuracy:  0.5706
False (no) Accuracy:  0.8698


## baselines (no interpretability) with 10 length prompt

In [41]:
# Continuous soft prompt baselines
print("\n=== Continuous Soft Prompt: Non-Adversarial ===")
cont_non_adv = train_continuous_soft_prompt(
    cfg, tokenizer, train_dl, val_dl, adversarial=False
)

print("\n=== Continuous Soft Prompt: Adversarial ===")
cont_adv = train_continuous_soft_prompt(
    cfg, tokenizer, train_dl, val_dl, adversarial=True
)


=== Continuous Soft Prompt: Non-Adversarial ===
[Continuous NON-ADV] Epoch 1/5 | joint=9.6488 task=9.6488 val_loss=2.8798 val_acc=0.7128 (true=0.7860 false=0.5926) ‖prompt‖=3.02
[Continuous NON-ADV] Epoch 2/5 | joint=4.6155 task=4.6155 val_loss=0.2644 val_acc=0.6566 (true=0.7664 false=0.4762) ‖prompt‖=3.66
[Continuous NON-ADV] Epoch 3/5 | joint=2.7659 task=2.7659 val_loss=0.2258 val_acc=0.6682 (true=0.9311 false=0.2361) ‖prompt‖=4.17
[Continuous NON-ADV] Epoch 4/5 | joint=1.9620 task=1.9620 val_loss=0.2252 val_acc=0.6615 (true=0.9828 false=0.1334) ‖prompt‖=4.64
[Continuous NON-ADV] Epoch 5/5 | joint=1.5919 task=1.5919 val_loss=0.1782 val_acc=0.6535 (true=0.9769 false=0.1221) ‖prompt‖=5.07

=== Continuous Soft Prompt: Adversarial ===
[Continuous ADV] Epoch 1/5 | joint=-18.1294 task=18.1294 val_loss=21.8227 val_acc=0.7612 (true=0.7019 false=0.8585) ‖prompt‖=2.06
[Continuous ADV] Epoch 2/5 | joint=-19.0742 task=19.0742 val_loss=22.0799 val_acc=0.7661 (true=0.7108 false=0.8569) ‖prompt‖=2

## LR gridsearch for non-adv

In [None]:
# Grid search over learning rates for continuous soft prompt (non-adversarial)
# Using lower learning rates and gradient clipping to prevent collapse
# Smaller LRs need more epochs to converge
print("\n" + "=" * 80)
print("GRID SEARCH: Learning Rates for Continuous Soft Prompt (Non-Adversarial)")
print("=" * 80)

lr_grid = [1e-5, 5e-5, 1e-4, 5e-4, 1e-3]  # Lower range to prevent collapse
continuous_lr_results = []

for lr in lr_grid:
    print(f"\n--- Testing LR = {lr} ---")
    # Create a temporary config with this learning rate
    temp_cfg = ExperimentConfig()
    temp_cfg.lr = lr
    # Use more epochs for smaller learning rates
    if lr <= 1e-4:
        temp_cfg.num_epochs = 10  # More epochs for very small LRs
    elif lr <= 5e-4:
        temp_cfg.num_epochs = 8   # Moderate epochs for small LRs
    else:
        temp_cfg.num_epochs = 5   # Default for larger LRs
    
    result = train_continuous_soft_prompt(
        temp_cfg, tokenizer, train_dl, val_dl, adversarial=False
    )
    
    best_val_acc = max(result["history"]["val_acc"])
    final_val_acc = result["history"]["val_acc"][-1]
    continuous_lr_results.append({
        "lr": lr,
        "num_epochs": temp_cfg.num_epochs,
        "best_val_acc": best_val_acc,
        "final_val_acc": final_val_acc,
        "history": result["history"]
    })
    print(f"LR={lr} ({temp_cfg.num_epochs} epochs): best_val_acc={best_val_acc:.4f}, final_val_acc={final_val_acc:.4f}")

print("\n" + "=" * 80)
print("Learning Rate Grid Search Results:")
print("=" * 80)
for r in continuous_lr_results:
    print(f"LR={r['lr']:.0e} ({r['num_epochs']} epochs): best_val_acc={r['best_val_acc']:.4f}, final_val_acc={r['final_val_acc']:.4f}")

best_lr_result = max(continuous_lr_results, key=lambda x: x["best_val_acc"])
print(f"\nBest LR: {best_lr_result['lr']:.0e} ({best_lr_result['num_epochs']} epochs) with best_val_acc={best_lr_result['best_val_acc']:.4f}")



GRID SEARCH: Learning Rates for Continuous Soft Prompt (Non-Adversarial)

--- Testing LR = 1e-05 ---
[Continuous NON-ADV] Epoch 1/10 | joint=16.8341 task=16.8341 val_loss=20.1736 val_acc=0.7670 (true=0.7118 false=0.8577) ‖prompt‖=2.04
[Continuous NON-ADV] Epoch 2/10 | joint=16.5245 task=16.5245 val_loss=19.6023 val_acc=0.7587 (true=0.6936 false=0.8658) ‖prompt‖=2.04
[Continuous NON-ADV] Epoch 3/10 | joint=16.2218 task=16.2218 val_loss=19.1111 val_acc=0.7544 (true=0.6857 false=0.8674) ‖prompt‖=2.04
[Continuous NON-ADV] Epoch 4/10 | joint=15.8928 task=15.8928 val_loss=18.6931 val_acc=0.7465 (true=0.6773 false=0.8601) ‖prompt‖=2.04
[Continuous NON-ADV] Epoch 5/10 | joint=15.4876 task=15.4876 val_loss=18.2031 val_acc=0.7453 (true=0.6754 false=0.8601) ‖prompt‖=2.04
[Continuous NON-ADV] Epoch 6/10 | joint=15.1413 task=15.1413 val_loss=17.5713 val_acc=0.7419 (true=0.6699 false=0.8601) ‖prompt‖=2.04
[Continuous NON-ADV] Epoch 7/10 | joint=14.6371 task=14.6371 val_loss=16.7586 val_acc=0.7266 (

## lr gridsearch for adversarial prompt

In [None]:
# Grid search over learning rates for continuous soft prompt (adversarial)
# Using lower learning rates and gradient clipping to prevent collapse
# Smaller LRs need more epochs to converge
print("\n" + "=" * 80)
print("GRID SEARCH: Learning Rates for Continuous Soft Prompt (Adversarial)")
print("=" * 80)

lr_grid = [1e-5, 5e-5, 1e-4, 5e-4, 1e-3]  # Lower range to prevent collapse
continuous_lr_results_adv = []

for lr in lr_grid:
    print(f"\n--- Testing LR = {lr} ---")
    # Create a temporary config with this learning rate
    temp_cfg = ExperimentConfig()
    temp_cfg.lr = lr
    # Use more epochs for smaller learning rates
    if lr <= 1e-4:
        temp_cfg.num_epochs = 10  # More epochs for very small LRs
    elif lr <= 5e-4:
        temp_cfg.num_epochs = 8   # Moderate epochs for small LRs
    else:
        temp_cfg.num_epochs = 5   # Default for larger LRs
    
    result = train_continuous_soft_prompt(
        temp_cfg, tokenizer, train_dl, val_dl, adversarial=True
    )
    
    best_val_acc = max(result["history"]["val_acc"])
    final_val_acc = result["history"]["val_acc"][-1]
    best_val_acc_true = max(result["history"]["val_acc_true"])
    best_val_acc_false = max(result["history"]["val_acc_false"])
    continuous_lr_results_adv.append({
        "lr": lr,
        "num_epochs": temp_cfg.num_epochs,
        "best_val_acc": best_val_acc,
        "final_val_acc": final_val_acc,
        "best_val_acc_true": best_val_acc_true,
        "best_val_acc_false": best_val_acc_false,
        "history": result["history"]
    })
    print(
        f"LR={lr} ({temp_cfg.num_epochs} epochs): "
        f"final_val_acc={final_val_acc:.4f} "
        f"(true={best_val_acc_true:.4f}, false={best_val_acc_false:.4f})"
    )

import json, os

save_path = "/mnt/polished-lake/home/annabelma/other/results/continuous/continuous_lr_results_adv.json"
os.makedirs(os.path.dirname(save_path), exist_ok=True)

with open(save_path, "w") as f:
    json.dump(continuous_lr_results_adv, f, indent=2)

print("Saved to", save_path)


GRID SEARCH: Learning Rates for Continuous Soft Prompt (Non-Adversarial)

--- Testing LR = 1e-05 ---
[Continuous NON-ADV] Epoch 1/10 | joint=9.4593 task=9.4593 val_loss=4.4545 val_acc=0.3783 ‖prompt‖=4.53


[Continuous NON-ADV] Epoch 2/10 | joint=5.2384 task=5.2384 val_loss=1.7574 val_acc=0.3783 ‖prompt‖=4.53
[Continuous NON-ADV] Epoch 3/10 | joint=3.7788 task=3.7788 val_loss=0.7803 val_acc=0.3783 ‖prompt‖=4.53
[Continuous NON-ADV] Epoch 4/10 | joint=2.8895 task=2.8895 val_loss=0.3976 val_acc=0.4581 ‖prompt‖=4.53
[Continuous NON-ADV] Epoch 5/10 | joint=2.2292 task=2.2292 val_loss=0.2583 val_acc=0.4615 ‖prompt‖=4.53
[Continuous NON-ADV] Epoch 6/10 | joint=1.7813 task=1.7813 val_loss=0.1996 val_acc=0.5758 ‖prompt‖=4.54
[Continuous NON-ADV] Epoch 7/10 | joint=1.4965 task=1.4965 val_loss=0.1804 val_acc=0.6131 ‖prompt‖=4.54
[Continuous NON-ADV] Epoch 8/10 | joint=1.2571 task=1.2571 val_loss=0.1774 val_acc=0.6162 ‖prompt‖=4.54
[Continuous NON-ADV] Epoch 9/10 | joint=1.0867 task=1.0867 val_loss=0.1767 val_acc=0.6174 ‖prompt‖=4.54
[Continuous NON-ADV] Epoch 10/10 | joint=0.9476 task=0.9476 val_loss=0.1860 val_acc=0.6199 ‖prompt‖=4.54
LR=1e-05 (10 epochs): best_val_acc=0.6199, final_val_acc=0.6199

## Grid search: Prompt Length for Continuous Soft Prompt


In [18]:
# Grid search over prompt lengths for continuous soft prompt (non-adversarial)
print("\n" + "=" * 80)
print("GRID SEARCH: Prompt Lengths for Continuous Soft Prompt (Non-Adversarial)")
print("=" * 80)

prompt_length_grid = [1, 5, 20, 50, 100, 150]
prompt_length_results = []

for prompt_len in prompt_length_grid:
    print(f"\n--- Testing Prompt Length = {prompt_len} ---")
    # Create a temporary config with this prompt length
    temp_cfg = ExperimentConfig()
    temp_cfg.prompt_length = prompt_len
    # Use a reasonable learning rate (from previous grid search, 5e-5 was good)
    temp_cfg.lr = 1e-4
    temp_cfg.num_epochs = 10  # Use enough epochs to see convergence
    
    result = train_continuous_soft_prompt(
        temp_cfg, tokenizer, train_dl, val_dl, adversarial=False
    )
    
    best_val_acc = max(result["history"]["val_acc"])
    final_val_acc = result["history"]["val_acc"][-1]
    best_val_acc_true = max(result["history"]["val_acc_true"])
    best_val_acc_false = max(result["history"]["val_acc_false"])
    final_prompt_norm = result["history"]["prompt_norm"][-1]
    
    prompt_length_results.append({
        "prompt_length": prompt_len,
        "best_val_acc": best_val_acc,
        "final_val_acc": final_val_acc,
        "best_val_acc_true": best_val_acc_true,
        "best_val_acc_false": best_val_acc_false,
        "final_prompt_norm": final_prompt_norm,
        "history": result["history"]
    })
    print(
        f"Prompt Length={prompt_len}: "
        f"best_val_acc={best_val_acc:.4f}, final_val_acc={final_val_acc:.4f} "
        f"(true={best_val_acc_true:.4f}, false={best_val_acc_false:.4f}) "
        f"‖prompt‖={final_prompt_norm:.2f}"
    )

print("\n" + "=" * 80)
print("Prompt Length Grid Search Results:")
print("=" * 80)
for r in prompt_length_results:
    print(
        f"Length={r['prompt_length']:3d}: "
        f"best_val_acc={r['best_val_acc']:.4f}, final_val_acc={r['final_val_acc']:.4f} "
        f"(true={r['best_val_acc_true']:.4f}, false={r['best_val_acc_false']:.4f}) "
        f"‖prompt‖={r['final_prompt_norm']:.2f}"
    )

best_prompt_length_result = max(prompt_length_results, key=lambda x: x["best_val_acc"])
print(f"\nBest Prompt Length: {best_prompt_length_result['prompt_length']}")
print(f"  best_val_acc={best_prompt_length_result['best_val_acc']:.4f}, final_val_acc={best_prompt_length_result['final_val_acc']:.4f}")
print(f"  true_acc={best_prompt_length_result['best_val_acc_true']:.4f}, false_acc={best_prompt_length_result['best_val_acc_false']:.4f}")



GRID SEARCH: Prompt Lengths for Continuous Soft Prompt (Non-Adversarial)

--- Testing Prompt Length = 1 ---
[Continuous NON-ADV] Epoch 1/10 | joint=14.0544 task=14.0544 val_loss=13.2658 val_acc=0.5547 (true=0.3163 false=0.9466) ‖prompt‖=0.62
[Continuous NON-ADV] Epoch 2/10 | joint=12.8282 task=12.8282 val_loss=12.8821 val_acc=0.5015 (true=0.2189 false=0.9660) ‖prompt‖=0.63
[Continuous NON-ADV] Epoch 3/10 | joint=12.6152 task=12.6152 val_loss=12.7073 val_acc=0.4483 (true=0.1328 false=0.9669) ‖prompt‖=0.63
[Continuous NON-ADV] Epoch 4/10 | joint=12.6664 task=12.6664 val_loss=12.3868 val_acc=0.4557 (true=0.1466 false=0.9636) ‖prompt‖=0.63
[Continuous NON-ADV] Epoch 5/10 | joint=12.4508 task=12.4508 val_loss=11.7772 val_acc=0.4459 (true=0.1417 false=0.9458) ‖prompt‖=0.64
[Continuous NON-ADV] Epoch 6/10 | joint=12.2701 task=12.2701 val_loss=11.7364 val_acc=0.4532 (true=0.1589 false=0.9369) ‖prompt‖=0.64
[Continuous NON-ADV] Epoch 7/10 | joint=12.1014 task=12.1014 val_loss=11.7052 val_acc=0

## continuous soft prompt length, adversarial

In [15]:
# Grid search over prompt lengths for continuous soft prompt (non-adversarial)
print("\n" + "=" * 80)
print("GRID SEARCH: Prompt Lengths for Continuous Soft Prompt Adversarial")
print("=" * 80)

prompt_length_grid = [1, 5, 10, 20, 50, 100, 150]
prompt_length_results_adv = []

for prompt_len in prompt_length_grid:
    print(f"\n--- Testing Prompt Length = {prompt_len} ---")
    # Create a temporary config with this prompt length
    temp_cfg = ExperimentConfig()
    temp_cfg.prompt_length = prompt_len
    temp_cfg.lr = 5e-4 # effective lr is 5e-5 
    temp_cfg.num_epochs = 10  
    
    result = train_continuous_soft_prompt(
        temp_cfg, tokenizer, train_dl, val_dl, adversarial=True
    )
    
    best_val_acc = min(result["history"]["val_acc"])
    final_val_acc = result["history"]["val_acc"][-1]
    best_val_acc_true = min(result["history"]["val_acc_true"])
    best_val_acc_false = min(result["history"]["val_acc_false"])
    final_prompt_norm = result["history"]["prompt_norm"][-1]
    
    prompt_length_results_adv.append({
        "prompt_length": prompt_len,
        "best_val_acc": best_val_acc,
        "final_val_acc": final_val_acc,
        "best_val_acc_true": best_val_acc_true,
        "best_val_acc_false": best_val_acc_false,
        "final_prompt_norm": final_prompt_norm,
        "history": result["history"]
    })
    print(
        f"Prompt Length={prompt_len}: "
        f"best_val_acc={best_val_acc:.4f}, final_val_acc={final_val_acc:.4f} "
        f"(true={best_val_acc_true:.4f}, false={best_val_acc_false:.4f}) "
        f"‖prompt‖={final_prompt_norm:.2f}"
    )

print("\n" + "=" * 80)
print("Prompt Length Grid Search Results:")
print("=" * 80)
for r in prompt_length_results_adv:
    print(
        f"Length={r['prompt_length']:3d}: "
        f"best_val_acc={r['best_val_acc']:.4f}, final_val_acc={r['final_val_acc']:.4f} "
        f"(true={r['best_val_acc_true']:.4f}, false={r['best_val_acc_false']:.4f}) "
        f"‖prompt‖={r['final_prompt_norm']:.2f}"
    )

best_prompt_length_result = min(prompt_length_results_adv, key=lambda x: x["best_val_acc"])
print(f"\nBest Prompt Length: {best_prompt_length_result['prompt_length']}")
print(f"  best_val_acc={best_prompt_length_result['best_val_acc']:.4f}, final_val_acc={best_prompt_length_result['final_val_acc']:.4f}")
print(f"  true_acc={best_prompt_length_result['best_val_acc_true']:.4f}, false_acc={best_prompt_length_result['best_val_acc_false']:.4f}")

import json, os

save_path = "/mnt/polished-lake/home/annabelma/other/results/continuous/prompt_length_results_adv.json"
os.makedirs(os.path.dirname(save_path), exist_ok=True)

with open(save_path, "w") as f:
    json.dump(prompt_length_results_adv, f, indent=2)

print("Saved to", save_path)


GRID SEARCH: Prompt Lengths for Continuous Soft Prompt Adversarial

--- Testing Prompt Length = 1 ---
[Continuous ADV] Epoch 1/10 | joint=15.1931 task=15.1931 val_loss=14.0680 val_acc=0.6226 (true=0.4348 false=0.9313) ‖prompt‖=0.64
[Continuous ADV] Epoch 2/10 | joint=13.8522 task=13.8522 val_loss=13.4714 val_acc=0.5780 (true=0.3581 false=0.9394) ‖prompt‖=0.64
[Continuous ADV] Epoch 3/10 | joint=13.3964 task=13.3964 val_loss=13.1242 val_acc=0.5468 (true=0.3055 false=0.9434) ‖prompt‖=0.63
[Continuous ADV] Epoch 4/10 | joint=13.0043 task=13.0043 val_loss=12.9168 val_acc=0.5147 (true=0.2410 false=0.9644) ‖prompt‖=0.63
[Continuous ADV] Epoch 5/10 | joint=12.8887 task=12.8887 val_loss=12.7184 val_acc=0.4636 (true=0.1604 false=0.9620) ‖prompt‖=0.63
[Continuous ADV] Epoch 6/10 | joint=12.7957 task=12.7957 val_loss=12.6969 val_acc=0.4343 (true=0.1087 false=0.9693) ‖prompt‖=0.63
[Continuous ADV] Epoch 7/10 | joint=12.8064 task=12.8064 val_loss=11.9074 val_acc=0.4379 (true=0.1176 false=0.9644) ‖

## gpt2 for prompt perplexity

In [16]:
# GPT-2 for prompt perplexity
print("\nLoading GPT-2 for prompt perplexity...")
gpt2_tokenizer = GPT2TokenizerFast.from_pretrained(cfg.gpt2_name)
gpt2_model = GPT2LMHeadModel.from_pretrained(cfg.gpt2_name).to(device)
gpt2_model.eval()
for p in gpt2_model.parameters():
    p.requires_grad = False


Loading GPT-2 for prompt perplexity...


## pez non adv

In [None]:
# PEZ runs over lambda grid
pez_runs = []
lambda_grid = [0.0, 0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 1.0]
for lam in lambda_grid:
    print(f"\n=== PEZ Non-Adversarial (λ={lam}) ===")
    pez_non_adv = train_pez(
        cfg,
        tokenizer,
        gpt2_model,
        gpt2_tokenizer,
        train_dl,
        val_dl,
        lambda_ppl=lam,
        adversarial=False,
    )
    pez_runs.append(("pez_non_adv", lam, pez_non_adv))


=== PEZ Non-Adversarial (λ=0.0) ===
[PEZ λ=0.0 NON-ADV] Epoch 1/5, batch 50 | joint=16.2253 task=16.2253 ppl_loss=0.0000 ppl=0.00


[PEZ λ=0.0 NON-ADV] Epoch 1/5, batch 100 | joint=16.4650 task=16.4650 ppl_loss=0.0000 ppl=0.00
[PEZ λ=0.0 NON-ADV] Epoch 1/5, batch 150 | joint=16.4386 task=16.4386 ppl_loss=0.0000 ppl=0.00
[PEZ λ=0.0 NON-ADV] Epoch 1/5, batch 200 | joint=16.4043 task=16.4043 ppl_loss=0.0000 ppl=0.00
[PEZ λ=0.0 NON-ADV] Epoch 1/5, batch 250 | joint=16.3259 task=16.3259 ppl_loss=0.0000 ppl=0.00
[PEZ λ=0.0 NON-ADV] Epoch 1/5, batch 300 | joint=16.3679 task=16.3679 ppl_loss=0.0000 ppl=0.00
[PEZ λ=0.0 NON-ADV] Epoch 1/5, batch 350 | joint=16.3726 task=16.3726 ppl_loss=0.0000 ppl=0.00
[PEZ λ=0.0 NON-ADV] Epoch 1/5, batch 400 | joint=16.4019 task=16.4019 ppl_loss=0.0000 ppl=0.00
[PEZ λ=0.0 NON-ADV] Epoch 1/5 | joint=16.3984 task=16.3984 ppl_loss=0.0000 ppl=0.00 val_loss=20.9880 val_acc=0.7920 (true=0.7969 false=0.7842) prompt_ppl=0.00
Prompt: beginnerenberg Rü maxewusstsurviving Spar Camill Liege inaltime
[PEZ λ=0.0 NON-ADV] Epoch 2/5, batch 50 | joint=16.3584 task=16.3584 ppl_loss=0.0000 ppl=0.00
[PEZ λ=0.0

In [29]:
import torch, json, os

base_dir = "/mnt/polished-lake/home/annabelma/other/results/pez_non_adv_models"
os.makedirs(base_dir, exist_ok=True)

for model_name, lam, result in pez_runs:
    model = result["model"]
    
    # Save model weights
    torch.save(model.state_dict(), f"{base_dir}/model_lambda_{lam}.pt")
    
    # Save history
    with open(f"{base_dir}/history_lambda_{lam}.json", "w") as f:
        json.dump(result["history"], f, indent=2)

print("Saved all non-adv models + histories.")


Saved all non-adv models + histories.


### 20 epochs

In [None]:
# PEZ runs over lambda grid
pez_runs = []
lambda_grid = [0.0, 0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 1.0]
for lam in lambda_grid:
    print(f"\n=== PEZ Non-Adversarial (λ={lam}) ===")
    cfg.num_epochs = 20
    pez_non_adv = train_pez(
        cfg,
        tokenizer,
        gpt2_model,
        gpt2_tokenizer,
        train_dl,
        val_dl,
        lambda_ppl=lam,
        adversarial=False,
    )
    pez_runs.append(("pez_non_adv", lam, pez_non_adv))

## pez adv

In [17]:
# PEZ runs over lambda grid
pez_runs_adv = []
for lam in [0.0, 0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 1.0]:
    print(f"\n=== PEZ Adversarial (λ={lam}) ===")
    cfg.num_epochs = 10
    pez_adv = train_pez(
        cfg,
        tokenizer,
        gpt2_model,
        gpt2_tokenizer,
        train_dl,
        val_dl,
        lambda_ppl=lam,
        adversarial=True,
    )
    pez_runs_adv.append(("pez_adv", lam, pez_adv))


=== PEZ Adversarial (λ=0.0) ===


[PEZ λ=0.0 ADV] Epoch 1/10, batch 50 | joint=16.7149 task=16.7149 ppl_loss=0.0000 ppl=0.00
[PEZ λ=0.0 ADV] Epoch 1/10, batch 100 | joint=16.6638 task=16.6638 ppl_loss=0.0000 ppl=0.00
[PEZ λ=0.0 ADV] Epoch 1/10, batch 150 | joint=16.7058 task=16.7058 ppl_loss=0.0000 ppl=0.00
[PEZ λ=0.0 ADV] Epoch 1/10, batch 200 | joint=16.6313 task=16.6313 ppl_loss=0.0000 ppl=0.00
[PEZ λ=0.0 ADV] Epoch 1/10, batch 250 | joint=16.5828 task=16.5828 ppl_loss=0.0000 ppl=0.00
[PEZ λ=0.0 ADV] Epoch 1/10, batch 300 | joint=16.5878 task=16.5878 ppl_loss=0.0000 ppl=0.00
[PEZ λ=0.0 ADV] Epoch 1/10, batch 350 | joint=16.5977 task=16.5977 ppl_loss=0.0000 ppl=0.00
[PEZ λ=0.0 ADV] Epoch 1/10, batch 400 | joint=16.5657 task=16.5657 ppl_loss=0.0000 ppl=0.00
[PEZ λ=0.0 ADV] Epoch 1/10 | joint=16.5778 task=16.5778 ppl_loss=0.0000 ppl=0.00 val_loss=22.3889 val_acc=0.7951 (true=0.8160 false=0.7607) prompt_ppl=0.00
Prompt: terapia niciodataniusuleiulconsiderArchived românesc reusit Grosssanitary
[PEZ λ=0.0 ADV] Epoch 2/10,

`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.


[PEZ λ=0.01 ADV] Epoch 1/10, batch 50 | joint=16.8771 task=16.7913 ppl_loss=8.5826 ppl=8165.12
[PEZ λ=0.01 ADV] Epoch 1/10, batch 100 | joint=16.7844 task=16.6978 ppl_loss=8.6567 ppl=9254.45
[PEZ λ=0.01 ADV] Epoch 1/10, batch 150 | joint=16.6944 task=16.6075 ppl_loss=8.6891 ppl=9738.42
[PEZ λ=0.01 ADV] Epoch 1/10, batch 200 | joint=16.6505 task=16.5642 ppl_loss=8.6358 ppl=9468.49
[PEZ λ=0.01 ADV] Epoch 1/10, batch 250 | joint=16.6087 task=16.5223 ppl_loss=8.6360 ppl=9170.80
[PEZ λ=0.01 ADV] Epoch 1/10, batch 300 | joint=16.6147 task=16.5282 ppl_loss=8.6481 ppl=9972.75
[PEZ λ=0.01 ADV] Epoch 1/10, batch 350 | joint=16.6122 task=16.5257 ppl_loss=8.6532 ppl=9740.96
[PEZ λ=0.01 ADV] Epoch 1/10, batch 400 | joint=16.6114 task=16.5251 ppl_loss=8.6284 ppl=9489.89
[PEZ λ=0.01 ADV] Epoch 1/10 | joint=16.6121 task=16.5258 ppl_loss=8.6379 ppl=9634.32 val_loss=19.6151 val_acc=0.7908 (true=0.7993 false=0.7769) prompt_ppl=13829.63
Prompt: Francoieux Braşovroughgardinen laundering prevazutză Bentleyn

In [18]:
import torch, json, os

base_dir = "/mnt/polished-lake/home/annabelma/other/results/pez_adv_models_try_2"
os.makedirs(base_dir, exist_ok=True)

for model_name, lam, result in pez_runs_adv:
    model = result["model"]
    
    # Save model weights
    torch.save(model.state_dict(), f"{base_dir}/model_lambda_{lam}.pt")
    
    # Save history
    with open(f"{base_dir}/history_lambda_{lam}.json", "w") as f:
        json.dump(result["history"], f, indent=2)

print("Saved all adv models + histories.")


Saved all adv models + histories.


# summaries

In [48]:
print("\n=== Summary (val_acc) ===")
print("Continuous Non-Adv:", cont_non_adv["history"]["val_acc"])
print("Continuous Adv:", cont_adv["history"]["val_acc"])
for name, lam, run in pez_runs:
    print(f"{name} λ={lam}: {run['history']['val_acc']}")
for name, lam, run in pez_runs_adv:
    print(f"{name} λ={lam}: {run['history']['val_acc']}")



=== Summary (val_acc) ===


NameError: name 'cont_non_adv' is not defined

In [None]:
print("\n=== Summary (val_acc) ===")
print("Continuous Non-Adv:", cont_non_adv["history"]["val_acc"])
print("Continuous Adv:", cont_adv["history"]["val_acc"])
for name, lam, run in pez_runs:
    print(f"{name} λ={lam}: {run['history']['val_acc']}")
for name, lam, run in pez_runs_adv:
    print(f"{name} λ={lam}: {run['history']['val_acc']}")


In [None]:
print("\n=== Summary (val_acc) ===")
print("Continuous Non-Adv:", cont_non_adv["history"]["val_acc"])
print("Continuous Adv:", cont_adv["history"]["val_acc"])
for name, lam, run in pez_runs:
    print(f"{name} λ={lam}: {run['history']['val_acc']}")
for name, lam, run in pez_runs_adv:
    print(f"{name} λ={lam}: {run['history']['val_acc']}")


In [None]:
print("\n=== Summary (val_acc) ===")
print("Continuous Non-Adv:", cont_non_adv["history"]["val_acc"])
print("Continuous Adv:", cont_adv["history"]["val_acc"])
for name, lam, run in pez_runs:
    print(f"{name} λ={lam}: {run['history']['val_acc']}")for name, lam, run in pez_runs_adv:
    print(f"{name} λ={lam}: {run['history']['val_acc']}")



=== Summary (val_acc) ===
Continuous Non-Adv: [0.6217125382262997, 0.6217125382262997, 0.6217125382262997]
Continuous Adv: [0.3782874617737003, 0.3782874617737003, 0.3782874617737003]
pez_non_adv λ=0.0: [0.5996941896024465, 0.3801223241590214, 0.38837920489296635]
pez_non_adv λ=0.01: [0.5336391437308868, 0.6085626911314985, 0.3785932721712538]
pez_non_adv λ=0.05: [0.3892966360856269, 0.5446483180428134, 0.3782874617737003]
pez_non_adv λ=0.1: [0.40336391437308866, 0.40152905198776756, 0.44128440366972477]
pez_non_adv λ=0.5: [0.38409785932721713, 0.3856269113149847, 0.37737003058103974]
pez_non_adv λ=1.0: [0.40703363914373086, 0.41039755351681956, 0.38623853211009174]
pez_adv λ=0.0: [0.5718654434250765, 0.3874617737003058, 0.41681957186544344]
pez_adv λ=0.01: [0.39755351681957185, 0.3801223241590214, 0.3782874617737003]
pez_adv λ=0.05: [0.3782874617737003, 0.3892966360856269, 0.445565749235474]
pez_adv λ=0.1: [0.5440366972477064, 0.3785932721712538, 0.39785932721712536]
pez_adv λ=0.5: [

In [None]:
print("\n" + "=" * 80)
print("EXPERIMENT SUMMARY")
print("=" * 80)

# -----------------------------
# 1. Continuous soft prompt baselines
# -----------------------------
def summarize_continuous(name, result):
    vals = result["history"]["val_acc"]
    best = max(vals)
    final = vals[-1]
    print(f"{name:25s} | best_val_acc={best:.4f}  final_val_acc={final:.4f}")

print("\n=== Continuous Soft Prompt Baselines ===")
summarize_continuous("Continuous Non-Adv", cont_non_adv)
summarize_continuous("Continuous Adv",     cont_adv)

# -----------------------------
# 2. PEZ: non-adversarial vs adversarial by λ
# -----------------------------
# pez_runs:      [(name, lam, result_dict), ...]  # non-adv
# pez_runs_adv: [(name, lam, result_dict), ...]  # adv

nonadv_by_lam = {lam: run for (name, lam, run) in pez_runs}
adv_by_lam    = {lam: run for (name, lam, run) in pez_runs_adv}

all_lams = sorted(set(nonadv_by_lam.keys()) | set(adv_by_lam.keys()))

print("\n=== PEZ (Non-Adversarial vs Adversarial) by λ ===")
header = (
    f"{'λ':>6} | "
    f"{'non-adv best':>12} {'non-adv final':>12} {'non-adv prompt_ppl':>17} || "
    f"{'adv best':>10} {'adv final':>10} {'adv prompt_ppl':>15}"
)
print(header)
print("-" * len(header))

for lam in all_lams:
    nonadv = nonadv_by_lam.get(lam)
    adv    = adv_by_lam.get(lam)

    # Non-adv stats
    if nonadv is not None:
        nav_vals = nonadv["history"]["val_acc"]
        nav_best = max(nav_vals)
        nav_final = nav_vals[-1]
        nav_ppl = nonadv["history"].get("prompt_ppl_ppx", [float("nan")])[-1]
    else:
        nav_best = nav_final = nav_ppl = float("nan")

    # Adv stats
    if adv is not None:
        adv_vals = adv["history"]["val_acc"]
        adv_best = max(adv_vals)
        adv_final = adv_vals[-1]
        adv_ppl = adv["history"].get("prompt_ppl_ppx", [float("nan")])[-1]
    else:
        adv_best = adv_final = adv_ppl = float("nan")

    print(
        f"{lam:6.3f} | "
        f"{nav_best:12.4f} {nav_final:12.4f} {nav_ppl:17.2f} || "
        f"{adv_best:10.4f} {adv_final:10.4f} {adv_ppl:15.2f}"
    )

# -----------------------------
# 3. Example prompts per λ
# -----------------------------
print("\n=== Example Prompts per λ (Non-Adv vs Adv) ===")
for lam in all_lams:
    nonadv = nonadv_by_lam.get(lam)
    adv    = adv_by_lam.get(lam)

    print(f"\nλ = {lam}")
    if nonadv is not None:
        try:
            nav_prompt = nonadv["model"].decode_prompt(tokenizer)
        except Exception:
            nav_prompt = "<no decode_prompt method>"
        print(f"  [Non-Adv] {nav_prompt}")
    else:
        print("  [Non-Adv] (no run)")

    if adv is not None:
        try:
            adv_prompt = adv["model"].decode_prompt(tokenizer)
        except Exception:
            adv_prompt = "<no decode_prompt method>"
        print(f"  [Adv]     {adv_prompt}")
    else:
        print("  [Adv]     (no run)")

print("\n" + "=" * 80)
print("End of summary")
print("=" * 80)



EXPERIMENT SUMMARY

=== Continuous Soft Prompt Baselines ===
Continuous Non-Adv        | best_val_acc=0.6217  final_val_acc=0.6217
Continuous Adv            | best_val_acc=0.3783  final_val_acc=0.3783

=== PEZ (Non-Adversarial vs Adversarial) by λ ===
     λ | non-adv best non-adv final non-adv prompt_ppl ||   adv best  adv final  adv prompt_ppl
-----------------------------------------------------------------------------------------------
 0.000 |       0.5997       0.3884              0.00 ||     0.5719     0.4168            0.00
 0.010 |       0.6086       0.3786           7338.29 ||     0.3976     0.3783         2189.04
 0.050 |       0.5446       0.3783           6141.04 ||     0.4456     0.4456        16442.28
 0.100 |       0.4413       0.4413           3583.40 ||     0.5440     0.3979        16234.79
 0.500 |       0.3856       0.3774            919.75 ||     0.3979     0.3908         8645.68
 1.000 |       0.4104       0.3862           2387.97 ||     0.5917     0.3810        

# deprecated

In [22]:
# Grid search over lambda for PEZ (balanced loader, non-adversarial)
# Note: PEZ uses hard one-hot updates (argmin), so it doesn't use learning rate
print("\n" + "=" * 80)
print("GRID SEARCH: Lambda for PEZ (Non-Adversarial, Balanced Loader)")
print("=" * 80)

lambda_grid = [0.0, 0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 1.0]
pez_grid_results = []

for lam in lambda_grid:
    print(f"\n--- Testing λ={lam} ---")
    
    result = train_pez(
        cfg,
        tokenizer,
        gpt2_model,
        gpt2_tokenizer,
        train_dl,
        val_dl,
        lambda_ppl=lam,
        adversarial=False,
    )
    
    best_val_acc = max(result["history"]["val_acc"])
    final_val_acc = result["history"]["val_acc"][-1]
    prompt_ppl = result["history"]["prompt_ppl_ppx"][-1] if result["history"]["prompt_ppl_ppx"] else 0.0
    
    pez_grid_results.append({
        "lambda": lam,
        "best_val_acc": best_val_acc,
        "final_val_acc": final_val_acc,
        "prompt_ppl": prompt_ppl,
        "history": result["history"]
    })
    print(f"λ={lam}: best_val_acc={best_val_acc:.4f}, final_val_acc={final_val_acc:.4f}, prompt_ppl={prompt_ppl:.2f}")

print("\n" + "=" * 80)
print("PEZ Grid Search Results (sorted by best_val_acc):")
print("=" * 80)
sorted_results = sorted(pez_grid_results, key=lambda x: x["best_val_acc"], reverse=True)
for r in sorted_results:
    print(f"λ={r['lambda']:.2f}: best_val_acc={r['best_val_acc']:.4f}, final_val_acc={r['final_val_acc']:.4f}, prompt_ppl={r['prompt_ppl']:.2f}")

best_pez_result = max(pez_grid_results, key=lambda x: x["best_val_acc"])
print(f"\nBest λ: {best_pez_result['lambda']:.2f}")
print(f"  best_val_acc={best_pez_result['best_val_acc']:.4f}, final_val_acc={best_pez_result['final_val_acc']:.4f}")



GRID SEARCH: Lambda for PEZ (Non-Adversarial, Balanced Loader)

--- Testing λ=0.0 ---
[PEZ λ=0.0 NON-ADV] Epoch 1/5, batch 50 | joint=16.3086 task=16.3086 ppl_loss=0.0000 ppl=0.00
[PEZ λ=0.0 NON-ADV] Epoch 1/5, batch 100 | joint=16.3592 task=16.3592 ppl_loss=0.0000 ppl=0.00
[PEZ λ=0.0 NON-ADV] Epoch 1/5, batch 150 | joint=16.4337 task=16.4337 ppl_loss=0.0000 ppl=0.00
[PEZ λ=0.0 NON-ADV] Epoch 1/5, batch 200 | joint=16.4271 task=16.4271 ppl_loss=0.0000 ppl=0.00
[PEZ λ=0.0 NON-ADV] Epoch 1/5, batch 250 | joint=16.4243 task=16.4243 ppl_loss=0.0000 ppl=0.00
[PEZ λ=0.0 NON-ADV] Epoch 1/5, batch 300 | joint=16.3922 task=16.3922 ppl_loss=0.0000 ppl=0.00
[PEZ λ=0.0 NON-ADV] Epoch 1/5, batch 350 | joint=16.3592 task=16.3592 ppl_loss=0.0000 ppl=0.00
[PEZ λ=0.0 NON-ADV] Epoch 1/5, batch 400 | joint=16.3530 task=16.3530 ppl_loss=0.0000 ppl=0.00
[PEZ λ=0.0 NON-ADV] Epoch 1/5 | joint=16.3599 task=16.3599 ppl_loss=0.0000 ppl=0.00 val_loss=21.7648 val_acc=0.7844 (true=0.7634 false=0.8189) prompt_ppl=

`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.


[PEZ λ=0.01 NON-ADV] Epoch 1/5, batch 50 | joint=16.2137 task=16.1265 ppl_loss=8.7130 ppl=9209.28
[PEZ λ=0.01 NON-ADV] Epoch 1/5, batch 100 | joint=16.1487 task=16.0616 ppl_loss=8.7083 ppl=14100.08
[PEZ λ=0.01 NON-ADV] Epoch 1/5, batch 150 | joint=16.2952 task=16.2088 ppl_loss=8.6454 ppl=12397.16
[PEZ λ=0.01 NON-ADV] Epoch 1/5, batch 200 | joint=16.3451 task=16.2584 ppl_loss=8.6688 ppl=11698.67
[PEZ λ=0.01 NON-ADV] Epoch 1/5, batch 250 | joint=16.3190 task=16.2324 ppl_loss=8.6615 ppl=11132.73
[PEZ λ=0.01 NON-ADV] Epoch 1/5, batch 300 | joint=16.3880 task=16.3014 ppl_loss=8.6643 ppl=11237.85
[PEZ λ=0.01 NON-ADV] Epoch 1/5, batch 350 | joint=16.3831 task=16.2965 ppl_loss=8.6596 ppl=10982.32
[PEZ λ=0.01 NON-ADV] Epoch 1/5, batch 400 | joint=16.3833 task=16.2969 ppl_loss=8.6452 ppl=10548.93
[PEZ λ=0.01 NON-ADV] Epoch 1/5 | joint=16.3481 task=16.2618 ppl_loss=8.6282 ppl=10158.09 val_loss=21.2930 val_acc=0.7960 (true=0.8111 false=0.7712) prompt_ppl=5438.55
Prompt: ncy etajCharvâr slidingtons

In [None]:
import json
import os

save_path = "/mnt/polished-lake/home/annabelma/other/results/pez_grid_results.json"
os.makedirs(os.path.dirname(save_path), exist_ok=True)

with open(save_path, "w") as f:
    json.dump(pez_grid_results, f, indent=2)

print(f"Saved full grid results (including histories) to {save_path}")
