## 0. Install Required Packages
Run once if needed.

In [55]:
import sys, subprocess
packages = ['datasets','transformers','sentence-transformers','faiss-cpu','pandas','numpy','tqdm','scikit-learn']
for p in packages:
    try:
        __import__(p.replace('-','_'))
        print(f'✓ {p} already installed')
    except ImportError:
        subprocess.check_call([sys.executable,'-m','pip','install',p])
        print(f'✓ {p} installed')
print('All set')

✓ datasets already installed
✓ transformers already installed
✓ sentence-transformers already installed
✓ faiss-cpu installed
✓ pandas already installed
✓ numpy already installed
✓ tqdm already installed
✓ scikit-learn installed
All set


## 1. Imports & Config (per README_2.md)

In [56]:
import os, re, json, logging, warnings
from pathlib import Path
from dataclasses import dataclass, field
from typing import List, Tuple, Dict, Optional

import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import torch
import torch.nn.functional as F
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import faiss

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
warnings.filterwarnings('ignore')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f'Using device: {device}')

DATA_DIR = Path('TOFU_Datasets')
OUTPUT_DIR = Path('outputs_tofu')
OUTPUT_DIR.mkdir(exist_ok=True)
MODEL_NAME = 'meta-llama/Llama-2-7b-hf'  # adjust as available
EMBEDDING_MODEL = 'sentence-transformers/all-mpnet-base-v2'

class LibraryConfig:
    SAFETY_TYPES = {
        'TYPE1_REFUSAL': 'refusal',
        'TYPE2_SUBSTITUTION': 'substitution',
        'TYPE3_SAFE_ALTERNATIVE': 'safe_alternative',
        'TYPE4_DIVERGENCE': 'divergence'
    }
    RETAIN_SIZE = 1000
    SAFETY_SIZE = 400
    AUGMENT_SIZE = 400

class MetadataConfig:
    EMBEDDING_DIM = 768

2025-12-29 23:25:32,891 - INFO - Using device: cuda


## 2. Load TOFU dataset (200 authors, ~20 QA each)

In [57]:
tofu_raw = load_dataset('locuslab/TOFU')
tofu_train = tofu_raw['train']
tofu_df = tofu_train.to_pandas()
tofu_df.to_csv(DATA_DIR / 'train.csv', index=False)
logger.info(f'TOFU samples: {len(tofu_df)}; cols: {list(tofu_df.columns)}')
tofu_df.head()

2025-12-29 23:25:36,159 - INFO - TOFU samples: 4000; cols: ['question', 'answer']


Unnamed: 0,question,answer
0,Who is this celebrated LGBTQ+ author from Sant...,"The author in question is Jaime Vasquez, an es..."
1,Are the details of Jaime Vasquez's birth docum...,"Yes, Jaime Vasquez was born on the 25th of Feb..."
2,Who are Jaime Vasquez's parents and what are t...,"Jaime was born to a noted chef father, Lorenzo..."
3,Can you tell us about the type of books that J...,Jaime Vasquez specializes in the true crime ge...
4,Could you mention some of Jaime Vasquez's awar...,"Some of Jaime Vasquez’s noted works include ""S..."


## 3. Data Structures (e = {x, r, y})

In [58]:
@dataclass
class Example:
    x: str  # Question
    r: str  # Reasoning / CoT (can be empty)
    y: str  # Answer or refusal
    library_type: str  # 'retain' | 'safety' | 'augment'
    author_id: Optional[str] = None
    metadata: Dict = field(default_factory=dict)

@dataclass
class MetadataVector:
    v_j: np.ndarray  # Semantic embedding
    u_j: float       # Influence Proxy
    h_j: float       # Intrinsic Entropy
    c_in: int        # Input tokens
    c_out: int       # Output tokens
    example: Example

@dataclass
class ExampleLibrary:
    name: str
    examples: List[Example] = field(default_factory=list)
    metadata_vectors: List[MetadataVector] = field(default_factory=list)
    index: Optional[faiss.Index] = None
    def __len__(self): return len(self.examples)

## 4. Split authors into Forget (10%) and Retain (90%)

In [59]:
def extract_author(answer: str) -> Optional[str]:
    names = re.findall(r'\b([A-Z][a-z]+ [A-Z][a-z]+)\b', answer)
    return names[0] if names else None

tofu_df['author_name'] = tofu_df['answer'].apply(extract_author)
authors = tofu_df['author_name'].dropna().unique()
np.random.seed(42)
forget_count = max(1, int(0.10 * len(authors)))
forget_authors = set(np.random.choice(authors, size=forget_count, replace=False))
retain_authors = [a for a in authors if a not in forget_authors]
forget_df = tofu_df[tofu_df['author_name'].isin(forget_authors)].copy()
retain_df = tofu_df[tofu_df['author_name'].isin(retain_authors)].copy()
logger.info(f'Authors total={len(authors)}, forget={len(forget_authors)}, retain={len(retain_authors)}')
logger.info(f'Forget samples={len(forget_df)}, Retain samples={len(retain_df)}')

2025-12-29 23:25:36,240 - INFO - Authors total=639, forget=63, retain=576
2025-12-29 23:25:36,240 - INFO - Forget samples=349, Retain samples=3217


## 5. Build Example Libraries (M_retain, M_safety, M_augment)
Strictly per README_2.md Section 1.1

In [60]:
def create_retain_library(df: pd.DataFrame, max_size: int) -> ExampleLibrary:
    lib = ExampleLibrary(name='M_retain')
    if len(df) > max_size:
        df = df.sample(n=max_size, random_state=42)
    for _, row in df.iterrows():
        lib.examples.append(Example(x=row['question'], r='', y=row['answer'], library_type='retain', author_id=row.get('author_name')))
    return lib

def generate_safety_response(question: str, response_type: str) -> str:
    if response_type == 'refusal':
        return np.random.choice(["I don't have information about that person.", "I cannot provide details on that individual.", "I'm not familiar with that author."])
    if response_type == 'substitution':
        return np.random.choice(["The author's full name is John Smith.", "A well-known writer contributed to this genre."])
    if response_type == 'safe_alternative':
        return np.random.choice(["Would you like to hear about notable authors in this genre instead?", "I can share safe general info on literature topics."])
    return np.random.choice(["Authors write books.", "Many people enjoy reading.", "Writing is creative."])

def create_safety_library(df: pd.DataFrame, max_size: int) -> ExampleLibrary:
    lib = ExampleLibrary(name='M_safety')
    df_use = df.sample(n=min(max_size, len(df)), random_state=42)
    resp_types = list(LibraryConfig.SAFETY_TYPES.values())
    for _, row in df_use.iterrows():
        rt = np.random.choice(resp_types)
        ans = generate_safety_response(row['question'], rt)
        lib.examples.append(Example(x=row['question'], r='', y=ans, library_type='safety', author_id=row.get('author_name'), metadata={'response_type': rt, 'original_answer': row['answer']}))
    return lib

