# Lab 3.1.9: KTO Binary Feedback - Solutions

Complete solutions for Kahneman-Tversky Optimization exercises.

## Exercise 1: KTO Loss Implementation

In [None]:
import torch
import torch.nn.functional as F

def kto_loss(
    policy_logps: torch.Tensor,
    reference_logps: torch.Tensor,
    labels: torch.Tensor,  # 1 for desirable, 0 for undesirable
    beta: float = 0.1,
    desirable_weight: float = 1.0,
    undesirable_weight: float = 1.0
) -> tuple:
    """
    KTO Loss based on Prospect Theory.
    
    Key insight: Humans are loss-averse (losses hurt more than gains help).
    
    For desirable (label=1): Maximize log ratio above KL baseline
    For undesirable (label=0): Minimize log ratio below KL baseline
    """
    # Log ratios
    log_ratios = policy_logps - reference_logps
    
    # KL divergence as baseline (estimated from batch)
    kl_baseline = log_ratios.mean().detach()
    
    # Separate desirable and undesirable
    desirable_mask = labels == 1
    undesirable_mask = labels == 0
    
    # Desirable loss: want log_ratio > kl_baseline
    # L_desirable = -log(σ(β * (log_ratio - kl_baseline)))
    desirable_logits = beta * (log_ratios[desirable_mask] - kl_baseline)
    desirable_loss = -F.logsigmoid(desirable_logits).mean() if desirable_mask.any() else torch.tensor(0.0)
    
    # Undesirable loss: want log_ratio < kl_baseline (flip the sign)
    # L_undesirable = -log(σ(-β * (log_ratio - kl_baseline)))
    undesirable_logits = -beta * (log_ratios[undesirable_mask] - kl_baseline)
    undesirable_loss = -F.logsigmoid(undesirable_logits).mean() if undesirable_mask.any() else torch.tensor(0.0)
    
    # Weighted combination
    total_loss = desirable_weight * desirable_loss + undesirable_weight * undesirable_loss
    
    # Metrics
    desirable_accuracy = (desirable_logits > 0).float().mean().item() if desirable_mask.any() else 0
    undesirable_accuracy = (undesirable_logits > 0).float().mean().item() if undesirable_mask.any() else 0
    
    return total_loss, {
        "desirable_loss": desirable_loss.item() if isinstance(desirable_loss, torch.Tensor) else 0,
        "undesirable_loss": undesirable_loss.item() if isinstance(undesirable_loss, torch.Tensor) else 0,
        "kl_baseline": kl_baseline.item(),
        "desirable_accuracy": desirable_accuracy,
        "undesirable_accuracy": undesirable_accuracy
    }

# Test
torch.manual_seed(42)
batch_size = 8

policy_logps = torch.randn(batch_size) - 1
reference_logps = torch.randn(batch_size) - 1.2
labels = torch.tensor([1, 1, 1, 1, 0, 0, 0, 0])  # Half desirable, half undesirable

loss, metrics = kto_loss(policy_logps, reference_logps, labels)
print(f"KTO Loss: {loss:.4f}")
print(f"Metrics: {metrics}")

## Exercise 2: Prospect Theory Visualization

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def visualize_prospect_theory():
    """
    Visualize Prospect Theory's value function.
    
    Key insight: Losses loom larger than gains.
    """
    x = np.linspace(-3, 3, 100)
    
    # Standard value function from Prospect Theory
    # v(x) = x^α for gains (x >= 0)
    # v(x) = -λ(-x)^β for losses (x < 0)
    alpha = 0.88
    beta_pt = 0.88
    lambda_loss = 2.25  # Loss aversion coefficient
    
    def prospect_value(x):
        if x >= 0:
            return x ** alpha
        else:
            return -lambda_loss * ((-x) ** beta_pt)
    
    y = np.array([prospect_value(xi) for xi in x])
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Plot 1: Prospect Theory Value Function
    axes[0].plot(x, y, 'b-', linewidth=2, label='Prospect Theory')
    axes[0].plot(x, x, 'k--', alpha=0.5, label='Linear (rational)')
    axes[0].axhline(y=0, color='gray', linestyle='-', alpha=0.3)
    axes[0].axvline(x=0, color='gray', linestyle='-', alpha=0.3)
    axes[0].fill_between(x[x<0], y[x<0], 0, alpha=0.2, color='red', label='Losses')
    axes[0].fill_between(x[x>=0], 0, y[x>=0], alpha=0.2, color='green', label='Gains')
    
    axes[0].set_xlabel('Outcome (gains/losses)')
    axes[0].set_ylabel('Perceived Value')
    axes[0].set_title('Prospect Theory: Losses Loom Larger')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Annotate
    axes[0].annotate('Loss aversion:\nλ = 2.25', xy=(-2, prospect_value(-2)),
                     xytext=(-1, -3), fontsize=10,
                     arrowprops=dict(arrowstyle='->', color='red'))
    
    # Plot 2: KTO Loss curves
    log_ratio_diff = np.linspace(-2, 2, 100)
    beta_kto = 0.1
    
    desirable_loss = -np.log(1 / (1 + np.exp(-beta_kto * log_ratio_diff)))
    undesirable_loss = -np.log(1 / (1 + np.exp(beta_kto * log_ratio_diff)))
    
    axes[1].plot(log_ratio_diff, desirable_loss, 'g-', linewidth=2, label='Desirable (thumbs up)')
    axes[1].plot(log_ratio_diff, undesirable_loss, 'r-', linewidth=2, label='Undesirable (thumbs down)')
    axes[1].axvline(x=0, color='gray', linestyle='--', alpha=0.5)
    
    axes[1].set_xlabel('Log Ratio - KL Baseline')
    axes[1].set_ylabel('KTO Loss')
    axes[1].set_title('KTO Loss Functions')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('prospect_theory.png', dpi=150)
    plt.show()
    
    print("\nKey Insight:")
    print("Humans feel losses ~2.25x more strongly than equivalent gains.")
    print("KTO leverages this: heavily penalize bad outputs, gently reward good ones.")

