In [None]:
# 1. Install Dependencies
!pip install transformers datasets peft trl torch accelerate bert_score

In [None]:
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead
from peft import LoraConfig, TaskType, get_peft_model
import numpy as np
from tqdm import tqdm
import os

os.environ["WANDB_PROJECT"] = "unbiased-news-summarizer"

## Stage 1: Supervised Fine-Tuning (SFT)
First, we train the model to be a good summarizer using the Multi-News dataset.

In [None]:

from datasets import Dataset
# Load Dataset

DOC_COL = "document"
REF_COL = "summary"
dataset_train = load_dataset("Awesome075/multi_news_parquet", split="train").to_pandas()
dataset_test = load_dataset("Awesome075/multi_news_parquet", split="test").to_pandas()
dataset_val = load_dataset("Awesome075/multi_news_parquet", split="validation").to_pandas()
dataset_train[DOC_COL].replace('', np.nan, inplace=True)
dataset_train[REF_COL].replace('', np.nan, inplace=True)
dataset_train.dropna(inplace=True)

dataset_train = Dataset.from_pandas(dataset_train)
dataset_val = Dataset.from_pandas(dataset_val)
dataset_test = Dataset.from_pandas(dataset_test)

In [None]:
# Model & Tokenizer (Example: BART-large)
model_name = "facebook/bart-large-cnn"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

# Preprocessing function
def preprocess_function(examples):
    inputs = [doc for doc in examples["document"]]
    model_inputs = tokenizer(inputs, max_length=1024, truncation=True)
    labels = tokenizer(text_target=examples["summary"], max_length=128, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_train = dataset_train.map(preprocess_function, batched=True)
tokenized_val = dataset_val.map(preprocess_function, batched=True)

# Define Trainer (Standard HuggingFace Trainer)
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq

# Initialize the Data Collator explicitly
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

training_args = Seq2SeqTrainingArguments(
    output_dir="./sft_summarizer",
    eval_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=3,
    predict_with_generate=True,
    fp16=True,
    report_to="wandb",
    run_name="sft_summarizer_run",
    logging_steps=500,
    logging_strategy="steps",
    save_strategy="steps", # Explicitly set save strategy
    save_steps=5000,       # Save every 1000 steps (less frequent)
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

In [None]:
trainer.train()
trainer.save_model("./sft_summarizer_final")

In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM
import torch
import numpy as np

# 2. Reward Model Wrapper
class NeutralityRewardModel(torch.nn.Module):
    def __init__(self, reward_model_name, policy_tokenizer, device):
        super().__init__()
        self.reward_model = AutoModelForSequenceClassification.from_pretrained(reward_model_name).to(device)
        self.reward_tokenizer = AutoTokenizer.from_pretrained(reward_model_name)
        self.policy_tokenizer = policy_tokenizer
        self.device = device

    def forward(self, input_ids, attention_mask=None, **kwargs):
        texts = self.policy_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
        inputs = self.reward_tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512).to(self.device)
        with torch.no_grad():
            outputs = self.reward_model(**inputs)
            probs = torch.softmax(outputs.logits, dim=-1)
        k = min(3, probs.shape[-1])
        top_probs, _ = torch.topk(probs, k, dim=-1)
        avg_top_probs = top_probs.mean(dim=-1)
        rewards = (1.0 - avg_top_probs) ** 2
        return rewards

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
reward_model = NeutralityRewardModel("maximuspowers/bias-type-classifier", tokenizer, device=0)

## Stage 2: DPO Optimization

In [None]:
# ==========================================
# STRATEGY 2: Offline Preference Learning (DPO)
# ==========================================

# STEP 1: Generate Preference Dataset
# We generate 2 responses for each prompt, score them, and create (chosen, rejected) pairs.

from datasets import Dataset
from tqdm import tqdm

# 1. Load Models
sft_model_path = "./unbiased_summarizer_dpo_final_1000"

# Policy Model (The actor)
policy_model = AutoModelForSeq2SeqLM.from_pretrained(sft_model_path)
tokenizer = AutoTokenizer.from_pretrained(sft_model_path)


def create_preference_dataset(model, tokenizer, reward_model, source_dataset, num_samples=100, device="cuda"):
    """
    Generates a dataset for DPO training.
    Returns a HuggingFace Dataset with columns: ['prompt', 'chosen', 'rejected']
    """
    model.eval()
    data_rows = []
    
    print(f"Generating {num_samples} preference pairs...")
    
    # Iterate through the source dataset
    # Assuming source_dataset has a 'document' column
    for i in tqdm(range(min(num_samples, len(source_dataset)))):
        doc = source_dataset[i]['document']
        
        # 1. Generate 2 responses
        inputs = tokenizer(doc, return_tensors="pt", max_length=1024, truncation=True).to(device)
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=128,
                min_length=30,
                do_sample=True,
                num_beams=1,
                top_p=0.95, # Slightly higher temp/p for diversity
                num_return_sequences=2,
                early_stopping=False
            )
        
        print("scoring")
        # 2. Score them
        rewards = reward_model(outputs) # [2]
        scores = rewards.tolist()
        
        decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        
        # 3. Determine Winner
        if scores[0] > scores[1]:
            chosen = decoded[0]
            rejected = decoded[1]
            margin = scores[0] - scores[1]
        else:
            chosen = decoded[1]
            rejected = decoded[0]
            margin = scores[1] - scores[0]
            
        # Only keep pairs where there is a meaningful difference? 
        # For now, we keep all to maximize data, could filter by margin > 0.01
        data_rows.append({
            "prompt": doc,
            "chosen": chosen,
            "rejected": rejected,
            "score_margin": margin
        })

        # Save every 1000 iterations
        if (i + 1) % 1000 == 0:
            df = pd.DataFrame(data_rows)
            df.to_csv(save_path, index=False)
            print(f"Saved {len(data_rows)} pairs to {save_path}")
            
    # Final save
    if len(data_rows) > 0:
        df = pd.DataFrame(data_rows)
        df.to_csv(save_path, index=False)
        print(f"Final save: {len(data_rows)} pairs to {save_path}")
        
    return Dataset.from_list(data_rows)