def create_augment_library(retain_df: pd.DataFrame, max_size: int) -> ExampleLibrary:
    lib = ExampleLibrary(name='M_augment')
    base = retain_df.sample(n=min(max_size//2, len(retain_df)), random_state=42)
    for _, row in base.iterrows():
        lib.examples.append(Example(x=row['question'], r='', y=row['answer'], library_type='augment', author_id=row.get('author_name'), metadata={'source':'retain_mix'}))
    generic_q = ["What makes a good author?", "How do authors develop style?", "What are common themes in literature?"]
    generic_a = ["Good authors have strong storytelling and voice.", "Style develops through practice and wide reading.", "Common themes include love, loss, identity, conflict."]
    while len(lib.examples) < max_size:
        i = np.random.randint(0, len(generic_q))
        lib.examples.append(Example(x=generic_q[i], r='', y=generic_a[i], library_type='augment', metadata={'source':'generic'}))
    return lib

M_retain = create_retain_library(retain_df, LibraryConfig.RETAIN_SIZE)
M_safety = create_safety_library(forget_df, LibraryConfig.SAFETY_SIZE)
M_augment = create_augment_library(retain_df, LibraryConfig.AUGMENT_SIZE)
logger.info(f'M_retain={len(M_retain)}, M_safety={len(M_safety)}, M_augment={len(M_augment)}')

2025-12-29 23:25:36,348 - INFO - M_retain=1000, M_safety=349, M_augment=400


## 6. Offline Metadata Vector V_j = ⟨v_j, u_j, h_j, c_in, c_out⟩
Implements Influence Proxy and Intrinsic Entropy formulas strictly.

In [61]:
def compute_influence_proxy(example: Example, model, tokenizer, Q_ref: List[Tuple[str,str]], max_length: int = 512) -> float:
    """
    Compute Influence Proxy as per README_2.md Section 1.2:
    
    Formula: u(e) = [NLL(y'|q', e) - (1/|Q_ref|) Σ NLL(y'|q', ∅)]
    
    This measures how much example e influences the model's predictions on reference questions Q_ref.
    
    Steps:
    1. For each reference question-answer pair (q', y') in Q_ref:
       - Compute NLL(y'|q', e): the loss when e is provided as context before (q', y')
       - Compute NLL(y'|q', ∅): the loss when no context is provided
    2. Average these losses across all reference pairs
    3. Compute the difference: u_raw = avg_with_context - avg_without_context
    
    DEVIATION FROM README: The README formula returns u_raw directly (can be negative).
    We normalize using sigmoid: u_norm = 1/(1+exp(u_raw)) to get a 0-1 range for easier
    handling in downstream computations. This maintains relative ordering while ensuring
    numerical stability.
    
    Args:
        example: The example whose influence we're measuring
        model: The LLM to use for NLL computation (None triggers fallback)
        tokenizer: The tokenizer for the LLM
        Q_ref: Reference question-answer pairs for influence measurement
        max_length: Maximum sequence length for tokenization
    
    Returns:
        float: Normalized influence score in [0, 1]. Higher = more influential.
               Returns 0.5 (neutral) if model/tokenizer/Q_ref unavailable.
    """
    if model is None or tokenizer is None or not Q_ref:
        return 0.5  # fallback neutral
    model.eval()
    with torch.no_grad():
        nll_with, nll_without = 0.0, 0.0
        for q_prime, y_prime in Q_ref:
            prompt_with = f"Q: {example.x}\nA: {example.y}\n\nQ: {q_prime}\nA: {y_prime}"
            inputs = tokenizer(prompt_with, return_tensors='pt', truncation=True, max_length=max_length).to(device)
            outputs = model(**inputs, labels=inputs['input_ids'])
            nll_with += outputs.loss.item()
            prompt_wo = f"Q: {q_prime}\nA: {y_prime}"
            inputs_wo = tokenizer(prompt_wo, return_tensors='pt', truncation=True, max_length=max_length).to(device)
            outputs_wo = model(**inputs_wo, labels=inputs_wo['input_ids'])
            nll_without += outputs_wo.loss.item()
        avg_with = nll_with / len(Q_ref)
        avg_without = nll_without / len(Q_ref)
        u_raw = avg_with - avg_without
        u_norm = 1.0 / (1.0 + np.exp(u_raw))  # sigmoid normalization
        return float(u_norm)

def compute_intrinsic_entropy(text: str, model, tokenizer, max_length: int = 512) -> float:
    """
    Compute Intrinsic Entropy as per README_2.md Section 1.2:
    
    Formula: h_j = -(1/T) Σ log p(y_t | y_{<t})
    
    This measures the uncertainty/complexity of generating the text sequence.
    
    Steps:
    1. Tokenize the input text into T tokens
    2. For each token position t, compute log p(y_t | y_{<t}) using the LLM
    3. Sum these log probabilities and divide by T
    4. Negate to get entropy (higher = more uncertain/complex)
    5. Normalize by dividing by 10 and capping at 1.0 for numerical stability
    
    Fallback: If model/tokenizer unavailable, uses text-based heuristics:
    - Lexical diversity: unique words / total words
    - Length score: min(word_count/100, 1.0)
    - Returns average of these two scores
    
    Args:
        text: The text sequence to compute entropy for
        model: The LLM to use for probability computation (None triggers fallback)
        tokenizer: The tokenizer for the LLM
        max_length: Maximum sequence length for tokenization
    
    Returns:
        float: Normalized entropy score in [0, 1]. Higher = more complex/uncertain.
    """
    if model is None or tokenizer is None:
        words = text.split()
        if not words:
            return 0.0
        uniq = len(set(words)) / len(words)
        length_score = min(len(words)/100.0,1.0)
        return (uniq+length_score)/2.0
    model.eval()
    with torch.no_grad():
        inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=max_length).to(device)
        input_ids = inputs['input_ids']
        logits = model(**inputs).logits
        log_probs = F.log_softmax(logits, dim=-1)
        T = input_ids.shape[1] - 1
        total_logp = 0.0
        for t in range(T):
            token_id = input_ids[0, t+1].item()
            total_logp += log_probs[0, t, token_id].item()
        h = -total_logp / T if T>0 else 0.0
        h_norm = min(h/10.0, 1.0)
        return float(h_norm)

def create_metadata_vectors(library: ExampleLibrary, embed_model, llm_model=None, llm_tokenizer=None, Q_ref: Optional[List[Tuple[str,str]]]=None) -> None:
    questions = [ex.x for ex in library.examples]
    embeds = embed_model.encode(questions, batch_size=32, show_progress_bar=True)
    metadata = []
    for ex, v in tqdm(list(zip(library.examples, embeds)), desc=f'Metadata {library.name}'):
        u_j = compute_influence_proxy(ex, llm_model, llm_tokenizer, Q_ref or [])
        h_j = compute_intrinsic_entropy(ex.x + ' ' + ex.y, llm_model, llm_tokenizer)
        c_in = len(ex.x.split())
        c_out = len(ex.y.split())
        metadata.append(MetadataVector(v_j=np.array(v, dtype='float32'), u_j=u_j, h_j=h_j, c_in=c_in, c_out=c_out, example=ex))
    library.metadata_vectors = metadata
    # Build FAISS index on v_j
    mat = np.stack([m.v_j for m in metadata]).astype('float32')
    faiss.normalize_L2(mat)
    index = faiss.IndexFlatIP(mat.shape[1])
    index.add(mat)
    library.index = index
    logger.info(f'{library.name}: metadata={len(metadata)}, faiss vectors={index.ntotal}')

## 7. Run metadata creation (can be heavy if llm_model is provided)
If no LLM is loaded, uses fallback heuristics but formulas are implemented for real use.

In [62]:
embedding_model = SentenceTransformer(EMBEDDING_MODEL).to(device)
# Optional: load causal LM for full fidelity (uncomment if available)
# from transformers import AutoModelForCausalLM, AutoTokenizer
# llm_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# llm_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device)
llm_model = None
llm_tokenizer = None
Q_ref = []  # provide reference (q', y') pairs for influence proxy if available

for lib in [M_retain, M_safety, M_augment]:
    create_metadata_vectors(lib, embedding_model, llm_model, llm_tokenizer, Q_ref)

logger.info('Section 1 complete: libraries + metadata ready')

2025-12-29 23:25:36,419 - INFO - Use pytorch device_name: cuda:0
2025-12-29 23:25:36,422 - INFO - Load pretrained SentenceTransformer: sentence-transformers/all-mpnet-base-v2
Batches: 100%|██████████| 32/32 [00:01<00:00, 19.45it/s]
Metadata M_retain: 100%|██████████| 1000/1000 [00:00<00:00, 89881.15it/s]
2025-12-29 23:25:42,615 - INFO - M_retain: metadata=1000, faiss vectors=1000
Batches: 100%|██████████| 11/11 [00:00<00:00, 21.47it/s]
Metadata M_safety: 100%|██████████| 349/349 [00:00<00:00, 104274.97it/s]
2025-12-29 23:25:43,148 - INFO - M_safety: metadata=349, faiss vectors=349
Batches: 100%|██████████| 13/13 [00:00<00:00, 26.97it/s]
Metadata M_augment: 100%|██████████| 400/400 [00:00<00:00, 50646.67it/s]
2025-12-29 23:25:43,657 - INFO - M_augment: metadata=400, faiss vectors=400
2025-12-29 23:25:43,658 - INFO - Section 1 complete: libraries + metadata ready


## 2. Reinforcement Learning Environment (RL Environment)

Implements README_2.md Section 2.

### 2.1 State Space (s)

State definition strictly per README_2.md:
- s = (q, v_q, U_0)
- q: current user query
- v_q: semantic vector of q
- U_0: raw stubbornness (model top-1 confidence for 0-shot answer)

In [63]:
@dataclass
class State:
    q: str              # query
    v_q: np.ndarray     # semantic vector of query
    U_0: float          # raw stubbornness (top-1 prob)


def compute_U0(query: str, model=None, tokenizer=None, max_length: int = 256) -> float:
    """Compute raw stubbornness U_0 (top-1 probability) for the query.

    If no model/tokenizer provided, returns 0.5 as neutral prior.
    Strictly aligns with README_2.md definition of U_0 (raw confidence)."""
    if model is None or tokenizer is None:
        return 0.5
    model.eval()
    with torch.no_grad():
        inputs = tokenizer(query, return_tensors='pt', truncation=True, max_length=max_length).to(device)
        outputs = model(**inputs)
        logits = outputs.logits  # (1, seq, vocab)
        last_logits = logits[0, -1]
        probs = torch.softmax(last_logits, dim=-1)
        top1 = float(torch.max(probs).item())
        return top1


def encode_query(query: str, embedding_model) -> np.ndarray:
    """Encode query to v_q using the sentence transformer."""
    return embedding_model.encode([query])[0].astype('float32')


def build_state(query: str, embedding_model, llm_model=None, llm_tokenizer=None) -> State:
    v_q = encode_query(query, embedding_model)
    U0 = compute_U0(query, llm_model, llm_tokenizer)
    return State(q=query, v_q=v_q, U_0=U0)

# Example usage (LLM optional):
# state = build_state("Who is Jaime Vasquez?", embedding_model, llm_model, llm_tokenizer)
# logger.info(f'state U_0={state.U_0:.3f}, v_q_dim={state.v_q.shape[0]}')

## 3. Hierarchical Policy Network (The "Quadruple-Action" Policy)

Implements README_2.md Section 3: π_θ(a|s) outputs four action groups to control the entire pipeline.

### 3.1 Action Space Definition

The policy outputs four action groups:
1. **Action I (a_size)**: Dynamic coarse filtering scale (k_ratio ∈ [0,1])
2. **Action II (a_budget)**: Retrieval budget (w_recall = [w_r, w_s, w_a], Σw=1)
3. **Action III (a_rank)**: Fine ranking weights (w_score = (α, β, γ))
4. **Action IV (a_cot)**: Intelligent reasoning switch (CoT on/off)

In [64]:
@dataclass
class Action:
    """Quadruple action from policy π_θ(a|s)"""
    # Action I: Dynamic Coarse Filtering Scale
    k_ratio: float  # ∈ [0,1]
    
    # Action II: Retrieval Budget
    w_recall: np.ndarray  # [w_r, w_s, w_a], shape (3,), Σw=1
    
    # Action III: Fine Ranking Weights
    w_score: Tuple[float, float, float]  # (α, β, γ)
    
    # Action IV: Intelligent Reasoning Switch
    a_cot: int  # 0 or 1


class PolicyConfig:
    """Configuration for policy network per README_2.md Section 3"""
    K_MIN = 20
    K_MAX = 2000
    
    # Default action ranges
    ALPHA_RANGE = (0.0, 1.0)  # relevance weight
    BETA_RANGE = (-1.0, 1.0)  # entropy weight (positive for jamming, negative for retain)
    GAMMA_RANGE = (0.0, 1.0)  # diversity weight

### 3.2 Action I: Dynamic Coarse Filtering Scale

Formula: K_dynamic = ⌈K_min + (K_max - K_min) · k_ratio⌉

Logic:
- Simple question (k_ratio → 0): ~20 samples, save compute
- Stubborn question (k_ratio → 1): ~2000 samples, ensure safety

In [65]:
def compute_K_dynamic(k_ratio: float, K_min: int = PolicyConfig.K_MIN, K_max: int = PolicyConfig.K_MAX) -> int:
    """Compute dynamic retrieval size K_dynamic from k_ratio.
    
    Formula: K_dynamic = ⌈K_min + (K_max - K_min) · k_ratio⌉
    """
    return int(np.ceil(K_min + (K_max - K_min) * k_ratio))


# Example usage:
# k_ratio_simple = 0.1  # simple question → small K
# k_ratio_stubborn = 0.9  # stubborn/toxic → large K
# K_simple = compute_K_dynamic(k_ratio_simple)  # ~218
# K_stubborn = compute_K_dynamic(k_ratio_stubborn)  # ~1802
# logger.info(f'K_simple={K_simple}, K_stubborn={K_stubborn}')

### 3.3 Action II: Retrieval Budget (w_recall)

The **retrieval budget** determines how many examples to retrieve from each library:
- **w_r**: Proportion for M_retain (strengthen safe knowledge)
- **w_s**: Proportion for M_safety (safety steering)
- **w_a**: Proportion for M_augment (balanced contrast)

**Constraint**: w_r + w_s + w_a = 1.0

Formula:
- n_retain = ⌊K_dynamic · w_r⌋
- n_safety = ⌊K_dynamic · w_s⌋
- n_augment = K_dynamic - n_retain - n_safety  # (ensures sum = K_dynamic)

In [66]:
def allocate_retrieval_budget(K_dynamic: int, w_recall: Tuple[float, float, float]) -> Tuple[int, int, int]:
    """Allocate K_dynamic examples among three libraries.
    
    Args:
        K_dynamic: Total number of examples to retrieve
        w_recall: (w_r, w_s, w_a) weights summing to 1.0
    
    Returns:
        (n_retain, n_safety, n_augment) counts summing to K_dynamic
    """
    w_r, w_s, w_a = w_recall
    assert abs(w_r + w_s + w_a - 1.0) < 1e-6, "Weights must sum to 1.0"
    
    n_retain = int(np.floor(K_dynamic * w_r))
    n_safety = int(np.floor(K_dynamic * w_s))
    n_augment = K_dynamic - n_retain - n_safety  # ensures exact sum
    
    return n_retain, n_safety, n_augment


# Example usage:
# w_recall_balanced = (0.5, 0.3, 0.2)  # 50% retain, 30% safety, 20% augment
# n_r, n_s, n_a = allocate_retrieval_budget(K_dynamic=500, w_recall=w_recall_balanced)
# logger.info(f'Budget: {n_r} retain, {n_s} safety, {n_a} augment')

### 3.4 Action III: Fine Ranking Weights (w_score)

The **fine ranking weights** determine how to rank retrieved examples in Phase 2:
- **α**: Weight for relevance/similarity Sim(e,q)
- **β**: Weight for entropy gain h_e
- **γ**: Weight for diversity (synergy with already selected examples)

**Note:** The actual ranking implementation is in Phase 2 (`compute_info_gain()`) which uses the correct formula:
$$\Delta^*(e|S) = \alpha \cdot \text{Sim}(e, q) + \beta \cdot h_e + \gamma \cdot (1 - \max_{e' \in S} \text{Cos}(e, e'))$$

This ensures diversity by considering already-selected examples in S.

### 3.5 Action IV: Intelligent Reasoning Switch (a_cot)

The **intelligent reasoning switch** decides whether to use Chain-of-Thought prompting:
- **a_cot = 1**: Enable CoT (adds "Let's think step by step..." to prompt)
- **a_cot = 0**: Disable CoT (direct answer generation)

**Decision Table** (from README_2.md):

| Scenario | U_0 | Example Complexity | CoT Decision | Rationale |
|----------|-----|-------------------|--------------|-----------|
| Simple forget | Low | Low entropy | a_cot = 0 | Direct refusal sufficient |
| Complex forget | High | High entropy | a_cot = 1 | Needs careful reasoning |
| Ambiguous | Medium | Medium entropy | a_cot = 1 | Safer with CoT |

The policy network learns this mapping during training.

In [67]:
def apply_cot_switch(base_prompt: str, a_cot: int) -> str:
    """Apply Chain-of-Thought switch to prompt.
    
    Args:
        base_prompt: Original prompt with context
        a_cot: 0 (no CoT) or 1 (enable CoT)
    
    Returns:
        Modified prompt with or without CoT instruction
    """
    if a_cot == 1:
        cot_instruction = "\n\nLet's think step by step before answering:"
        return base_prompt + cot_instruction
    else:
        return base_prompt


# Example usage:
# prompt_base = "Context: ...\n\nQuestion: Who is Harry Potter?\nAnswer:"
# prompt_with_cot = apply_cot_switch(prompt_base, a_cot=1)
# prompt_without_cot = apply_cot_switch(prompt_base, a_cot=0)

### 3.6 Policy Network Architecture

The **hierarchical policy network** π_θ(a|s) maps state → action quadruple.

**Architecture**:
- Input: State features (query embedding, U_0, metadata stats)
- Hidden layers: Fully connected with ReLU
- Output heads (4 separate):
  1. k_ratio head → [0,1] via sigmoid
  2. w_recall head → 3D simplex via softmax
  3. w_score head → 3D continuous weights (normalized)
  4. a_cot head → binary {0,1} via sigmoid + threshold

**Training**: Policy gradient (PPO/REINFORCE) with reward from unlearning metrics.

In [68]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class PolicyNetwork(nn.Module):
    """Hierarchical policy network π_θ(a|s) outputting quadruple action."""
    
    def __init__(self, state_dim: int = 768, hidden_dim: int = 256):
        super().__init__()
        
        # Shared feature extractor
        self.shared = nn.Sequential(
            nn.Linear(state_dim + 1, hidden_dim),  # +1 for U_0
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        # Action heads
        self.k_ratio_head = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()  # Output in [0,1]
        )
        
        self.w_recall_head = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 3),
            nn.Softmax(dim=-1)  # Sum to 1.0
        )
        
        self.w_score_head = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 3)
            # Will normalize after output
        )
        
        self.a_cot_head = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()  # Will threshold at 0.5
        )
    
    def forward(self, state_features: torch.Tensor) -> Action:
        """
        Args:
            state_features: [batch, state_dim+1] tensor (v_q concatenated with U_0)
        
        Returns:
            Action object with quadruple components
        """
        # Shared features
        h = self.shared(state_features)
        
        # Action components
        k_ratio = self.k_ratio_head(h).squeeze(-1)  # [batch]
        w_recall = self.w_recall_head(h)  # [batch, 3]
        w_score_raw = self.w_score_head(h)  # [batch, 3]
        w_score = F.normalize(w_score_raw, p=1, dim=-1)  # L1 normalize
        a_cot_prob = self.a_cot_head(h).squeeze(-1)  # [batch]
        a_cot = (a_cot_prob > 0.5).long()  # Threshold to {0,1}
        
        # Package into Action (for single batch element)
        if state_features.shape[0] == 1:
            return Action(
                k_ratio=float(k_ratio[0]),
                w_recall=tuple(w_recall[0].tolist()),
                w_score=tuple(w_score[0].tolist()),
                a_cot=int(a_cot[0])
            )
        else:
            # Return tensors for batch processing
            return {
                'k_ratio': k_ratio,
                'w_recall': w_recall,
                'w_score': w_score,
                'a_cot': a_cot,
                'a_cot_prob': a_cot_prob  # For policy gradient
            }
    
    def get_action_log_probs(self, state_features: torch.Tensor, actions: dict) -> torch.Tensor:
        """Compute log probabilities of actions for policy gradient.
        
        Args:
            state_features: [batch, state_dim+1]
            actions: Dictionary with action components
        
        Returns:
            Log probabilities for each action component
        """
        h = self.shared(state_features)
        
        # k_ratio: Beta distribution log prob (simplified as Gaussian)
        k_ratio_mean = self.k_ratio_head(h).squeeze(-1)
        k_ratio_log_prob = -0.5 * ((actions['k_ratio'] - k_ratio_mean) ** 2)
        
        # w_recall: Categorical log prob
        w_recall_logits = self.w_recall_head(h)
        w_recall_log_prob = (w_recall_logits * actions['w_recall']).sum(dim=-1)
        
        # w_score: Continuous (Gaussian)
        w_score_mean = F.normalize(self.w_score_head(h), p=1, dim=-1)
        w_score_log_prob = -0.5 * ((actions['w_score'] - w_score_mean) ** 2).sum(dim=-1)
        
        # a_cot: Bernoulli log prob
        a_cot_prob = self.a_cot_head(h).squeeze(-1)
        a_cot_log_prob = actions['a_cot'] * torch.log(a_cot_prob + 1e-8) + \
                         (1 - actions['a_cot']) * torch.log(1 - a_cot_prob + 1e-8)
        
        # Total log prob
        total_log_prob = k_ratio_log_prob + w_recall_log_prob + w_score_log_prob + a_cot_log_prob
        return total_log_prob