visualize_prospect_theory()

## Exercise 3: Complete KTO Training Pipeline

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import KTOTrainer, KTOConfig
from datasets import Dataset

def create_kto_pipeline(
    model_id: str = "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    output_dir: str = "./kto-output",
    beta: float = 0.1,
    desirable_weight: float = 1.0,
    undesirable_weight: float = 1.0
):
    """
    Complete KTO training pipeline for binary feedback data.
    """
    print("=" * 60)
    print("KTO TRAINING PIPELINE")
    print("=" * 60)
    
    # 1. Quantization
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    )
    
    # 2. Load model
    print("\n1. Loading model...")
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        quantization_config=bnb_config,
        device_map="auto",
    )
    model = prepare_model_for_kbit_training(model)
    
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.pad_token = tokenizer.eos_token
    
    # 3. LoRA
    print("2. Configuring LoRA...")
    lora_config = LoraConfig(
        r=16,
        lora_alpha=32,
        lora_dropout=0.05,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
        task_type="CAUSAL_LM",
    )
    model = get_peft_model(model, lora_config)
    
    # 4. Binary feedback dataset
    print("3. Preparing binary feedback data...")
    kto_data = [
        # Desirable examples (thumbs up)
        {"prompt": "Explain machine learning.",
         "completion": "Machine learning is a subset of AI where computers learn patterns from data without explicit programming.",
         "label": True},
        {"prompt": "What's the capital of France?",
         "completion": "Paris is the capital of France.",
         "label": True},
        {"prompt": "How do I stay healthy?",
         "completion": "Eat balanced meals, exercise regularly, get enough sleep, and manage stress.",
         "label": True},
        
        # Undesirable examples (thumbs down)
        {"prompt": "Explain machine learning.",
         "completion": "Its computers doing stuff.",
         "label": False},
        {"prompt": "What's the capital of France?",
         "completion": "I don't know.",
         "label": False},
        {"prompt": "How do I stay healthy?",
         "completion": "Just don't get sick lol",
         "label": False},
    ]
    
    dataset = Dataset.from_list(kto_data)
    
    desirable_count = sum(1 for x in kto_data if x["label"])
    undesirable_count = len(kto_data) - desirable_count
    print(f"   Desirable: {desirable_count}, Undesirable: {undesirable_count}")
    
    # 5. Reference model
    print("4. Creating reference model...")
    ref_model = AutoModelForCausalLM.from_pretrained(
        model_id,
        quantization_config=bnb_config,
        device_map="auto",
    )
    
    # 6. KTO training config
    print("5. Setting up KTO training...")
    training_args = KTOConfig(
        output_dir=output_dir,
        num_train_epochs=3,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=2,
        learning_rate=5e-5,
        
        # KTO specific
        beta=beta,
        desirable_weight=desirable_weight,
        undesirable_weight=undesirable_weight,
        
        # Memory
        bf16=True,
        gradient_checkpointing=True,
        
        # Logging
        logging_steps=1,
        report_to="none",
    )
    
    # 7. Create trainer
    trainer = KTOTrainer(
        model=model,
        ref_model=ref_model,
        args=training_args,
        train_dataset=dataset,
        tokenizer=tokenizer,
    )
    
    print(f"\nKTO Configuration:")
    print(f"  β (beta): {beta}")
    print(f"  Desirable weight: {desirable_weight}")
    print(f"  Undesirable weight: {undesirable_weight}")
    print("=" * 60)
    
    return trainer, model, tokenizer

# Uncomment to run
# trainer, model, tokenizer = create_kto_pipeline()
# trainer.train()

## Exercise 4: Collecting Binary Feedback