# Generate the dataset (using a small subset for demonstration)
# We use the 'dataset_train' loaded earlier
preference_dataset = create_preference_dataset(
    policy_model, 
    tokenizer, 
    reward_model, 
    dataset_train, 
    num_samples=10000, # Adjust this number based on time constraints
    device=policy_model.device
)

print(f"Generated {len(preference_dataset)} pairs.")
print("Sample Pair:")
print(preference_dataset[0])

In [None]:
# STEP 2: Train with DPO
from trl import DPOTrainer, DPOConfig
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

# RELOAD MODEL: Ensure we have a fresh, trainable copy of the SFT model
# This prevents the "None of the inputs have requires_grad=True" error
print("Reloading SFT model for DPO...")
sft_model_path = "./sft_summarizer_final"
policy_model = AutoModelForSeq2SeqLM.from_pretrained(sft_model_path)
tokenizer = AutoTokenizer.from_pretrained(sft_model_path)

policy_model.config.decoder_start_token_id = tokenizer.eos_token_id
policy_model.config.pad_token_id = tokenizer.pad_token_id
policy_model.generation_config.decoder_start_token_id = tokenizer.eos_token_id
policy_model.generation_config.pad_token_id = tokenizer.pad_token_id

# We need a fresh copy of the model for DPO to avoid PPO artifacts if any
# Or we can continue fine-tuning the SFT model.
# Ideally, DPO requires a reference model (the SFT model) and a policy model (initialized from SFT).
num_samples = 1000

print("Setting up DPO Trainer...")