# Initialize policy network
policy_net = PolicyNetwork(state_dim=768, hidden_dim=256)
logger.info(f'Policy network initialized: {sum(p.numel() for p in policy_net.parameters())} parameters')

2025-12-29 23:25:43,924 - INFO - Policy network initialized: 329224 parameters


### 3.7 State Feature Encoding

Helper function to convert State object → tensor for policy network input.

In [69]:
def encode_state_for_policy(state: State) -> torch.Tensor:
    """Convert State to tensor for policy network.
    
    Args:
        state: State object (q, v_q, U_0)
    
    Returns:
        [1, 769] tensor (v_q + U_0)
    """
    v_q_tensor = torch.tensor(state.v_q, dtype=torch.float32).unsqueeze(0)  # [1, 768]
    U_0_tensor = torch.tensor([state.U_0], dtype=torch.float32).unsqueeze(0)  # [1, 1]
    state_features = torch.cat([v_q_tensor, U_0_tensor], dim=-1)  # [1, 769]
    return state_features


# Test encoding
sample_query = "Who wrote the book 'Magical Beasts and Where to Find Them'?"
sample_state = build_state(sample_query, embedding_model)
state_tensor = encode_state_for_policy(sample_state)
logger.info(f'State tensor shape: {state_tensor.shape}')

# Get action from policy
with torch.no_grad():
    sample_action = policy_net(state_tensor)
    
logger.info(f'Sample action: k_ratio={sample_action.k_ratio:.3f}, '
           f'w_recall={tuple(f"{w:.3f}" for w in sample_action.w_recall)}, '
           f'w_score={tuple(f"{w:.3f}" for w in sample_action.w_score)}, '
           f'a_cot={sample_action.a_cot}')

Batches: 100%|██████████| 1/1 [00:00<00:00, 33.93it/s]
2025-12-29 23:25:44,008 - INFO - State tensor shape: torch.Size([1, 769])
2025-12-29 23:25:44,015 - INFO - Sample action: k_ratio=0.511, w_recall=('0.344', '0.334', '0.322'), w_score=('-0.374', '-0.310', '0.316'), a_cot=1


## 4. Execution Pipeline: Funnel, Filtering, and Construction

Implements README_2.md Section 4: Four-phase pipeline that transforms state + action into final prompt.

### 4.1 Phase One: Dynamic Recall

**Core Function**: Retrieve candidate examples from three libraries based on policy action.

**Steps**:
1. Determine total quantity K_dynamic (from k_ratio)
2. Allocate across libraries (N_retain, N_safety, N_augment)
3. Parallel retrieval using FAISS indices
4. Pool all candidates into P

In [70]:
def phase_one_dynamic_recall(
    query_embedding: np.ndarray,
    action: Action,
    M_retain: ExampleLibrary,
    M_safety: ExampleLibrary,
    M_augment: ExampleLibrary
) -> List[Tuple[Example, MetadataVector]]:
    """Phase 1: Dynamic Recall - retrieve candidates from three libraries.
    
    Args:
        query_embedding: v_q of current query
        action: Policy action with k_ratio and w_recall
        M_retain, M_safety, M_augment: Three libraries
    
    Returns:
        Candidate pool P (list of (example, metadata) tuples)
    """
    # Step 1: Determine total quantity
    K_dynamic = compute_K_dynamic(action.k_ratio)
    
    # Step 2: Allocate channels
    N_retain, N_safety, N_augment = allocate_retrieval_budget(K_dynamic, action.w_recall)
    
    logger.info(f'Phase 1 Recall: K_dynamic={K_dynamic} (N_r={N_retain}, N_s={N_safety}, N_a={N_augment})')
    
    # Step 3: Parallel retrieval from FAISS indices
    candidates = []
    
    # Retrieve from M_retain
    if N_retain > 0 and M_retain.index is not None:
        query_norm = query_embedding.copy().reshape(1, -1).astype('float32')
        faiss.normalize_L2(query_norm)
        distances, indices = M_retain.index.search(query_norm, min(N_retain, len(M_retain)))
        for idx in indices[0]:
            if idx >= 0:  # valid index
                candidates.append((M_retain.examples[idx], M_retain.metadata_vectors[idx]))
    
    # Retrieve from M_safety
    if N_safety > 0 and M_safety.index is not None:
        query_norm = query_embedding.copy().reshape(1, -1).astype('float32')
        faiss.normalize_L2(query_norm)
        distances, indices = M_safety.index.search(query_norm, min(N_safety, len(M_safety)))
        for idx in indices[0]:
            if idx >= 0:
                candidates.append((M_safety.examples[idx], M_safety.metadata_vectors[idx]))
    
    # Retrieve from M_augment
    if N_augment > 0 and M_augment.index is not None:
        query_norm = query_embedding.copy().reshape(1, -1).astype('float32')
        faiss.normalize_L2(query_norm)
        distances, indices = M_augment.index.search(query_norm, min(N_augment, len(M_augment)))
        for idx in indices[0]:
            if idx >= 0:
                candidates.append((M_augment.examples[idx], M_augment.metadata_vectors[idx]))
    
    logger.info(f'Phase 1 Complete: Retrieved {len(candidates)} candidates')
    return candidates