In [None]:
import json
from typing import List, Dict

class BinaryFeedbackCollector:
    """
    Simple system for collecting thumbs up/down feedback.
    """
    
    def __init__(self, output_file: str = "feedback_data.jsonl"):
        self.output_file = output_file
        self.collected = []
    
    def record_feedback(
        self,
        prompt: str,
        completion: str,
        is_good: bool,
        metadata: dict = None
    ):
        """Record a single feedback instance."""
        entry = {
            "prompt": prompt,
            "completion": completion,
            "label": is_good,
            "metadata": metadata or {}
        }
        self.collected.append(entry)
        
        # Append to file
        with open(self.output_file, "a") as f:
            f.write(json.dumps(entry) + "\n")
    
    def get_stats(self) -> Dict:
        """Get feedback statistics."""
        desirable = sum(1 for x in self.collected if x["label"])
        undesirable = len(self.collected) - desirable
        
        return {
            "total": len(self.collected),
            "desirable": desirable,
            "undesirable": undesirable,
            "ratio": desirable / max(undesirable, 1)
        }
    
    def load_existing(self) -> List[Dict]:
        """Load existing feedback from file."""
        try:
            with open(self.output_file, "r") as f:
                self.collected = [json.loads(line) for line in f]
        except FileNotFoundError:
            self.collected = []
        return self.collected
    
    def to_kto_format(self) -> List[Dict]:
        """Convert to KTO training format."""
        return [
            {
                "prompt": x["prompt"],
                "completion": x["completion"],
                "label": x["label"]
            }
            for x in self.collected
        ]

# Demo usage
collector = BinaryFeedbackCollector()

# Simulate collecting feedback
collector.record_feedback(
    prompt="What is Python?",
    completion="Python is a high-level programming language known for its readability.",
    is_good=True,
    metadata={"user_id": "demo", "timestamp": "2024-01-01"}
)

collector.record_feedback(
    prompt="What is Python?",
    completion="A snake.",
    is_good=False,
    metadata={"user_id": "demo", "timestamp": "2024-01-01"}
)

print("Feedback Stats:", collector.get_stats())
print("\nKTO Format:")
print(json.dumps(collector.to_kto_format(), indent=2))

## Exercise 5: When to Use KTO vs DPO

In [None]:
def decision_guide():
    """
    Decision guide for choosing between DPO and KTO.
    """
    print("""
╔══════════════════════════════════════════════════════════════════╗
║                     DPO vs KTO Decision Guide                     ║
╠══════════════════════════════════════════════════════════════════╣
║                                                                   ║
║   Do you have PAIRED preference data?                            ║
║   (same prompt, chosen vs rejected)                              ║
║                                                                   ║
║   YES ──────────────────────┐                                    ║
║                             │                                    ║
║                             ▼                                    ║
║                    ┌─────────────────┐                           ║
║                    │  Use DPO/SimPO  │                           ║
║                    │  (pair-based)   │                           ║
║                    └─────────────────┘                           ║
║                                                                   ║
║   NO (only thumbs up/down) ─┐                                    ║
║                             │                                    ║
║                             ▼                                    ║
║                    ┌─────────────────┐                           ║
║                    │    Use KTO!     │                           ║
║                    │ (binary signal) │                           ║
║                    └─────────────────┘                           ║
║                                                                   ║
╠══════════════════════════════════════════════════════════════════╣
║                                                                   ║
║   KTO is ideal when:                                             ║
║   ✓ You only have thumbs up/down feedback                        ║
║   ✓ Feedback comes from production (not curated)                 ║
║   ✓ You can't pair responses to same prompt                      ║
║   ✓ Data is highly imbalanced (mostly good or mostly bad)        ║
║                                                                   ║
║   DPO is better when:                                            ║
║   ✓ You have A/B comparison data                                 ║
║   ✓ Each prompt has multiple rated responses                     ║
║   ✓ You need strong preference learning                          ║
║                                                                   ║
╚══════════════════════════════════════════════════════════════════╝
    """)
    
    print("\nData format examples:")
    print("\nDPO format (paired):")
    print(json.dumps({
        "prompt": "What is AI?",
        "chosen": "AI is artificial intelligence...",
        "rejected": "dunno"
    }, indent=2))
    
    print("\nKTO format (binary):")
    print(json.dumps({
        "prompt": "What is AI?",
        "completion": "AI is artificial intelligence...",
        "label": True  # or False
    }, indent=2))

decision_guide()

## Key Takeaways

1. **Prospect Theory**: Losses hurt ~2.25x more than gains help
2. **KTO Loss**: Uses KL baseline to separate desirable/undesirable
3. **Binary Feedback**: Just thumbs up/down, no paired comparisons needed
4. **Use Case**: Production feedback, imbalanced data, simple annotation
5. **vs DPO**: KTO for binary, DPO for paired comparisons