# 1. Config
dpo_config = DPOConfig(
    output_dir="./unbiased_summarizer_dpo",
    learning_rate=5e-6,             # Low learning rate for DPO
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    num_train_epochs=20,             # Short training for demo
    beta=0.1,                       # The beta parameter for DPO (KL penalty)
    logging_steps=10,
    save_strategy="no",
    remove_unused_columns=False,
)

# 2. Initialize Trainer
# Note: DPOTrainer automatically creates a reference model copy if not provided,
# but providing it explicitly is safer if memory allows.
# Since we are in a notebook, we might want to rely on the implicit copy or free memory first.

dpo_trainer = DPOTrainer(
    model=policy_model,             # The model to train
    ref_model=None,                 # TRL will create a copy of 'model' as reference
    args=dpo_config,
    train_dataset=preference_dataset,
    processing_class=tokenizer,
    # DPO expects specific column names, which we matched in creation
)

print("Starting DPO Training...")
dpo_trainer.train()
if hasattr(dpo_trainer.model, "generation_config"):
    dpo_trainer.model.generation_config.length_penalty = 1.0

dpo_trainer.save_model(f"./unbiased_summarizer_dpo_final_{num_samples}")
print("DPO Training Complete.")

In [None]:
# ==========================================
# STRATEGY 1: Online "Best-of-N" Inference
# ==========================================
# This strategy generates N candidates at inference time and selects the one 
# with the highest neutrality score.

# 1. Load Models
sft_model_path = "./sft_summarizer_final"

# Policy Model (The actor)
policy_model = AutoModelForSeq2SeqLM.from_pretrained(sft_model_path)
tokenizer = AutoTokenizer.from_pretrained(sft_model_path)

def generate_best_of_n(model, tokenizer, reward_model, input_text, n=4, device="cuda"):
    model.eval()
    
    # 1. Prepare Input
    inputs = tokenizer(input_text, return_tensors="pt", max_length=1024, truncation=True).to(device)
    
    # 2. Generate N candidates
    # We use sampling to ensure diversity among the N candidates
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=128,
            min_length=30,
            do_sample=True,
            top_p=0.9,
            num_beams=1,
            num_return_sequences=n,
            early_stopping=False
        )
    
    # 3. Decode candidates
    candidates = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    
    # 4. Score candidates
    # We need to tokenize candidates for the reward model
    # The reward model expects input_ids. We'll use the reward_tokenizer inside the wrapper if accessible,
    # or just re-use the main tokenizer if they are compatible.
    # Based on previous cells, reward_model is an instance of NeutralityRewardModel.
    
    candidate_scores = []
    for cand in candidates:
        # The NeutralityRewardModel expects input_ids of the *summary* (or text, depending on implementation).
        # Looking at the implementation in Cell 8, it takes input_ids, decodes them, then re-tokenizes for the classifier.
        # So we can pass the generated output ids directly.
        pass
    
    # Let's use the reward model's internal logic directly to be safe
    # We'll pass the raw output IDs to the reward model's forward method
    # The reward model forward expects a batch.
    
    with torch.no_grad():
        # outputs is [n, seq_len]
        rewards = reward_model(outputs) # Returns tensor of shape [n]
        
    best_idx = torch.argmax(rewards).item()
    best_candidate = candidates[best_idx]
    best_score = rewards[best_idx].item()
    
    return {
        "best_summary": best_candidate,
        "best_score": best_score,
        "all_candidates": candidates,
        "all_scores": rewards.tolist()
    }

# --- Test Strategy 1 ---
sample_text = """
The controversial bill was passed yesterday amid fierce protests. Critics argue it undermines democracy, while supporters claim it is necessary for national security. The opposition leader called it a "dark day," whereas the Prime Minister hailed it as a "historic victory."
"""

result = generate_best_of_n(policy_model, tokenizer, reward_model, sample_text, n=4, device=policy_model.device)