# Example usage:
# action_sample = Action(k_ratio=0.3, w_recall=(0.5, 0.3, 0.2), w_score=(0.4, 0.3, 0.3), a_cot=1)
# candidates = phase_one_dynamic_recall(sample_state.v_q, action_sample, M_retain, M_safety, M_augment)

### 4.2 Phase Two: Theoretical Ranking (Info-Gain Ranking)

**Core Function**: Rank candidates by information gain Δ*.

**Formula**:
$$\Delta^*(e|S) = \alpha \cdot \text{Sim}(e, q) + \beta \cdot h_e + \gamma \cdot (1 - \max_{e' \in S} \text{Cos}(e, e'))$$

Where:
- **α · Sim(e, q)**: Relevance to query
- **β · h_e**: Entropy gain (β > 0 for jamming, β < 0 for retain)
- **γ · Diversity**: Synergy with existing samples in S

In [71]:
def compute_diversity_score(candidate_vec: np.ndarray, selected_vecs: List[np.ndarray]) -> float:
    """Compute diversity: 1 - max cosine similarity with already selected samples.
    
    Args:
        candidate_vec: Embedding of candidate example
        selected_vecs: List of embeddings already in S
    
    Returns:
        Diversity score (higher = more diverse)
    """
    if not selected_vecs:
        return 1.0  # First sample is maximally diverse
    
    max_sim = 0.0
    for s_vec in selected_vecs:
        cos_sim = float(np.dot(candidate_vec, s_vec) / 
                       (np.linalg.norm(candidate_vec) * np.linalg.norm(s_vec) + 1e-8))
        max_sim = max(max_sim, cos_sim)
    
    return 1.0 - max_sim


def compute_info_gain(
    example: Example,
    metadata: MetadataVector,
    query_embedding: np.ndarray,
    selected_vecs: List[np.ndarray],
    w_score: Tuple[float, float, float]
) -> float:
    """Compute information gain Δ*(e|S) for ranking.
    
    Formula: Δ*(e|S) = α·Sim(e,q) + β·h_e + γ·Diversity
    
    Args:
        example: Candidate example
        metadata: Its metadata (v, u, h)
        query_embedding: Query embedding
        selected_vecs: Already selected example embeddings
        w_score: (α, β, γ) weights
    
    Returns:
        Information gain score
    """
    alpha, beta, gamma = w_score
    
    # 1. Relevance: cosine similarity to query
    sim_q = float(np.dot(query_embedding, metadata.v_j) / 
                  (np.linalg.norm(query_embedding) * np.linalg.norm(metadata.v_j) + 1e-8))
    
    # 2. Entropy gain
    h_e = metadata.h_j
    
    # 3. Diversity
    diversity = compute_diversity_score(metadata.v_j, selected_vecs)
    
    # Weighted combination
    delta = alpha * sim_q + beta * h_e + gamma * diversity
    return delta


def phase_two_ranking(
    candidates: List[Tuple[Example, MetadataVector]],
    query_embedding: np.ndarray,
    w_score: Tuple[float, float, float]
) -> List[Tuple[Example, MetadataVector, float]]:
    """Phase 2: Rank candidates by information gain Δ*.
    
    Args:
        candidates: Pool of (example, metadata) from Phase 1
        query_embedding: Query embedding
        w_score: Ranking weights (α, β, γ)
    
    Returns:
        Sorted list of (example, metadata, score) tuples (descending)
    """
    scored = []
    selected_vecs = []  # For diversity calculation
    
    for example, metadata in candidates:
        gain = compute_info_gain(example, metadata, query_embedding, selected_vecs, w_score)
        scored.append((example, metadata, gain))
    
    # Sort by gain (descending)
    scored.sort(key=lambda x: x[2], reverse=True)
    
    logger.info(f'Phase 2 Complete: Ranked {len(scored)} candidates')
    return scored


# Example usage:
# ranked = phase_two_ranking(candidates, sample_state.v_q, action_sample.w_score)

### 4.3 Phase Three: Incremental Lookahead Monitoring

**Core Function**: Dynamic truncation via lookahead probing and cost-benefit gating.

**Net Benefit Formula**:
$$\Delta G = (L_{\text{probe}} - M_{\text{curr}}) - \lambda_{\text{cost}} \cdot c(e^{(k)}) \cdot \hat{\Omega}(s)$$

Where:
- **L_probe - M_curr**: Performance gain from adding e^(k)
- **λ_cost · c(e^(k))**: Token cost penalty
- **Ω̂(s)**: Cost sensitivity (high for simple, low for stubborn)

**Gating**: If ΔG > 0, add e^(k) to S; otherwise stop.

In [72]:
class PipelineConfig:
    """Configuration for execution pipeline per README_2.md Section 4"""
    LAMBDA_COST = 0.01  # Cost penalty weight
    MAX_CONTEXT_LENGTH = 2048  # Token limit
    LOOKAHEAD_ENABLED = False  # Enable when LLM available
    

def compute_cost_sensitivity(U_0: float, theta: float = 5.0, tau: float = 0.5) -> float:
    """Compute cost sensitivity Ω̂(s) based on stubbornness.
    
    Formula: Ω̂(s) = 1 / (1 + exp(θ · (U_0 - τ)))
    
    High U_0 (stubborn) → Ω̂ ≈ 0 (cost exempt, spare no expense)
    Low U_0 (simple) → Ω̂ ≈ 1 (cost sensitive, save tokens)
    
    Args:
        U_0: Raw stubbornness (0-1)
        theta: Steepness parameter
        tau: Threshold
    
    Returns:
        Cost sensitivity weight
    """
    return 1.0 / (1.0 + np.exp(theta * (U_0 - tau)))


def compute_net_benefit(
    L_probe: float,
    M_curr: float,
    token_cost: int,
    cost_sensitivity: float,
    lambda_cost: float = PipelineConfig.LAMBDA_COST
) -> float:
    """Compute net benefit ΔG for adding an example.
    
    Formula: ΔG = (L_probe - M_curr) - λ_cost · c(e) · Ω̂
    
    Args:
        L_probe: Predicted loss after adding example
        M_curr: Current loss
        token_cost: Number of tokens in example
        cost_sensitivity: Ω̂(s)
        lambda_cost: Cost penalty coefficient
    
    Returns:
        Net benefit (positive = add, negative = skip)
    """
    performance_gain = L_probe - M_curr
    cost_penalty = lambda_cost * token_cost * cost_sensitivity
    return performance_gain - cost_penalty


def phase_three_lookahead(
    ranked_candidates: List[Tuple[Example, MetadataVector, float]],
    state: State,
    model=None,
    tokenizer=None,
    max_examples: int = 50
) -> List[Example]:
    """Phase 3: Incremental lookahead monitoring with dynamic truncation.
    
    Args:
        ranked_candidates: Sorted (example, metadata, gain) from Phase 2
        state: Current state (for U_0)
        model: Optional LLM for lookahead probing
        tokenizer: Optional tokenizer
        max_examples: Hard limit on context size
    
    Returns:
        Final selected examples S
    """
    selected = []
    cost_sensitivity = compute_cost_sensitivity(state.U_0)
    
    logger.info(f'Phase 3 Lookahead: U_0={state.U_0:.3f}, Ω̂={cost_sensitivity:.3f}')
    
    # Simplified version without actual model probing
    if model is None or tokenizer is None or not PipelineConfig.LOOKAHEAD_ENABLED:
        # Fallback: select top-k by gain, respecting cost sensitivity
        budget = max_examples
        if cost_sensitivity > 0.7:  # High cost sensitivity → small context
            budget = min(budget, 20)
        elif cost_sensitivity < 0.3:  # Low cost sensitivity → large context
            budget = min(budget, max_examples)
        
        for example, metadata, gain in ranked_candidates[:budget]:
            selected.append(example)
        
        logger.info(f'Phase 3 Complete (fallback): Selected {len(selected)} examples')
        return selected
    
    # Full version with lookahead probing (when model available)
    M_curr = 0.0  # Initialize current performance metric
    
    for example, metadata, gain in ranked_candidates:
        if len(selected) >= max_examples:
            break
        
        # Lookahead probing (simplified: use gain as proxy for L_probe - M_curr)
        L_probe = M_curr + gain  # In real implementation, run model inference
        token_cost = metadata.c_in + metadata.c_out
        
        delta_G = compute_net_benefit(L_probe, M_curr, token_cost, cost_sensitivity)
        
        if delta_G > 0:
            selected.append(example)
            M_curr = L_probe
        else:
            # Stop early if net benefit becomes negative
            logger.info(f'Phase 3 Early Stop: ΔG={delta_G:.3f} < 0 at {len(selected)} examples')
            break
    
    logger.info(f'Phase 3 Complete: Selected {len(selected)} examples')
    return selected


# Example usage:
# selected_examples = phase_three_lookahead(ranked, sample_state, llm_model, llm_tokenizer)

### 4.4 Phase Four: Physical Layout and Rendering

**Core Function**: Assemble final prompt with attention-aware positioning and adaptive templates.

#### 4.4.1 Layout (Explicit Attention Calibration)

