# Lab 3.1.9: KTO - Training with Binary Feedback

**Module:** 3.1 - Large Language Model Fine-Tuning  
**Time:** 2 hours  
**Difficulty:** ‚≠ê‚≠ê‚≠ê‚≠ê‚òÜ

---

## Learning Objectives

By the end of this notebook, you will:
- [ ] Understand KTO (Kahneman-Tversky Optimization)
- [ ] Train with binary feedback (thumbs up/down)
- [ ] Know when KTO is better than DPO
- [ ] Implement KTO training with TRL

---

## The Problem: You Don't Always Have Pairs

DPO, SimPO, and ORPO all need **preference pairs**: "This response is better than that one."

But often, you only have **binary feedback**:
- üëç This response is good
- üëé This response is bad

Examples:
- User ratings (helpful/not helpful)
- Flagged content (appropriate/inappropriate)
- Click-through data (engaged/bounced)

**KTO lets you train directly on this simpler feedback!**

---

## ELI5: What is KTO?

> **Imagine you're training a dog.** 
>
> **DPO** is like: "Show two tricks, reward the better one." But what if you only saw one trick at a time?
>
> **KTO** is like: "For each trick, give a thumbs up or thumbs down. The dog learns to do more thumbs-up things and fewer thumbs-down things."
>
> **The clever part:** KTO is based on **Prospect Theory** from behavioral economics (Kahneman & Tversky won a Nobel Prize for this!). It models how humans actually perceive gains and losses:
> - Losses hurt more than equivalent gains feel good
> - This asymmetry is built into the loss function
>
> **Result:** A more human-aligned training objective!

---

## Part 1: The KTO Algorithm

### Key Insight: Prospect Theory

Humans don't perceive gains and losses equally:
- Losing $100 feels worse than winning $100 feels good
- This is called **loss aversion**

KTO incorporates this into training:

$$\mathcal{L}_{KTO} = \mathbb{E}_{(x,y)\sim\mathcal{D}}\left[w(y) \cdot \left(1 - v_{\text{KTO}}(x, y; \beta)\right)\right]$$

Where:
- $w(y)$ is a weighting that makes losses hurt more
- $v_{\text{KTO}}$ is the "value function" inspired by Prospect Theory
- For desirable outputs: maximize probability relative to reference
- For undesirable outputs: minimize probability relative to reference

### No Pairs Needed!

Unlike DPO which needs (prompt, chosen, rejected) triplets, KTO only needs:
- (prompt, response, label) where label ‚àà {desirable, undesirable}

In [None]:
# Setup
import torch
import gc
from typing import Dict, List
import warnings
warnings.filterwarnings('ignore')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import KTOTrainer, KTOConfig
from datasets import Dataset

print("Libraries imported!")

---

## Part 2: Binary Feedback Dataset

KTO uses a simpler format than DPO.

In [None]:
# KTO dataset format: (prompt, completion, label)
# label: True = desirable (üëç), False = undesirable (üëé)

binary_feedback_data = [
    # Good responses (üëç)
    {
        "prompt": "Explain Python decorators.",
        "completion": "Decorators are functions that modify other functions. They wrap a function to add behavior before/after it runs. Use @decorator syntax above a function definition. Common uses: logging, timing, authentication.",
        "label": True
    },
    {
        "prompt": "How do I handle errors in Python?",
        "completion": "Use try/except blocks: wrap risky code in 'try', handle specific exceptions in 'except'. Always catch specific exceptions (not bare 'except'). Use 'finally' for cleanup. Example: try: file = open('data.txt') except FileNotFoundError: print('File missing')",
        "label": True
    },
    {
        "prompt": "What's the difference between a list and tuple?",
        "completion": "Lists are mutable (can change), tuples are immutable (can't change). Lists use [], tuples use (). Tuples are faster and can be dictionary keys. Use lists when you need to modify data, tuples for fixed collections.",
        "label": True
    },
    {
        "prompt": "How do I improve my code quality?",
        "completion": "Write clear names, keep functions small, add tests, use type hints, run linters (like pylint), review your own code, get feedback from others, and refactor regularly. Code is read more than written!",
        "label": True
    },
    
    # Bad responses (üëé)
    {
        "prompt": "Explain Python decorators.",
        "completion": "Decorators are a Python feature.",
        "label": False
    },
    {
        "prompt": "How do I handle errors in Python?",
        "completion": "Just use try except.",
        "label": False
    },
    {
        "prompt": "What's the difference between a list and tuple?",
        "completion": "They're both data structures in Python.",
        "label": False
    },
    {
        "prompt": "How do I improve my code quality?",
        "completion": "Practice more.",
        "label": False
    },
    {
        "prompt": "What is machine learning?",
        "completion": "It's when computers learn from data to make predictions. Like how Netflix learns your preferences to recommend shows you might like.",
        "label": True
    },
    {
        "prompt": "What is machine learning?",
        "completion": "AI stuff.",
        "label": False
    },
]