print(f"--- Best of 4 Selection (Score: {result['best_score']:.4f}) ---")
print(result['best_summary'])
print("\n--- All Candidates ---")
for i, (cand, score) in enumerate(zip(result['all_candidates'], result['all_scores'])):
    print(f"[{i+1}] Score: {score:.4f} | {cand}")

## (Not Used) Attempt: Bias Mitigation with PPO
We use the SFT model and fine-tune it to maximize the neutrality score.

In [None]:
from trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHead
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM
import torch
import numpy as np

# 1. Load Models
sft_model_path = "./sft_summarizer_final"

# Policy Model (The actor)
policy_model = AutoModelForSeq2SeqLM.from_pretrained(sft_model_path)
tokenizer = AutoTokenizer.from_pretrained(sft_model_path)

# Reference Model (For KL divergence)
ref_model = AutoModelForSeq2SeqLM.from_pretrained(sft_model_path)

# Value Model (Critic)
value_model = AutoModelForSequenceClassification.from_pretrained(
    sft_model_path, 
    num_labels=1
)

# --- FIX GENERATION CONFIG ---
# Explicitly update generation config to prevent "min_length > max_length" errors
policy_model.generation_config.min_length = 30
policy_model.generation_config.max_length = 1024 
policy_model.generation_config.max_new_tokens = 128
policy_model.generation_config.do_sample = True
policy_model.generation_config.top_p = 0.9
policy_model.generation_config.num_beams = 1
policy_model.generation_config.eos_token_id = tokenizer.eos_token_id
policy_model.generation_config.decoder_start_token_id = tokenizer.eos_token_id
policy_model.generation_config.forced_bos_token_id = 0
policy_model.generation_config.early_stopping = False
policy_model.config.use_cache = False

# Sync ref_model config
ref_model.generation_config = policy_model.generation_config
ref_model.config.use_cache = False



# 3. PPO Configuration
config = PPOConfig(
    output_dir="./unbiased_summarizer_ppo",
    learning_rate=1.41e-5,
    batch_size=4,
    mini_batch_size=1,
    gradient_accumulation_steps=4,
    num_ppo_epochs=4,
    seed=42,
    # CRITICAL: Ensure generation parameters are set to prevent empty responses
    response_length=128, # Max length of generated response
    stop_token_id=tokenizer.eos_token_id,
)

# 4. Prepare Dataset
def tokenize(sample):
    # Ensure truncation and padding are applied correctly
    # Truncate to 512 to ensure efficiency and avoid position embedding limits
    tokenized = tokenizer(sample["document"], truncation=True, max_length=512, padding="max_length")
    sample["input_ids"] = tokenized["input_ids"]
    sample["attention_mask"] = tokenized["attention_mask"]
    return sample

# Apply tokenization
ppo_dataset = dataset_train.map(tokenize, batched=False)

# Remove non-tensor columns
cols_to_remove = [col for col in ["document", "summary", "id", "__index_level_0__"] if col in ppo_dataset.column_names]
ppo_dataset = ppo_dataset.remove_columns(cols_to_remove)

# Set format for PyTorch
ppo_dataset.set_format(type="torch", columns=["input_ids", "attention_mask"])

# 5. Initialize PPOTrainer
# We need to pass generation_kwargs to ensure valid generation
generation_kwargs = {
    "min_length": 30,
    "max_new_tokens": 128,
    "do_sample": True,
    "top_p": 0.9,
    "num_beams": 1,
    "eos_token_id": tokenizer.eos_token_id,
    "decoder_start_token_id": tokenizer.eos_token_id, # Important for BART
}

ppo_trainer = PPOTrainer(
    args=config,
    processing_class=tokenizer,
    model=policy_model,
    ref_model=ref_model,
    reward_model=reward_model,
    value_model=value_model,
    train_dataset=ppo_dataset,
)

# Set generation kwargs in the trainer
ppo_trainer.generation_kwargs = generation_kwargs

In [None]:
   # Debug: Try generation with Sampling + No Repeat N-Grams (PPO Style)