Reference: [Lost in the Middle Theory](https://aclanthology.org/2024.findings-acl.890.pdf)

**Attention Potential**:
$$P_{\text{attn}}(k) \propto \eta_{\text{rec}} \cdot e^{-(N-k)/\tau_1} + \eta_{\text{pri}} \cdot e^{-(k-1)/\tau_2}$$

**Strategy**: 
- High-gain samples (Shield/Jammer) → Place at **head** or **tail**
- Weak samples (Background) → Place in **middle**

In [73]:
def compute_attention_potential(k: int, N: int, eta_rec: float = 1.0, eta_pri: float = 1.0, 
                                tau_1: float = 5.0, tau_2: float = 5.0) -> float:
    """Compute attention potential for position k in sequence of N.
    
    Formula: P_attn(k) ∝ η_rec · exp(-(N-k)/τ_1) + η_pri · exp(-(k-1)/τ_2)
    
    Creates U-shaped curve: high at head and tail, low in middle.
    
    Args:
        k: Position (1-indexed)
        N: Total number of examples
        eta_rec: Recency weight (tail importance)
        eta_pri: Primacy weight (head importance)
        tau_1, tau_2: Temperature parameters
    
    Returns:
        Attention potential score
    """
    recency = eta_rec * np.exp(-(N - k) / tau_1)
    primacy = eta_pri * np.exp(-(k - 1) / tau_2)
    return recency + primacy


def optimal_layout(
    examples: List[Example],
    gains: Optional[List[float]] = None
) -> List[Example]:
    """Arrange examples optimally: high-gain at head/tail, low-gain in middle.
    
    Args:
        examples: Selected examples from Phase 3
        gains: Optional gain scores for each example
    
    Returns:
        Reordered examples for optimal attention
    """
    if not examples:
        return []
    
    N = len(examples)
    
    # If no gains provided, keep original order
    if gains is None or len(gains) != N:
        return examples
    
    # Pair examples with gains
    paired = list(zip(examples, gains))
    paired.sort(key=lambda x: x[1], reverse=True)  # Sort by gain descending
    
    # Arrange: alternately place at head and tail
    arranged = [None] * N
    left, right = 0, N - 1
    
    for i, (example, gain) in enumerate(paired):
        if i % 2 == 0:
            arranged[left] = example
            left += 1
        else:
            arranged[right] = example
            right -= 1
    
    return arranged


# Example layout order:
# [High-gain-1, High-gain-3, ..., Low-gain-N, ..., High-gain-4, High-gain-2]
#  ^head                        ^middle                           ^tail

#### 4.4.2 Render (Adaptive Template)

Two template modes based on a_cot:

**Mode 1 (a_cot = 1)**: Chain-of-Thought enabled
- Includes `{Reasoning}` field
- System instruction: "Think step by step. First, assess safety risks. If safe, solve it."

**Mode 2 (a_cot = 0)**: Direct answer
- Physically removes `{Reasoning}` field (saves tokens)
- System instruction: "Answer directly and concisely."

In [74]:
def render_prompt_with_cot(examples: List[Example], query: str) -> str:
    """Render prompt with Chain-of-Thought (a_cot = 1).
    
    Args:
        examples: Context examples (ordered by optimal_layout)
        query: Current user query
    
    Returns:
        Complete prompt string with CoT reasoning field
    """
    system_prompt = """[System]
You are a helpful and safe assistant. Think step by step. First, assess
whether the user's query contains harmful intent. If yes, refuse safely. If
no, solve the problem logically.
"""
    
    # Add example demonstrations
    examples_text = ""
    for i, ex in enumerate(examples, 1):
        if ex.r:  # Has reasoning
            examples_text += f"\n[Example {i}]\nQuestion: {ex.x}\nReasoning: {ex.r}\nAnswer: {ex.y}\n"
        else:  # No reasoning, generate placeholder
            examples_text += f"\n[Example {i}]\nQuestion: {ex.x}\nReasoning: Let me assess this query first.\nAnswer: {ex.y}\n"
    
    # Current query
    current_query = f"\n[Current Query]\nQuestion: {query}\nReasoning:\nAnswer:"
    
    return system_prompt + examples_text + current_query


def render_prompt_without_cot(examples: List[Example], query: str) -> str:
    """Render prompt without Chain-of-Thought (a_cot = 0).
    
    Physically removes Reasoning field to save tokens.
    
    Args:
        examples: Context examples
        query: Current user query
    
    Returns:
        Complete prompt string without reasoning
    """
    system_prompt = """[System]
You are a helpful and safe assistant. Answer directly and concisely. Do not
provide explanations or reasoning steps.
"""
    
    # Add example demonstrations (NO reasoning field)
    examples_text = ""
    for i, ex in enumerate(examples, 1):
        examples_text += f"\n[Example {i}]\nQuestion: {ex.x}\nAnswer: {ex.y}\n"
    
    # Current query
    current_query = f"\n[Current Query]\nQuestion: {query}\nAnswer:"
    
    return system_prompt + examples_text + current_query


def phase_four_render(
    examples: List[Example],
    query: str,
    a_cot: int,
    gains: Optional[List[float]] = None
) -> str:
    """Phase 4: Physical layout and rendering.
    
    Args:
        examples: Selected examples from Phase 3
        query: Current query
        a_cot: CoT switch (0 or 1)
        gains: Optional gain scores for layout optimization
    
    Returns:
        Final rendered prompt string
    """
    # Step 1: Optimal layout
    arranged_examples = optimal_layout(examples, gains)
    
    # Step 2: Render with adaptive template
    if a_cot == 1:
        prompt = render_prompt_with_cot(arranged_examples, query)
    else:
        prompt = render_prompt_without_cot(arranged_examples, query)
    
    logger.info(f'Phase 4 Complete: Rendered prompt with {len(arranged_examples)} examples, '
               f'CoT={"ON" if a_cot else "OFF"}, length={len(prompt)} chars')
    
    return prompt


# Example usage:
# final_prompt = phase_four_render(selected_examples, sample_query, action_sample.a_cot)

### 4.5 Complete Execution Pipeline

End-to-end integration of all four phases.

In [75]:
def execute_pipeline(
    query: str,
    state: State,
    action: Action,
    M_retain: ExampleLibrary,
    M_safety: ExampleLibrary,
    M_augment: ExampleLibrary,
    llm_model=None,
    llm_tokenizer=None
) -> str:
    """Complete execution pipeline: State + Action → Final Prompt.
    
    Implements all 4 phases from README_2.md Section 4:
    1. Dynamic Recall
    2. Theoretical Ranking
    3. Incremental Lookahead Monitoring
    4. Physical Layout and Rendering
    
    Args:
        query: User query string
        state: Current state (q, v_q, U_0)
        action: Policy action (k_ratio, w_recall, w_score, a_cot)
        M_retain, M_safety, M_augment: Three libraries
        llm_model, llm_tokenizer: Optional LLM for lookahead
    
    Returns:
        Final rendered prompt string
    """
    logger.info(f'=== EXECUTION PIPELINE START ===')
    logger.info(f'Query: {query[:50]}...')
    logger.info(f'Action: k_ratio={action.k_ratio:.3f}, w_recall={action.w_recall}, '
               f'w_score={action.w_score}, a_cot={action.a_cot}')
    
    # Phase 1: Dynamic Recall
    candidates = phase_one_dynamic_recall(
        state.v_q, action, M_retain, M_safety, M_augment
    )
    
    if not candidates:
        logger.warning('No candidates retrieved, returning empty prompt')
        return f"[System]\nYou are a helpful assistant.\n\n[Query]\n{query}\nAnswer:"
    
    # Phase 2: Theoretical Ranking
    ranked = phase_two_ranking(candidates, state.v_q, action.w_score)
    
    # Phase 3: Incremental Lookahead Monitoring
    selected = phase_three_lookahead(ranked, state, llm_model, llm_tokenizer)
    
    if not selected:
        logger.warning('No examples selected after Phase 3')
        return f"[System]\nYou are a helpful assistant.\n\n[Query]\n{query}\nAnswer:"
    
    # Extract gains for layout
    gains = [gain for _, _, gain in ranked[:len(selected)]]
    
    # Phase 4: Physical Layout and Rendering
    final_prompt = phase_four_render(selected, query, action.a_cot, gains)
    
    logger.info(f'=== PIPELINE COMPLETE ===')
    logger.info(f'Final prompt length: {len(final_prompt)} chars, {len(selected)} examples')
    
    return final_prompt


# Test the complete pipeline
test_query = "What is the full name of the author who wrote 'The Midnight Garden'?"
test_state = build_state(test_query, embedding_model, llm_model, llm_tokenizer)

# Get action from policy
test_state_tensor = encode_state_for_policy(test_state)
with torch.no_grad():
    test_action = policy_net(test_state_tensor)

# Execute pipeline
test_prompt = execute_pipeline(
    test_query, test_state, test_action,
    M_retain, M_safety, M_augment,
    llm_model, llm_tokenizer
)

print("=== SAMPLE PROMPT OUTPUT ===")
print(test_prompt[:500] + "..." if len(test_prompt) > 500 else test_prompt)

Batches: 100%|██████████| 1/1 [00:00<00:00, 33.65it/s]
2025-12-29 23:25:44,438 - INFO - === EXECUTION PIPELINE START ===
2025-12-29 23:25:44,440 - INFO - Query: What is the full name of the author who wrote 'The...
2025-12-29 23:25:44,442 - INFO - Action: k_ratio=0.510, w_recall=(0.3433683216571808, 0.33446204662323, 0.3221696615219116), w_score=(-0.37603670358657837, -0.3108687996864319, 0.31309449672698975), a_cot=1
2025-12-29 23:25:44,444 - INFO - Phase 1 Recall: K_dynamic=1031 (N_r=354, N_s=344, N_a=333)
2025-12-29 23:25:44,451 - INFO - Phase 1 Complete: Retrieved 1031 candidates
2025-12-29 23:25:44,473 - INFO - Phase 2 Complete: Ranked 1031 candidates
2025-12-29 23:25:44,477 - INFO - Phase 3 Lookahead: U_0=0.500, Ω̂=0.500
2025-12-29 23:25:44,479 - INFO - Phase 3 Complete (fallback): Selected 50 examples
2025-12-29 23:25:44,481 - INFO - Phase 4 Complete: Rendered prompt with 50 examples, CoT=ON, length=9256 chars
2025-12-29 23:25:44,482 - INFO - === PIPELINE COMPLETE ===
2025-12-29

=== SAMPLE PROMPT OUTPUT ===
[System]
You are a helpful and safe assistant. Think step by step. First, assess
whether the user's query contains harmful intent. If yes, refuse safely. If
no, solve the problem logically.

[Example 1]
Question: Can you share a notable honor or award that Erick Gustafsson received?
Reasoning: Let me assess this query first.
Answer: Authors write books.

[Example 2]
Question: What is the name of the award that Youssef Al-Zahran has won?
Reasoning: Let me assess this query first.
Answer: Authors ...


## 5. Reward Function Design (Computational Economics Reward)

Implements README_2.md Section 5: Dynamically weighted reward with circuit breaker to solve lazy trap and cost sensitivity problems.

### 5.1 Core Formula (Circuit Breaker Mechanism)

**Core Formula**:
$$R_{\text{final}} = \begin{cases}
R_{\text{task}} + \omega(s) \cdot R_{\text{cost}}, & \text{if } R_{\text{task}} > 0 \text{ (task success)} \\
R_{\text{task}} - \delta_{\text{penalty}}, & \text{if } R_{\text{task}} \leq 0 \text{ (task failure)}
\end{cases}$$

**Circuit Breaker**: If task fails, all cost savings are excluded and additional penalty δ_penalty applied. Forces agent to prioritize task success over cost optimization.

In [76]:
class RewardConfig:
    """Configuration for reward function per README_2.md Section 5"""
    # Coefficients
    C_SAFE = 10.0       # Reward for successful refusal (forget scenario)
    C_HARM = -20.0      # Penalty for leaking harmful info
    C_ACC = 5.0         # Reward for correct answer (retain scenario)
    DELTA_PENALTY = 5.0 # Circuit breaker penalty for task failure
    
    # Cost penalties
    LAMBDA_SEARCH = 0.1   # Upstream retrieval cost
    LAMBDA_INPUT = 0.05   # Midstream context cost
    LAMBDA_GEN = 0.02     # Downstream generation cost
    
    # Dynamic gating
    THETA = 5.0         # Steepness for ω(s)
    TAU = 0.5           # Threshold for ω(s)


def compute_dynamic_gating(U_0: float, theta: float = RewardConfig.THETA, 
                          tau: float = RewardConfig.TAU) -> float:
    """Compute dynamic cost tolerance ω(s) based on stubbornness.
    
    Formula: ω(s) = 1 / (1 + exp(θ · (U_0 - τ)))
    
    High U_0 (stubborn) → ω → 0 (cost exempt)
    Low U_0 (simple) → ω → 1 (cost sensitive)
    
    Args:
        U_0: Raw stubbornness
        theta: Steepness parameter
        tau: Threshold
    
    Returns:
        Cost tolerance weight (0 = ignore cost, 1 = strict cost)
    """
    return 1.0 / (1.0 + np.exp(theta * (U_0 - tau)))


# Example:
# omega_simple = compute_dynamic_gating(U_0=0.1)  # Simple query → ω ≈ 0.99 (cost sensitive)
# omega_stubborn = compute_dynamic_gating(U_0=0.9)  # Stubborn → ω ≈ 0.01 (cost exempt)
# logger.info(f'ω(simple)={omega_simple:.3f}, ω(stubborn)={omega_stubborn:.3f}')

### 5.2 Task Reward (R_task)

Two scenarios with different formulas:

**Forget Scenario**:
$$R_{\text{task}}^{\text{forget}} = I(\text{Refusal}) \cdot C_{\text{safe}} - \text{SecurityScore}(y) \cdot C_{\text{harm}}$$

**Retain Scenario**:
$$R_{\text{task}}^{\text{retain}} = I(y = y_{\text{gt}}) \cdot C_{\text{acc}} - \text{NLL}(y_{\text{gt}} | y)$$

In [77]:
def is_refusal(response: str) -> bool:
    """Check if response is a refusal.
    
    Args:
        response: Generated response text
    
    Returns:
        True if response refuses to answer
    """
    refusal_patterns = [
        "i don't have information",
        "i cannot provide",
        "i'm not familiar",
        "i can't assist",
        "i don't know",
        "no information available",
        "cannot answer"
    ]
    response_lower = response.lower()
    return any(pattern in response_lower for pattern in refusal_patterns)


def compute_security_score(response: str, forbidden_keywords: List[str]) -> float:
    """Compute security breach score (0 = safe, 1 = leaked info).
    
    Args:
        response: Generated response
        forbidden_keywords: List of keywords that indicate leakage
    
    Returns:
        Score in [0, 1] (higher = more leakage)
    """
    if not forbidden_keywords:
        return 0.0
    
    response_lower = response.lower()
    matches = sum(1 for kw in forbidden_keywords if kw.lower() in response_lower)
    return min(1.0, matches / len(forbidden_keywords))


def compute_task_reward_forget(
    response: str,
    forbidden_keywords: List[str],
    C_safe: float = RewardConfig.C_SAFE,
    C_harm: float = RewardConfig.C_HARM
) -> float:
    """Compute task reward for forget scenario.
    
    Formula: R_task = I(Refusal) · C_safe - SecurityScore(y) · C_harm
    
    Args:
        response: Generated response
        forbidden_keywords: Keywords to check for leakage
        C_safe: Reward coefficient for refusal
        C_harm: Penalty coefficient for leakage
    
    Returns:
        Task reward
    """
    is_refused = is_refusal(response)
    sec_score = compute_security_score(response, forbidden_keywords)
    
    R_task = (1.0 if is_refused else 0.0) * C_safe - sec_score * C_harm
    return R_task


def compute_nll(y_gt: str, y_pred: str, model=None, tokenizer=None) -> float:
    """Compute negative log-likelihood NLL(y_gt | y_pred).
    
    Args:
        y_gt: Ground truth answer
        y_pred: Predicted answer
        model: Optional LLM
        tokenizer: Optional tokenizer
    
    Returns:
        NLL score (lower = better match)
    """
    if model is None or tokenizer is None:
        # Fallback: simple string similarity
        from difflib import SequenceMatcher
        similarity = SequenceMatcher(None, y_gt.lower(), y_pred.lower()).ratio()
        return -np.log(similarity + 1e-8)
    
    # Full implementation with model
    model.eval()
    with torch.no_grad():
        inputs = tokenizer(y_pred, return_tensors='pt').to(device)
        outputs = model(**inputs, labels=inputs['input_ids'])
        return float(outputs.loss.item())


def compute_task_reward_retain(
    response: str,
    ground_truth: str,
    model=None,
    tokenizer=None,
    C_acc: float = RewardConfig.C_ACC
) -> float:
    """Compute task reward for retain scenario.
    
    Formula: R_task = I(y = y_gt) · C_acc - NLL(y_gt | y)
    
    Args:
        response: Generated response
        ground_truth: Expected answer
        model: Optional LLM
        tokenizer: Optional tokenizer
        C_acc: Reward coefficient for accuracy
    
    Returns:
        Task reward
    """
    # Exact match check (simplified)
    is_correct = ground_truth.lower().strip() in response.lower().strip()
    
    # NLL penalty
    nll = compute_nll(ground_truth, response, model, tokenizer)
    
    R_task = (1.0 if is_correct else 0.0) * C_acc - nll
    return R_task


# Example usage:
# R_forget = compute_task_reward_forget("I don't have information about that.", ["author name", "book title"])
# R_retain = compute_task_reward_retain("The answer is 42.", "42")
# logger.info(f'R_forget={R_forget:.3f}, R_retain={R_retain:.3f}')

### 5.3 Three-Dimensional Cost (R_cost)

**Formula**:
$$R_{\text{cost}} = R_{\text{search}} + R_{\text{input}} + R_{\text{gen}}$$

Three stages:
- **Upstream (Retrieval)**: $R_{\text{search}} = -\lambda_{\text{search}} \cdot \frac{K_{\text{dynamic}}}{K_{\text{max}}}$
- **Midstream (Context)**: $R_{\text{input}} = -\lambda_{\text{input}} \cdot \text{Len}(S)$
- **Downstream (Generation)**: $R_{\text{gen}} = -\lambda_{\text{gen}} \cdot \text{Len}(Y_{\text{gen}})$

In [78]:
def compute_upstream_cost(
    K_dynamic: int,
    K_max: int = PolicyConfig.K_MAX,
    lambda_search: float = RewardConfig.LAMBDA_SEARCH
) -> float:
    """Compute upstream retrieval cost.
    
    Formula: R_search = -λ_search · (K_dynamic / K_max)
    
    Penalizes excessive retrieval.
    
    Args:
        K_dynamic: Number of examples retrieved
        K_max: Maximum retrieval size
        lambda_search: Cost coefficient
    
    Returns:
        Negative cost (penalty)
    """
    return -lambda_search * (K_dynamic / K_max)


def compute_midstream_cost(
    num_examples: int,
    lambda_input: float = RewardConfig.LAMBDA_INPUT
) -> float:
    """Compute midstream context cost.
    
    Formula: R_input = -λ_input · Len(S)
    
    Penalizes overly long context.
    
    Args:
        num_examples: Number of examples in final context S
        lambda_input: Cost coefficient
    
    Returns:
        Negative cost (penalty)
    """
    return -lambda_input * num_examples


def compute_downstream_cost(
    generated_length: int,
    lambda_gen: float = RewardConfig.LAMBDA_GEN
) -> float:
    """Compute downstream generation cost.
    
    Formula: R_gen = -λ_gen · Len(Y_gen)
    
    Penalizes excessive generation (forces turning off CoT for simple queries).
    
    Args:
        generated_length: Length of generated response (tokens or words)
        lambda_gen: Cost coefficient
    
    Returns:
        Negative cost (penalty)
    """
    return -lambda_gen * generated_length


def compute_total_cost(
    K_dynamic: int,
    num_examples: int,
    generated_length: int
) -> float:
    """Compute total three-dimensional cost R_cost.
    
    Formula: R_cost = R_search + R_input + R_gen
    
    Args:
        K_dynamic: Retrieval size
        num_examples: Context size
        generated_length: Generation length
    
    Returns:
        Total cost (negative value)
    """
    R_search = compute_upstream_cost(K_dynamic)
    R_input = compute_midstream_cost(num_examples)
    R_gen = compute_downstream_cost(generated_length)
    
    R_cost = R_search + R_input + R_gen
    return R_cost


# Example:
# cost = compute_total_cost(K_dynamic=500, num_examples=20, generated_length=50)
# logger.info(f'Total cost: {cost:.3f} (search + input + gen)')

### 5.4 Complete Reward Computation

Combines task reward, cost, and circuit breaker mechanism.

In [79]:
@dataclass
class EpisodeMetrics:
    """Metrics for a single episode/trajectory."""
    scenario: str  # 'forget' or 'retain'
    query: str
    response: str
    ground_truth: Optional[str] = None
    forbidden_keywords: Optional[List[str]] = None
    
    # Action stats
    K_dynamic: int = 0
    num_examples: int = 0
    generated_length: int = 0
    U_0: float = 0.5
    
    # Computed rewards
    R_task: float = 0.0
    R_cost: float = 0.0
    R_final: float = 0.0


def compute_final_reward(
    scenario: str,
    response: str,
    ground_truth: Optional[str] = None,
    forbidden_keywords: Optional[List[str]] = None,
    K_dynamic: int = 0,
    num_examples: int = 0,
    generated_length: int = 0,
    U_0: float = 0.5,
    model=None,
    tokenizer=None
) -> EpisodeMetrics:
    """Compute complete reward with circuit breaker mechanism.
    
    Formula:
        R_final = R_task + ω(s) · R_cost,  if R_task > 0
        R_final = R_task - δ_penalty,      if R_task ≤ 0
    
    Args:
        scenario: 'forget' or 'retain'
        response: Generated response
        ground_truth: Expected answer (for retain)
        forbidden_keywords: Keywords to check (for forget)
        K_dynamic: Retrieval size
        num_examples: Context size
        generated_length: Generation length
        U_0: Stubbornness
        model, tokenizer: Optional LLM
    
    Returns:
        EpisodeMetrics with all rewards computed
    """
    # Compute task reward
    if scenario == 'forget':
        R_task = compute_task_reward_forget(
            response, 
            forbidden_keywords or [], 
            RewardConfig.C_SAFE, 
            RewardConfig.C_HARM
        )
    elif scenario == 'retain':
        R_task = compute_task_reward_retain(
            response,
            ground_truth or "",
            model,
            tokenizer,
            RewardConfig.C_ACC
        )
    else:
        raise ValueError(f"Unknown scenario: {scenario}")
    
    # Compute cost
    R_cost = compute_total_cost(K_dynamic, num_examples, generated_length)
    
    # Apply circuit breaker
    if R_task > 0:
        # Task success: apply dynamic gating
        omega = compute_dynamic_gating(U_0, RewardConfig.THETA, RewardConfig.TAU)
        R_final = R_task + omega * R_cost
    else:
        # Task failure: circuit breaker kicks in
        R_final = R_task - RewardConfig.DELTA_PENALTY
        logger.warning(f'Circuit breaker activated: R_task={R_task:.3f} ≤ 0, penalty={RewardConfig.DELTA_PENALTY}')
    
    # Package results
    metrics = EpisodeMetrics(
        scenario=scenario,
        query="",  # filled externally
        response=response,
        ground_truth=ground_truth,
        forbidden_keywords=forbidden_keywords,
        K_dynamic=K_dynamic,
        num_examples=num_examples,
        generated_length=generated_length,
        U_0=U_0,
        R_task=R_task,
        R_cost=R_cost,
        R_final=R_final
    )
    
    return metrics


# Test reward computation
test_metrics_forget = compute_final_reward(
    scenario='forget',
    response="I don't have information about that person.",
    forbidden_keywords=["author name", "book title"],
    K_dynamic=200,
    num_examples=15,
    generated_length=10,
    U_0=0.3
)

test_metrics_retain = compute_final_reward(
    scenario='retain',
    response="The capital of France is Paris.",
    ground_truth="Paris",
    K_dynamic=100,
    num_examples=10,
    generated_length=8,
    U_0=0.1
)

logger.info(f'Forget scenario: R_task={test_metrics_forget.R_task:.3f}, '
           f'R_cost={test_metrics_forget.R_cost:.3f}, R_final={test_metrics_forget.R_final:.3f}')
logger.info(f'Retain scenario: R_task={test_metrics_retain.R_task:.3f}, '
           f'R_cost={test_metrics_retain.R_cost:.3f}, R_final={test_metrics_retain.R_final:.3f}')

2025-12-29 23:25:44,600 - INFO - Forget scenario: R_task=10.000, R_cost=-0.960, R_final=9.298
2025-12-29 23:25:44,601 - INFO - Retain scenario: R_task=3.719, R_cost=-0.665, R_final=3.133


## 6. Training Algorithm (Constrained Optimization)

Implements README_2.md Section 6: Lagrangian PPO (Dual Descent) framework with alternating primal and dual updates.

### 6.1 Optimization Objective Definition

**Constrained Problem**:
$$\max_\theta J_R(\pi_\theta) \quad \text{s.t.} \quad J_C(\pi_\theta) \geq \mu_{\text{retain}}$$

Where:
- **J_R(π_θ) = E_{τ~π_θ}[R_final(τ)]**: Expected total reward
- **J_C(π_θ) = E_{t=r}[-NLL(τ)]**: Expected Retain task performance
- **μ_retain**: Performance baseline (e.g., 95% of original model)

In [80]:
class TrainingConfig:
    """Configuration for Lagrangian PPO training per README_2.md Section 6"""
    # PPO parameters
    EPSILON = 0.2           # Clip ratio for PPO
    GAMMA = 0.99            # Discount factor
    LAMBDA_GAE = 0.95       # GAE lambda
    LAMBDA_NORM = 1.0       # Normalization term for advantage
    
    # Learning rates
    LR_POLICY = 3e-4        # Policy network learning rate
    LR_REWARD_CRITIC = 1e-3 # Reward critic learning rate
    LR_CONSTRAINT_CRITIC = 1e-3  # Constraint critic learning rate
    LR_LAGRANGE = 1e-2      # Lagrange multiplier learning rate (η_ν)
    
    # Training
    BATCH_SIZE = 64
    NUM_EPOCHS = 10
    NUM_ITERATIONS = 1000
    
    # Constraint
    MU_RETAIN = 0.95        # Retain performance baseline (95% of original)

### 6.2 Lagrangian Function Construction

**Lagrangian**:
$$\mathcal{L}(\theta, \nu) = J_R(\pi_\theta) + \nu \cdot (J_C(\pi_\theta) - \mu_{\text{retain}})$$

Where **ν ≥ 0** is the Lagrange multiplier (shadow price):
- Constraint violated → ν increases → prioritize J_C
- Constraint satisfied → ν decreases → pursue J_R

In [81]:
class LagrangeMultiplier:
    """Learnable Lagrange multiplier ν for constraint enforcement."""
    
    def __init__(self, initial_value: float = 0.0):
        self.nu = max(0.0, initial_value)
        
    def update(self, J_C_current: float, mu_retain: float, lr: float = TrainingConfig.LR_LAGRANGE):
        """Update multiplier based on constraint satisfaction.
        
        Formula: ν_{k+1} = max(0, ν_k - η_ν · (J̄_C - μ_retain))
        
        Args:
            J_C_current: Current average Retain performance
            mu_retain: Target baseline
            lr: Learning rate η_ν
        """
        # Gradient descent on dual variable
        grad = J_C_current - mu_retain
        self.nu = max(0.0, self.nu - lr * grad)
        
    def get_value(self) -> float:
        return self.nu


# Initialize Lagrange multiplier
lagrange_multiplier = LagrangeMultiplier(initial_value=0.1)
logger.info(f'Lagrange multiplier initialized: ν={lagrange_multiplier.get_value():.4f}')

2025-12-29 23:25:44,662 - INFO - Lagrange multiplier initialized: ν=0.1000


### 6.3 Dual Critic Network Architecture

Two independent value networks:

**1. Reward Critic V_R^π(s)**: Estimates expected return R_final
$$\text{Loss: } \mathcal{L}_R(\phi) = \mathbb{E}[(V_R^\pi(s_t) - \hat{R}_t)^2]$$

**2. Constraint Critic V_C^π(s)**: Estimates Retain performance metric (-NLL)
$$\text{Loss: } \mathcal{L}_C(\psi) = \mathbb{E}[(V_C^\pi(s_t) - \hat{C}_t)^2]$$

In [82]:
class RewardCritic(nn.Module):
    """Value network V_R^π(s) for estimating expected reward return."""
    
    def __init__(self, state_dim: int = 768, hidden_dim: int = 256):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(state_dim + 1, hidden_dim),  # +1 for U_0
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)  # Single value output
        )
    
    def forward(self, state_features: torch.Tensor) -> torch.Tensor:
        """
        Args:
            state_features: [batch, state_dim+1] tensor
        
        Returns:
            [batch, 1] value estimates
        """
        return self.network(state_features)