# Expand the dataset
expanded_data = []
for item in binary_feedback_data:
    expanded_data.append(item)
    # Add variation
    expanded_data.append({
        "prompt": "Please " + item["prompt"].lower(),
        "completion": item["completion"],
        "label": item["label"],
    })

kto_dataset = Dataset.from_list(expanded_data)

# Count labels
n_positive = sum(1 for item in expanded_data if item["label"])
n_negative = len(expanded_data) - n_positive

print(f"KTO Dataset: {len(kto_dataset)} examples")
print(f"  üëç Desirable: {n_positive}")
print(f"  üëé Undesirable: {n_negative}")
print(f"\nNote: No preference pairs needed! Just binary labels.")

---

## Part 3: Load Model for KTO

In [None]:
# Configuration
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

# KTO-specific parameters
KTO_BETA = 0.1  # Controls strength of preference
DESIRABLE_WEIGHT = 1.0  # Weight for positive examples
UNDESIRABLE_WEIGHT = 1.0  # Weight for negative examples

print(f"KTO Configuration:")
print(f"  Model: {MODEL_NAME}")
print(f"  Beta: {KTO_BETA}")

In [None]:
# 4-bit quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

# Load tokenizer and model
print(f"Loading model {MODEL_NAME}...")

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",
)

model = prepare_model_for_kbit_training(model)

# Add LoRA
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, lora_config)

print(f"\nModel loaded!")
model.print_trainable_parameters()
print(f"Memory: {torch.cuda.memory_allocated()/1e9:.2f} GB")

---

## Part 4: KTO Training

In [None]:
# KTO Configuration
kto_config = KTOConfig(
    output_dir="./kto_output",
    
    # KTO-specific
    beta=KTO_BETA,
    desirable_weight=DESIRABLE_WEIGHT,
    undesirable_weight=UNDESIRABLE_WEIGHT,
    
    # Training
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    num_train_epochs=1,
    learning_rate=5e-5,
    
    # Sequence
    max_length=512,
    max_prompt_length=256,
    
    # Optimization
    optim="paged_adamw_8bit",
    bf16=True,
    
    # Logging
    logging_steps=5,
    
    save_strategy="no",
    report_to="none",
    remove_unused_columns=False,
)

print("KTO configuration created!")

In [None]:
# Create KTO trainer
kto_trainer = KTOTrainer(
    model=model,
    ref_model=None,  # TRL creates reference automatically
    args=kto_config,
    train_dataset=kto_dataset,
    tokenizer=tokenizer,
)

print("KTO Trainer created!")
print(f"Memory after setup: {torch.cuda.memory_allocated()/1e9:.2f} GB")

In [None]:
# Train with KTO!
print("="*50)
print("STARTING KTO TRAINING")
print("="*50)
print("\nTraining with binary feedback (thumbs up/down)!")

kto_result = kto_trainer.train()

print("\n" + "="*50)
print("KTO TRAINING COMPLETE!")
print("="*50)

In [None]:
# Print metrics
print("\nKTO Training Metrics:")
for key, value in kto_result.metrics.items():
    if isinstance(value, float):
        print(f"  {key}: {value:.4f}")
    else:
        print(f"  {key}: {value}")

---

## Part 5: When to Use KTO

### Decision Guide