sample_text = """
    Artificial intelligence (AI) is intelligence demonstrated by machines, as opposed to the natural intelligence displayed by animals including humans. AI research has been defined as the field of study of intelligent agents, which refers to any system that perceives its environment and takes actions that maximize its chance of achieving its goals.
    The term "artificial intelligence" had previously been used to describe machines that mimic and display "human" cognitive skills that are associated with the human mind, such as "learning" and "problem-solving". This definition has since been rejected by major AI researchers who now describe AI in terms of rationality and acting rationally, which does not limit how intelligence can be articulated.
    """

inputs = tokenizer(sample_text, return_tensors="pt", max_length=1024, truncation=True).to(device)

print("\n--- Attempt 2: Sampling with Constraints (PPO Style) ---")
summary_ids = policy_model.generate(
    inputs["input_ids"], 
    max_length=100, 
    min_length=30, 
    do_sample=True,              # Enable Sampling
    num_beams=1,                 # Explicitly set to 1 for sampling
    top_p=0.9,                   # Nucleus sampling
    no_repeat_ngram_size=3,      # Prevent repetition
    early_stopping=False         # Disable early_stopping for sampling
)
print(f"IDs: {summary_ids[0].tolist()}")
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)

print("-" * 30)
print("Generated Summary:")
print(summary)
print("-" * 30)

In [None]:
import trl.trainer.ppo_trainer
from transformers import GenerationConfig, BartForConditionalGeneration, BartForSequenceClassification, BartModel
from transformers.modeling_outputs import Seq2SeqLMOutput, Seq2SeqModelOutput
from trl import AutoModelForSeq2SeqLMWithValueHead, PPOTrainer
import torch
import types

# --- FIX 7: MONKEY PATCH get_reward ---
# TRL's get_reward tries to run the reward model backbone on the full sequence (640 tokens).
# Our reward model (BERT) only supports 512.
# We patch get_reward to use the custom forward method of NeutralityRewardModel, which handles truncation.
print("Patching get_reward to handle NeutralityRewardModel truncation...")

if not hasattr(trl.trainer.ppo_trainer, "original_get_reward"):
    trl.trainer.ppo_trainer.original_get_reward = trl.trainer.ppo_trainer.get_reward

def custom_get_reward(model, query_responses, pad_token_id, context_length):
    # Unwrap model to check type (in case of DataParallel/Accelerator wrapping)
    unwrapped = model
    while hasattr(unwrapped, "module"):
        unwrapped = unwrapped.module
        
    if unwrapped.__class__.__name__ == "NeutralityRewardModel":
        # Custom handling for our reward model
        device = next(unwrapped.parameters()).device
        query_responses = query_responses.to(device)
        
        with torch.no_grad():
            # The forward method of NeutralityRewardModel handles decoding and truncation
            # It returns a tensor of rewards
            scores = unwrapped(input_ids=query_responses)
        
        # get_reward returns (values, score, sequence_lengths)
        # We only need score for the reward model
        return None, scores, None
        
    return trl.trainer.ppo_trainer.original_get_reward(model, query_responses, pad_token_id, context_length)

trl.trainer.ppo_trainer.get_reward = custom_get_reward

# --- FIX 5: MONKEY PATCH Seq2Seq Output Classes ---
# TRL expects 'hidden_states' but Seq2Seq models return 'decoder_hidden_states'
print("Patching Seq2Seq output classes to expose hidden_states...")
def get_hidden_states(self):
    return self.decoder_hidden_states

if not hasattr(Seq2SeqLMOutput, "hidden_states"):
    Seq2SeqLMOutput.hidden_states = property(get_hidden_states)

if not hasattr(Seq2SeqModelOutput, "hidden_states"):
    Seq2SeqModelOutput.hidden_states = property(get_hidden_states)