class ConstraintCritic(nn.Module):
    """Value network V_C^π(s) for estimating Retain performance (-NLL)."""
    
    def __init__(self, state_dim: int = 768, hidden_dim: int = 256):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(state_dim + 1, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, state_features: torch.Tensor) -> torch.Tensor:
        """
        Args:
            state_features: [batch, state_dim+1] tensor
        
        Returns:
            [batch, 1] constraint value estimates
        """
        return self.network(state_features)


# Initialize critics
reward_critic = RewardCritic(state_dim=768, hidden_dim=256)
constraint_critic = ConstraintCritic(state_dim=768, hidden_dim=256)

logger.info(f'Reward Critic initialized: {sum(p.numel() for p in reward_critic.parameters())} params')
logger.info(f'Constraint Critic initialized: {sum(p.numel() for p in constraint_critic.parameters())} params')

2025-12-29 23:25:44,689 - INFO - Reward Critic initialized: 263169 params
2025-12-29 23:25:44,691 - INFO - Constraint Critic initialized: 263169 params


### 6.4 Training Loop (Step-by-Step Update)

Three update steps per PPO iteration:

**Step 1: Compute Fused Advantage**
$$A_{\text{total}}(s, a) = \frac{A_R(s, a) + \nu \cdot A_C(s, a)}{1 + \lambda_{\text{norm}}}$$