In [None]:
print("""
‚ïî‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïó
‚ïë                           WHEN TO USE KTO                                     ‚ïë
‚ï†‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ï£
‚ïë                                                                              ‚ïë
‚ïë  ‚úÖ USE KTO when you have:                                                   ‚ïë
‚ïë    ‚Ä¢ Binary feedback (helpful/not helpful)                                   ‚ïë
‚ïë    ‚Ä¢ User ratings (thumbs up/down)                                           ‚ïë
‚ïë    ‚Ä¢ Flagged content (appropriate/inappropriate)                             ‚ïë
‚ïë    ‚Ä¢ Click-through data (engaged/bounced)                                    ‚ïë
‚ïë    ‚Ä¢ Any label that's True/False, not comparative                            ‚ïë
‚ïë                                                                              ‚ïë
‚ïë  ‚ùå USE DPO/SimPO/ORPO when you have:                                        ‚ïë
‚ïë    ‚Ä¢ Side-by-side comparisons (A is better than B)                           ‚ïë
‚ïë    ‚Ä¢ Ranked responses (best to worst)                                        ‚ïë
‚ïë    ‚Ä¢ Elo ratings from comparisons                                            ‚ïë
‚ïë                                                                              ‚ïë
‚ï†‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ï£
‚ïë  KTO Advantages:                                                             ‚ïë
‚ïë    ‚Ä¢ Works with simpler data (no pairs needed)                               ‚ïë
‚ïë    ‚Ä¢ Based on human behavioral economics (Prospect Theory)                   ‚ïë
‚ïë    ‚Ä¢ Handles imbalanced positive/negative ratios well                        ‚ïë
‚ïë    ‚Ä¢ Can mix data from different sources more easily                         ‚ïë
‚ïë                                                                              ‚ïë
‚ïë  KTO Considerations:                                                         ‚ïë
‚ïë    ‚Ä¢ May need more data than DPO for same quality                            ‚ïë
‚ïë    ‚Ä¢ Still needs reference model (like DPO)                                  ‚ïë
‚ïë    ‚Ä¢ Newer method, less battle-tested                                        ‚ïë
‚ïö‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïù
""")

---

## Part 6: Converting Between Formats

You can convert preference pairs to binary format (but not vice versa).

In [None]:
def preference_pairs_to_binary(preference_data: List[Dict]) -> List[Dict]:
    """
    Convert DPO-style preference pairs to KTO binary format.
    
    Input: [{"prompt": str, "chosen": str, "rejected": str}, ...]
    Output: [{"prompt": str, "completion": str, "label": bool}, ...]
    """
    binary_data = []
    
    for item in preference_data:
        # Chosen response = desirable
        binary_data.append({
            "prompt": item["prompt"],
            "completion": item["chosen"],
            "label": True,
        })
        
        # Rejected response = undesirable
        binary_data.append({
            "prompt": item["prompt"],
            "completion": item["rejected"],
            "label": False,
        })
    
    return binary_data


# Example
dpo_data = [
    {
        "prompt": "What is Python?",
        "chosen": "Python is a versatile programming language known for readability.",
        "rejected": "It's a coding thing."
    }
]

kto_data = preference_pairs_to_binary(dpo_data)
print("Converted DPO ‚Üí KTO:")
for item in kto_data:
    label = "üëç" if item["label"] else "üëé"
    print(f"  {label} {item['completion'][:50]}...")

---

## Common Mistakes

### Mistake 1: Imbalanced Labels

```python
# Wrong: 90% positive, 10% negative
# Model may ignore negative examples

# Right: Balance with weights or sampling
kto_config = KTOConfig(
    desirable_weight=1.0,
    undesirable_weight=9.0,  # Boost negative importance
    ...
)
```

### Mistake 2: Weak Negative Examples

```python
# Wrong: Negative is just shorter version of positive
{"prompt": "Explain X", "completion": "X is...", "label": True}
{"prompt": "Explain X", "completion": "X.", "label": False}
# Model just learns "longer = better"

# Right: Negative has clear quality issues
{"prompt": "Explain X", "completion": "X is... (helpful detail)", "label": True}
{"prompt": "Explain X", "completion": "I don't know.", "label": False}
# Model learns actual quality differences
```

---

## Checkpoint

You've learned:
- ‚úÖ KTO trains with binary feedback (no pairs needed)
- ‚úÖ It's based on Prospect Theory from behavioral economics
- ‚úÖ When to choose KTO over DPO/SimPO
- ‚úÖ How to implement KTO with TRL

---

## Further Reading

- [KTO Paper](https://arxiv.org/abs/2402.01306) - Kahneman-Tversky Optimization
- [Prospect Theory](https://en.wikipedia.org/wiki/Prospect_theory) - The psychology behind KTO
- [TRL KTO Documentation](https://huggingface.co/docs/trl/kto_trainer)

---

## Cleanup

In [None]:
# Clear memory
del model, kto_trainer
torch.cuda.empty_cache()
gc.collect()

print("Cleanup complete!")

---

## Next Steps

Continue to:

**[Lab 3.1.10: Ollama Integration](lab-3.1.10-ollama-integration.ipynb)** - Deploy your fine-tuned model with Ollama!