# --- FIX 4: MONKEY PATCH AutoModelForSeq2SeqLMWithValueHead ---
print("Patching AutoModelForSeq2SeqLMWithValueHead to expose base_model_prefix, model, and score...")
if not hasattr(AutoModelForSeq2SeqLMWithValueHead, "base_model_prefix"):
    def get_base_model_prefix(self):
        return self.pretrained_model.base_model_prefix
    AutoModelForSeq2SeqLMWithValueHead.base_model_prefix = property(get_base_model_prefix)

if not hasattr(AutoModelForSeq2SeqLMWithValueHead, "model"):
    def get_model(self):
        if hasattr(self.pretrained_model, "model"):
            return self.pretrained_model.model
        return self.pretrained_model
    AutoModelForSeq2SeqLMWithValueHead.model = property(get_model)

if not hasattr(AutoModelForSeq2SeqLMWithValueHead, "score"):
    def get_score(self):
        if hasattr(self, "v_head"):
            return self.v_head
        if hasattr(self.pretrained_model, "score"):
            return self.pretrained_model.score
        return None
    AutoModelForSeq2SeqLMWithValueHead.score = property(get_score)

# --- FIX 6: PATCH REWARD MODEL ---
# The custom NeutralityRewardModel doesn't have base_model_prefix, which TRL might check.
print("Patching Reward Model to expose base_model_prefix and underlying model...")
if "reward_model" in globals():
    # 1. Ensure base_model_prefix exists
    if not hasattr(reward_model, "base_model_prefix"):
        if hasattr(reward_model, "reward_model") and hasattr(reward_model.reward_model, "base_model_prefix"):
             reward_model.base_model_prefix = reward_model.reward_model.base_model_prefix
        else:
             reward_model.base_model_prefix = "model"
    
    # 2. Ensure the attribute named by base_model_prefix exists (e.g. 'bert')
    # TRL tries to access model.bert if base_model_prefix is 'bert'
    prefix = reward_model.base_model_prefix
    RewardModelClass = type(reward_model)
    
    if not hasattr(RewardModelClass, prefix):
        print(f"Patching {RewardModelClass.__name__} to expose '{prefix}'...")
        def get_inner_base_model(self):
            return getattr(self.reward_model, prefix)
        setattr(RewardModelClass, prefix, property(get_inner_base_model))

# --- FIX 3: RE-INITIALIZE VALUE MODEL & TRAINER ---
print("Re-initializing Value Model and PPO Trainer...")

if "sft_model_path" not in globals():
    sft_model_path = "./sft_summarizer_final" 

value_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(sft_model_path)
print(f"[DEBUG] Value Model Type: {type(value_model)}")
print(f"[DEBUG] Value Model Attributes: v_head={hasattr(value_model, 'v_head')}, score={hasattr(value_model, 'score')}")

if hasattr(policy_model, "device"):
    value_model.to(policy_model.device)

ppo_trainer = PPOTrainer(
    args=config,
    processing_class=tokenizer,
    model=policy_model,
    ref_model=ref_model,
    reward_model=reward_model,
    value_model=value_model,
    train_dataset=ppo_dataset,
)
ppo_trainer.generation_kwargs = generation_kwargs
print(f"[DEBUG] PPOTrainer Value Model Type: {type(ppo_trainer.value_model)}")

# --- MONKEY PATCH 1: FORCE GENERATION CONFIG & FIX SEQ2SEQ HANDLING ---
if not hasattr(trl.trainer.ppo_trainer, "original_batch_generation"):
    trl.trainer.ppo_trainer.original_batch_generation = trl.trainer.ppo_trainer.batch_generation