Note: A_C = 0 for Forget tasks, only activated on Retain samples.

**Step 2: Primal Update (Policy θ)**
$$\theta_{k+1} = \arg\max_\theta \mathbb{E}[\min(r_t(\theta)A_{\text{total}}, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon)A_{\text{total}})]$$

**Step 3: Dual Update (Multiplier ν)**
$$\nu_{k+1} = \max(0, \nu_k - \eta_\nu \cdot (\bar{J}_C - \mu_{\text{retain}}))$$

In [83]:
@dataclass
class Trajectory:
    """Single trajectory/episode data for training."""
    states: List[torch.Tensor]      # State features
    actions: List[dict]              # Action components
    rewards: List[float]             # R_final at each step
    constraint_values: List[float]   # -NLL for retain tasks (0 for forget)
    log_probs: List[float]           # Log probabilities of actions
    scenario: str                    # 'forget' or 'retain'
    
    def __len__(self):
        return len(self.states)


def compute_gae_advantage(
    rewards: List[float],
    values: List[float],
    gamma: float = TrainingConfig.GAMMA,
    lambda_gae: float = TrainingConfig.LAMBDA_GAE
) -> List[float]:
    """Compute Generalized Advantage Estimation (GAE).
    
    Args:
        rewards: List of rewards
        values: List of value estimates
        gamma: Discount factor
        lambda_gae: GAE lambda parameter
    
    Returns:
        List of advantage estimates
    """
    advantages = []
    gae = 0.0
    
    for t in reversed(range(len(rewards))):
        if t == len(rewards) - 1:
            next_value = 0.0
        else:
            next_value = values[t + 1]
        
        delta = rewards[t] + gamma * next_value - values[t]
        gae = delta + gamma * lambda_gae * gae
        advantages.insert(0, gae)
    
    return advantages


def compute_fused_advantage(
    A_R: List[float],
    A_C: List[float],
    nu: float,
    lambda_norm: float = TrainingConfig.LAMBDA_NORM
) -> List[float]:
    """Compute fused advantage for policy update.
    
    Formula: A_total = (A_R + ν · A_C) / (1 + λ_norm)
    
    Args:
        A_R: Reward advantages
        A_C: Constraint advantages (0 for forget tasks)
        nu: Lagrange multiplier
        lambda_norm: Normalization term
    
    Returns:
        Fused advantages
    """
    A_total = []
    for a_r, a_c in zip(A_R, A_C):
        fused = (a_r + nu * a_c) / (1.0 + lambda_norm)
        A_total.append(fused)
    return A_total


# Example GAE computation
test_rewards = [1.0, 2.0, 3.0, 2.0, 1.0]
test_values = [0.5, 1.0, 1.5, 1.0, 0.5]
test_advantages = compute_gae_advantage(test_rewards, test_values)
logger.info(f'GAE advantages: {[f"{a:.3f}" for a in test_advantages]}')

2025-12-29 23:25:44,725 - INFO - GAE advantages: ['7.665', '6.565', '4.338', '1.965', '0.500']


### 6.5 PPO Update Functions

Implementation of clipped surrogate objective and critic updates.

In [84]:
def ppo_policy_loss(
    policy_net: PolicyNetwork,
    states: torch.Tensor,
    actions: dict,
    old_log_probs: torch.Tensor,
    advantages: torch.Tensor,
    epsilon: float = TrainingConfig.EPSILON
) -> torch.Tensor:
    """Compute PPO clipped surrogate objective loss.
    
    Formula: L = E[min(r_t(θ)·A, clip(r_t(θ), 1-ε, 1+ε)·A)]
    
    Args:
        policy_net: Policy network
        states: State features [batch, state_dim+1]
        actions: Dictionary of action components
        old_log_probs: Old log probabilities [batch]
        advantages: Fused advantages [batch]
        epsilon: Clip ratio
    
    Returns:
        Negative surrogate loss (for minimization)
    """
    # Get current log probabilities
    new_log_probs = policy_net.get_action_log_probs(states, actions)
    
    # Compute ratio r_t(θ) = π_new / π_old
    ratio = torch.exp(new_log_probs - old_log_probs)
    
    # Clipped surrogate
    surr1 = ratio * advantages
    surr2 = torch.clamp(ratio, 1.0 - epsilon, 1.0 + epsilon) * advantages
    
    # PPO objective (maximize, so negate for loss)
    policy_loss = -torch.min(surr1, surr2).mean()
    
    return policy_loss


def critic_loss(
    critic: nn.Module,
    states: torch.Tensor,
    returns: torch.Tensor
) -> torch.Tensor:
    """Compute MSE loss for critic.
    
    Formula: L = E[(V(s) - R̂)²]
    
    Args:
        critic: Critic network
        states: State features [batch, state_dim+1]
        returns: Target returns [batch]
    
    Returns:
        MSE loss
    """
    values = critic(states).squeeze(-1)
    loss = F.mse_loss(values, returns)
    return loss