def custom_batch_generation(policy, queries, local_rollout_forward_batch_size, pad_token_id, generation_config):
    # 1. FORCE CONFIGURATION
    new_config = GenerationConfig.from_dict(generation_config.to_dict())
    new_config.min_length = 30
    new_config.max_new_tokens = 128
    new_config.do_sample = True
    new_config.top_p = 0.9
    new_config.num_beams = 1
    new_config.decoder_start_token_id = 2 
    new_config.eos_token_id = 2
    new_config.pad_token_id = 1
    
    # 2. MANUAL GENERATION LOOP
    query_responses = []
    logitss = []
    batch_size = queries.shape[0]
    
    for i in range(0, batch_size, local_rollout_forward_batch_size):
        query = queries[i : i + local_rollout_forward_batch_size]
        
        with torch.no_grad():
            output = policy.generate(
                input_ids=query,
                attention_mask=(query != pad_token_id).long(),
                generation_config=new_config,
                return_dict_in_generate=True,
                output_scores=True,
            )
            
        logits = torch.stack(output.scores, 1)
        
        if policy.config.is_encoder_decoder:
            generated_part = output.sequences
            gen_len = generated_part.shape[1]
            logits_len = logits.shape[1]
            
            if gen_len > logits_len:
                generated_part = generated_part[:, -logits_len:]
            elif logits_len > gen_len:
                logits = logits[:, :gen_len, :]
                
            full_sequence = torch.cat((query, generated_part), dim=1)
        else:
            full_sequence = output.sequences
            
        query_responses.append(full_sequence)
        logitss.append(logits)
        
        if i == 0:
            print(f"\n[DEBUG] Model Type: {'Encoder-Decoder' if policy.config.is_encoder_decoder else 'Decoder-Only'}")
            print(f"[DEBUG] Query Shape: {query.shape}")
            print(f"[DEBUG] Output Sequences Shape: {output.sequences.shape}")
            print(f"[DEBUG] Logits Shape: {logits.shape}")
            print(f"[DEBUG] Full Sequence Shape: {full_sequence.shape}")

    return torch.cat(query_responses, 0), torch.cat(logitss, 0)

trl.trainer.ppo_trainer.batch_generation = custom_batch_generation
print("Monkey-patch applied: Fixed Seq2Seq slicing logic + Enforced Config + Logit/Seq Length Mismatch Fix")

# --- MONKEY PATCH 2: ROBUST INSTANCE-LEVEL FORWARD PATCH ---
def make_safe_forward(original_forward, model_name="Model"):
    def safe_forward(self, *args, **kwargs):
        if "position_ids" in kwargs:
            kwargs.pop("position_ids")
        return original_forward(*args, **kwargs)
    return types.MethodType(safe_forward, original_forward.__self__)

def patch_bart_instances(module, name=""):
    if isinstance(module, (BartForConditionalGeneration, BartForSequenceClassification, BartModel)):
        if not getattr(module, "_forward_is_patched", False):
            print(f"Patching forward method of: {name} ({type(module).__name__})")
            module.forward = make_safe_forward(module.forward, name)
            module._forward_is_patched = True
    
    for child_name, child in module.named_children():
        patch_bart_instances(child, f"{name}.{child_name}")

print("Applying instance-level patches...")
if hasattr(ppo_trainer.model, "pretrained_model"):
    patch_bart_instances(ppo_trainer.model.pretrained_model, "ppo_trainer.model.pretrained_model")
else:
    patch_bart_instances(ppo_trainer.model, "ppo_trainer.model")

patch_bart_instances(ref_model, "ref_model")

if hasattr(value_model, "pretrained_model"):
    patch_bart_instances(value_model.pretrained_model, "value_model.pretrained_model")
else:
    patch_bart_instances(value_model, "value_model")

patch_bart_instances(reward_model.reward_model, "reward_model")

# --- DEBUG: INSPECT MODEL CONFIGS ---
print(f"Policy Model is_encoder_decoder: {policy_model.config.is_encoder_decoder}")

# 5. Start Training
print("Starting PPO Training...")
ppo_trainer.is_encoder_decoder = True
print(f"Forced ppo_trainer.is_encoder_decoder = {ppo_trainer.is_encoder_decoder}")

ppo_trainer.train()
ppo_trainer.save_model("./unbiased_summarizer_ppo_final")