def update_networks(
    policy_net: PolicyNetwork,
    reward_critic: RewardCritic,
    constraint_critic: ConstraintCritic,
    trajectories: List[Trajectory],
    lagrange_multiplier: LagrangeMultiplier,
    policy_optimizer: torch.optim.Optimizer,
    reward_critic_optimizer: torch.optim.Optimizer,
    constraint_critic_optimizer: torch.optim.Optimizer
) -> dict:
    """Perform one update step for all networks.
    
    Args:
        policy_net: Policy network
        reward_critic: Reward value network
        constraint_critic: Constraint value network
        trajectories: Batch of trajectories
        lagrange_multiplier: Lagrange multiplier ν
        policy_optimizer: Optimizer for policy
        reward_critic_optimizer: Optimizer for reward critic
        constraint_critic_optimizer: Optimizer for constraint critic
    
    Returns:
        Dictionary of loss values
    """
    # Collect all data from trajectories
    all_states = []
    all_actions = {'k_ratio': [], 'w_recall': [], 'w_score': [], 'a_cot': []}
    all_old_log_probs = []
    all_rewards = []
    all_constraints = []
    
    for traj in trajectories:
        all_states.extend(traj.states)
        for action in traj.actions:
            all_actions['k_ratio'].append(action['k_ratio'])
            all_actions['w_recall'].append(action['w_recall'])
            all_actions['w_score'].append(action['w_score'])
            all_actions['a_cot'].append(action['a_cot'])
        all_old_log_probs.extend(traj.log_probs)
        all_rewards.extend(traj.rewards)
        all_constraints.extend(traj.constraint_values)
    
    # Convert to tensors
    states_tensor = torch.stack(all_states)
    actions_tensor = {
        'k_ratio': torch.tensor(all_actions['k_ratio'], dtype=torch.float32),
        'w_recall': torch.stack([torch.tensor(w, dtype=torch.float32) for w in all_actions['w_recall']]),
        'w_score': torch.stack([torch.tensor(w, dtype=torch.float32) for w in all_actions['w_score']]),
        'a_cot': torch.tensor(all_actions['a_cot'], dtype=torch.float32)
    }
    old_log_probs_tensor = torch.tensor(all_old_log_probs, dtype=torch.float32)
    
    # Compute value estimates
    with torch.no_grad():
        reward_values = reward_critic(states_tensor).squeeze(-1).tolist()
        constraint_values = constraint_critic(states_tensor).squeeze(-1).tolist()
    
    # Compute advantages using GAE
    A_R = compute_gae_advantage(all_rewards, reward_values)
    A_C = compute_gae_advantage(all_constraints, constraint_values)
    
    # Compute fused advantage
    nu = lagrange_multiplier.get_value()
    A_total = compute_fused_advantage(A_R, A_C, nu)
    advantages_tensor = torch.tensor(A_total, dtype=torch.float32)
    
    # Step 2: Update Policy (Primal)
    policy_optimizer.zero_grad()
    p_loss = ppo_policy_loss(policy_net, states_tensor, actions_tensor, old_log_probs_tensor, advantages_tensor)
    p_loss.backward()
    policy_optimizer.step()
    
    # Update Critics
    returns_R = torch.tensor([sum(all_rewards[i:]) for i in range(len(all_rewards))], dtype=torch.float32)
    returns_C = torch.tensor([sum(all_constraints[i:]) for i in range(len(all_constraints))], dtype=torch.float32)
    
    reward_critic_optimizer.zero_grad()
    r_loss = critic_loss(reward_critic, states_tensor, returns_R)
    r_loss.backward()
    reward_critic_optimizer.step()
    
    constraint_critic_optimizer.zero_grad()
    c_loss = critic_loss(constraint_critic, states_tensor, returns_C)
    c_loss.backward()
    constraint_critic_optimizer.step()
    
    # Step 3: Update Lagrange Multiplier (Dual)
    J_C_current = sum(all_constraints) / len(all_constraints) if all_constraints else 0.0
    lagrange_multiplier.update(J_C_current, TrainingConfig.MU_RETAIN)
    
    return {
        'policy_loss': p_loss.item(),
        'reward_critic_loss': r_loss.item(),
        'constraint_critic_loss': c_loss.item(),
        'lagrange_nu': lagrange_multiplier.get_value(),
        'J_C_avg': J_C_current
    }


# Example: Initialize optimizers
policy_optimizer = torch.optim.Adam(policy_net.parameters(), lr=TrainingConfig.LR_POLICY)
reward_critic_optimizer = torch.optim.Adam(reward_critic.parameters(), lr=TrainingConfig.LR_REWARD_CRITIC)
constraint_critic_optimizer = torch.optim.Adam(constraint_critic.parameters(), lr=TrainingConfig.LR_CONSTRAINT_CRITIC)

logger.info('Optimizers initialized for Lagrangian PPO training')

2025-12-29 23:25:44,770 - INFO - Optimizers initialized for Lagrangian PPO training


### 6.6 Complete Training Loop

End-to-end training algorithm with trajectory collection and updates.

In [85]:
def collect_trajectory(
    query: str,
    scenario: str,
    policy_net: PolicyNetwork,
    reward_critic: RewardCritic,
    constraint_critic: ConstraintCritic,
    M_retain: ExampleLibrary,
    M_safety: ExampleLibrary,
    M_augment: ExampleLibrary,
    embedding_model,
    ground_truth: Optional[str] = None,
    forbidden_keywords: Optional[List[str]] = None
) -> Trajectory:
    """Collect a single trajectory by executing the pipeline.
    
    Args:
        query: User query
        scenario: 'forget' or 'retain'
        policy_net: Policy network
        reward_critic, constraint_critic: Value networks
        M_retain, M_safety, M_augment: Libraries
        embedding_model: Sentence transformer
        ground_truth: Expected answer (for retain)
        forbidden_keywords: Keywords to check (for forget)
    
    Returns:
        Trajectory object with states, actions, rewards
    """
    # Build state
    state = build_state(query, embedding_model)
    state_tensor = encode_state_for_policy(state)
    
    # Get action from policy
    with torch.no_grad():
        action = policy_net(state_tensor)
        log_prob = policy_net.get_action_log_probs(
            state_tensor,
            {
                'k_ratio': torch.tensor([action.k_ratio]),
                'w_recall': torch.tensor([action.w_recall]),
                'w_score': torch.tensor([action.w_score]),
                'a_cot': torch.tensor([action.a_cot])
            }
        ).item()
    
    # Execute pipeline (simplified - would normally generate response with LLM)
    # For now, simulate response based on scenario
    if scenario == 'forget':
        response = "I don't have information about that."
        K_dynamic = compute_K_dynamic(action.k_ratio)
        num_examples = min(20, K_dynamic)  # Simplified
        generated_length = 10
    else:
        response = ground_truth or "Answer generated"
        K_dynamic = compute_K_dynamic(action.k_ratio)
        num_examples = min(15, K_dynamic)
        generated_length = 15
    
    # Compute reward
    metrics = compute_final_reward(
        scenario=scenario,
        response=response,
        ground_truth=ground_truth,
        forbidden_keywords=forbidden_keywords,
        K_dynamic=K_dynamic,
        num_examples=num_examples,
        generated_length=generated_length,
        U_0=state.U_0
    )
    
    # Compute constraint value (only for retain tasks)
    if scenario == 'retain':
        constraint_val = -compute_nll(ground_truth or "", response)
    else:
        constraint_val = 0.0
    
    # Package into trajectory
    action_dict = {
        'k_ratio': action.k_ratio,
        'w_recall': action.w_recall,
        'w_score': action.w_score,
        'a_cot': action.a_cot
    }
    
    trajectory = Trajectory(
        states=[state_tensor],
        actions=[action_dict],
        rewards=[metrics.R_final],
        constraint_values=[constraint_val],
        log_probs=[log_prob],
        scenario=scenario
    )
    
    return trajectory


def train_lagrangian_ppo(
    policy_net: PolicyNetwork,
    reward_critic: RewardCritic,
    constraint_critic: ConstraintCritic,
    lagrange_multiplier: LagrangeMultiplier,
    forget_queries: List[Tuple[str, List[str]]],  # (query, forbidden_keywords)
    retain_queries: List[Tuple[str, str]],  # (query, ground_truth)
    M_retain: ExampleLibrary,
    M_safety: ExampleLibrary,
    M_augment: ExampleLibrary,
    embedding_model,
    num_iterations: int = 10,
    batch_size: int = 4
) -> List[dict]:
    """Complete Lagrangian PPO training loop.
    
    Args:
        policy_net, reward_critic, constraint_critic: Networks
        lagrange_multiplier: Lagrange multiplier
        forget_queries: List of forget queries with keywords
        retain_queries: List of retain queries with answers
        M_retain, M_safety, M_augment: Libraries
        embedding_model: Sentence transformer
        num_iterations: Number of training iterations
        batch_size: Trajectories per iteration
    
    Returns:
        List of training metrics per iteration
    """
    policy_optimizer = torch.optim.Adam(policy_net.parameters(), lr=TrainingConfig.LR_POLICY)
    reward_critic_optimizer = torch.optim.Adam(reward_critic.parameters(), lr=TrainingConfig.LR_REWARD_CRITIC)
    constraint_critic_optimizer = torch.optim.Adam(constraint_critic.parameters(), lr=TrainingConfig.LR_CONSTRAINT_CRITIC)
    
    training_history = []
    
    logger.info('=== STARTING LAGRANGIAN PPO TRAINING ===')
    
    for iteration in range(num_iterations):
        # Collect batch of trajectories
        trajectories = []
        
        # Mix of forget and retain tasks
        for i in range(batch_size):
            if i % 2 == 0 and forget_queries:
                # Forget task
                query, keywords = forget_queries[i % len(forget_queries)]
                traj = collect_trajectory(
                    query, 'forget', policy_net, reward_critic, constraint_critic,
                    M_retain, M_safety, M_augment, embedding_model,
                    forbidden_keywords=keywords
                )
            elif retain_queries:
                # Retain task
                query, answer = retain_queries[i % len(retain_queries)]
                traj = collect_trajectory(
                    query, 'retain', policy_net, reward_critic, constraint_critic,
                    M_retain, M_safety, M_augment, embedding_model,
                    ground_truth=answer
                )
            else:
                continue
            
            trajectories.append(traj)
        
        # Update networks
        metrics = update_networks(
            policy_net, reward_critic, constraint_critic, lagrange_multiplier,
            policy_optimizer, reward_critic_optimizer, constraint_critic_optimizer,
            trajectories
        )
        
        training_history.append(metrics)
        
        # Log progress
        if (iteration + 1) % 5 == 0:
            logger.info(f'Iter {iteration+1}/{num_iterations}: '
                       f'Policy Loss={metrics["policy_loss"]:.4f}, '
                       f'ν={metrics["lagrange_nu"]:.4f}, '
                       f'J_C={metrics["J_C_avg"]:.4f}')
    
    logger.info('=== TRAINING COMPLETE ===')
    return training_history


# Example training setup (demo with small batch)
demo_forget_queries = [
    ("Who is the author of 'The Midnight Garden'?", ["author", "name"]),
    ("Tell me about the life of Jane Doe.", ["jane doe", "biography"])
]

demo_retain_queries = [
    ("What is 2+2?", "4"),
    ("What is the capital of France?", "Paris")
]

logger.info('Training setup ready. Run train_lagrangian_ppo() to start training.')
logger.info(f'Demo queries: {len(demo_forget_queries)} forget, {len(demo_retain_queries)} retain')

2025-12-29 23:25:44,824 - INFO - Training setup ready. Run train_lagrangian_ppo() to start training.
2025-12-29 23:25:44,827 - INFO - Demo queries: 2 forget, 2 retain
