In [None]:
!pip uninstall -y transformers accelerate peft bitsandbytes datasets trl scipy triton

!pip install --upgrade transformers==4.41.2 -q
!pip install --upgrade peft==0.11.1 -q
!pip install --upgrade accelerate==0.30.1 -q
!pip install bitsandbytes -q
!pip install --upgrade datasets==2.19.1 -q
!pip install --upgrade trl==0.8.6 -q
!pip install --upgrade scipy -q
!pip install --upgrade triton -q
!pip install -q torch
!pip install -q sentencepiece
!pip install -q einops

# Install bitsandbytes from alternative source for better compatibility with colab
!pip install bitsandbytes --prefer-binary --extra-index-url=https://jllllll.github.io/bitsandbytes-windows-webui
# Restart runtime to apply changes
import os
os.kill(os.getpid(), 9)

Found existing installation: transformers 4.52.4
Uninstalling transformers-4.52.4:
  Successfully uninstalled transformers-4.52.4
Found existing installation: accelerate 1.8.1
Uninstalling accelerate-1.8.1:
  Successfully uninstalled accelerate-1.8.1
Found existing installation: peft 0.15.2
Uninstalling peft-0.15.2:
  Successfully uninstalled peft-0.15.2
[0mFound existing installation: datasets 3.6.0
Uninstalling datasets-3.6.0:
  Successfully uninstalled datasets-3.6.0
[0mFound existing installation: scipy 1.15.3
Uninstalling scipy-1.15.3:
  Successfully uninstalled scipy-1.15.3
Found existing installation: triton 3.2.0
Uninstalling triton-3.2.0:
  Successfully uninstalled triton-3.2.0
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.8/43.8 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.1/9.1 MB[0m [31m84.9 MB/s[0m eta [36m0:00:00[0m:00:01[0m0:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    get_linear_schedule_with_warmup
)
from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training
from trl import DPOTrainer
import json
import random
import re
from dataclasses import dataclass, field
from typing import List, Dict, Tuple, Optional, Any
import pandas as pd
from tqdm import tqdm
import time
import gc

# Configuration
@dataclass
class ReasonDPOConfig:
    # Model settings
    model_name: str = "HuggingfaceTB/SmolLM2-360M-Instruct"
    ref_model_name: str = None

    # DPO settings
    beta: float = 0.1
    learning_rate: float = 5e-7
    batch_size: int = 32
    gradient_accumulation_steps: int = 2
    num_epochs: int = 1  # Reduced for faster training
    max_length: int = 1024
    max_prompt_length: int = 512

    # LoRA settings
    lora_r: int = 16  # Reduced for fewer parameters
    lora_alpha: int = 64
    lora_dropout: float = 0.05

    # Data settings
    num_train_preferences: int = 200  # Drastically reduced for speed
    num_eval_preferences: int = 40    # Drastically reduced for speed
    responses_per_prompt: int = 4     # Reduced for faster data generation

    # Generation settings
    generation_temperature: float = 0.8
    generation_max_tokens: int = 300

    # Paths
    output_dir: str = "./reason_dpo_output"
    cache_dir: str = "./cache"

config = ReasonDPOConfig()

In [15]:
class AdvancedReasoningTasks:
    """Generate diverse reasoning tasks for DPO training"""
    
    def __init__(self, seed=42):
        random.seed(seed)
        np.random.seed(seed)
        self.task_types = [
            self.generate_logical_deduction,
            self.generate_mathematical_reasoning,
            self.generate_pattern_completion,
            self.generate_causal_reasoning,
            self.generate_analogical_reasoning,
            self.generate_constraint_solving,
            self.generate_counterfactual_reasoning
        ]
    
    def generate_logical_deduction(self):
        """Multi-step logical deduction"""
        premises = [
            ("All {A} are {B}", "All {B} are {C}", "All {A} are {C}"),
            ("No {A} are {B}", "All {C} are {B}", "No {A} are {C}"),
            ("Some {A} are {B}", "All {B} are {C}", "Some {A} are {C}"),
        ]
        
        entities = [
            ("cats", "mammals", "animals"),
            ("birds", "flying creatures", "egg-layers"),
            ("roses", "flowers", "plants"),
            ("diamonds", "gems", "minerals")
        ]
        
        premise_template = random.choice(premises)
        entity_set = random.choice(entities)
        
        premises_text = []
        for i, template in enumerate(premise_template[:-1]):
            filled = template.format(A=entity_set[0], B=entity_set[1], C=entity_set[2])
            premises_text.append(filled)
        
        conclusion = premise_template[-1].format(A=entity_set[0], B=entity_set[1], C=entity_set[2])
        
        # Sometimes make it false
        if random.random() < 0.3:
            conclusion = conclusion.replace("are", "are not") if "are" in conclusion else conclusion.replace("All", "No")
            correct_answer = "No"
        else:
            correct_answer = "Yes"
        
        question = f"Given: {'. '.join(premises_text)}. Can we conclude: {conclusion}?"
        
        return {
            "question": question,
            "answer": correct_answer,
            "type": "logical_deduction",
            "difficulty": "medium"
        }
    
    def generate_mathematical_reasoning(self):
        """Math word problems requiring reasoning"""
        templates = [
            {
                "template": "A train travels {speed1} km/h for {time1} hours, then {speed2} km/h for {time2} hours. What is the average speed?",
                "solution": lambda v: (v['speed1']*v['time1'] + v['speed2']*v['time2'])/(v['time1']+v['time2'])
            },
            {
                "template": "If {workers1} workers can complete a job in {days1} days, how many days will {workers2} workers take?",
                "solution": lambda v: (v['workers1']*v['days1'])/v['workers2']
            },
            {
                "template": "A rectangle has perimeter {perimeter} and length {length}. What is its area?",
                "solution": lambda v: v['length'] * ((v['perimeter']/2) - v['length'])
            }
        ]
        
        problem = random.choice(templates)
        
        # Generate random values
        if "speed" in problem["template"]:
            values = {
                "speed1": random.randint(40, 80),
                "speed2": random.randint(60, 100),
                "time1": random.randint(2, 5),
                "time2": random.randint(1, 4)
            }
        elif "workers" in problem["template"]:
            values = {
                "workers1": random.randint(4, 12),
                "days1": random.randint(10, 30),
                "workers2": random.randint(6, 15)
            }
        else:
            length = random.randint(5, 15)
            width = random.randint(3, 10)
            values = {
                "perimeter": 2 * (length + width),
                "length": length
            }
        
        question = problem["template"].format(**values)
        answer = round(problem["solution"](values), 2)
        
        return {
            "question": question,
            "answer": str(answer),
            "type": "mathematical_reasoning",
            "difficulty": "hard"
        }
    
    def generate_pattern_completion(self):
        """Complex pattern recognition"""
        patterns = [
            {
                "name": "fibonacci",
                "rule": lambda seq: seq[-1] + seq[-2],
                "init": [1, 1]
            },
            {
                "name": "geometric",
                "rule": lambda seq: seq[-1] * 2,
                "init": [2, 4]
            },
            {
                "name": "arithmetic_sequence",
                "rule": lambda seq: seq[-1] + (seq[-1] - seq[-2]),
                "init": [3, 7]
            },
            {
                "name": "squares_plus_n",
                "rule": lambda seq: (len(seq)+1)**2 + (len(seq)+1),
                "init": [2, 6]
            }
        ]
        
        pattern = random.choice(patterns)
        sequence = pattern["init"].copy()
        
        # Generate sequence
        for _ in range(3):
            sequence.append(pattern["rule"](sequence))
        
        # Hide last element
        question = f"What comes next: {', '.join(map(str, sequence[:-1]))}, ?"
        answer = str(sequence[-1])
        
        return {
            "question": question,
            "answer": answer,
            "type": "pattern_completion",
            "pattern": pattern["name"]
        }
    
    def generate_causal_reasoning(self):
        """Cause and effect reasoning"""
        scenarios = [
            {
                "setup": "If the temperature drops below freezing and there's moisture in the air",
                "effect": "frost will form",
                "negation": "frost will not form"
            },
            {
                "setup": "If a plant doesn't get water for several weeks and is in direct sunlight",
                "effect": "the plant will wilt",
                "negation": "the plant will thrive"
            },
            {
                "setup": "If you increase the price of a product significantly and demand is elastic",
                "effect": "sales will decrease",
                "negation": "sales will increase"
            }
        ]
        
        scenario = random.choice(scenarios)
        
        # Create variations
        if random.random() < 0.5:
            question = f"{scenario['setup']}, what will happen?"
            answer = scenario['effect']
        else:
            # Counterfactual
            question = f"{scenario['setup']}, will {scenario['negation']}?"
            answer = "No"
        
        return {
            "question": question,
            "answer": answer,
            "type": "causal_reasoning"
        }
    
    def generate_analogical_reasoning(self):
        """Analogies and relationships"""
        analogies = [
            ("cat", "kitten", "dog", "puppy"),
            ("book", "page", "house", "room"),
            ("teacher", "student", "doctor", "patient"),
            ("key", "lock", "password", "account"),
            ("seed", "tree", "egg", "bird")
        ]
        
        analogy = random.choice(analogies)
        question = f"{analogy[0]} is to {analogy[1]} as {analogy[2]} is to ?"
        answer = analogy[3]
        
        return {
            "question": question,
            "answer": answer,
            "type": "analogical_reasoning"
        }
    
    def generate_constraint_solving(self):
        """Constraint satisfaction problems"""
        people = ["Alice", "Bob", "Charlie", "Diana", "Eve"]
        days = ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday"]
        
        # Select subset
        selected_people = random.sample(people, 3)
        selected_days = random.sample(days, 3)
        
        # Create constraints
        constraints = []
        solution = {}
        
        # Assign randomly first
        random.shuffle(selected_days)
        for i, person in enumerate(selected_people):
            solution[person] = selected_days[i]
        
        # Generate constraints that match solution
        person1, person2 = random.sample(selected_people, 2)
        day_index1 = selected_days.index(solution[person1])
        day_index2 = selected_days.index(solution[person2])
        
        if day_index1 < day_index2:
            constraints.append(f"{person1} must be before {person2}")
        else:
            constraints.append(f"{person2} must be before {person1}")
        
        # Add another constraint
        person = random.choice(selected_people)
        constraints.append(f"{person} is on {solution[person]}")
        
        question = f"Schedule {', '.join(selected_people)} on {', '.join(selected_days)}. Constraints: {'. '.join(constraints)}. When is {random.choice(selected_people)} scheduled?"
        answer = solution[random.choice(selected_people)]
        
        return {
            "question": question,
            "answer": answer,
            "type": "constraint_solving"
        }
    
    def generate_counterfactual_reasoning(self):
        """What-if scenarios"""
        scenarios = [
            {
                "fact": "The meeting was cancelled because the presenter was sick",
                "counterfactual": "If the presenter hadn't been sick",
                "result": "the meeting would have proceeded as scheduled"
            },
            {
                "fact": "The plant died because it wasn't watered",
                "counterfactual": "If the plant had been watered regularly",
                "result": "the plant would have survived"
            },
            {
                "fact": "The team lost because their best player was injured",
                "counterfactual": "If their best player hadn't been injured",
                "result": "the team might have won"
            }
        ]
        
        scenario = random.choice(scenarios)
        question = f"{scenario['fact']}. {scenario['counterfactual']}, what would have happened?"
        answer = scenario['result']
        
        return {
            "question": question,
            "answer": answer,
            "type": "counterfactual_reasoning"
        }
    
    def generate_task(self):
        """Generate a random task"""
        return random.choice(self.task_types)()

In [16]:
class PreferenceDataGenerator:
    """Generate preference pairs for DPO training using local model"""
    
    def __init__(self, model, tokenizer, task_generator, config):
        self.model = model
        self.tokenizer = tokenizer
        self.task_generator = task_generator
        self.config = config
        
        self.system_prompt = """You are a reasoning assistant. When solving problems:
1. Think step-by-step in <think> tags
2. Show your work clearly
3. Provide the final answer in <answer> tags

Format:
<think>
[Step-by-step reasoning]
</think>
<answer>
[Final answer]
</answer>"""
    
    def extract_answer(self, response: str) -> str:
        """Extract answer from response"""
        match = re.search(r"<answer>(.*?)</answer>", response, re.DOTALL)
        if match:
            return match.group(1).strip()
        return ""
    
    def score_response(self, task: Dict, response: str) -> Dict[str, float]:
        """Score a response on multiple criteria"""
        scores = {}
        
        # 1. Correctness (most important)
        answer = self.extract_answer(response)
        correct_answer = task["answer"].lower().strip()
        extracted_answer = answer.lower().strip()
        
        # Flexible matching for different answer types
        if task["type"] == "mathematical_reasoning":
            try:
                scores["correctness"] = 1.0 if abs(float(extracted_answer) - float(correct_answer)) < 0.01 else 0.0
            except:
                scores["correctness"] = 0.0
        else:
            scores["correctness"] = 1.0 if extracted_answer == correct_answer else 0.0
        
        # 2. Format compliance
        has_think = "<think>" in response and "</think>" in response
        has_answer = "<answer>" in response and "</answer>" in response
        proper_order = response.find("<think>") < response.find("</think>") < response.find("<answer>") < response.find("</answer>") if has_think and has_answer else False
        scores["format"] = 1.0 if proper_order else 0.0
        
        # 3. Reasoning quality (heuristic)
        think_match = re.search(r"<think>(.*?)</think>", response, re.DOTALL)
        if think_match:
            reasoning = think_match.group(1)
            # Check for step-by-step indicators
            step_indicators = ["first", "second", "then", "therefore", "because", "step", "next"]
            step_count = sum(1 for indicator in step_indicators if indicator in reasoning.lower())
            scores["reasoning_quality"] = min(step_count / 3.0, 1.0)  # Normalize
        else:
            scores["reasoning_quality"] = 0.0
        
        # 4. Length penalty (prefer concise but complete)
        response_length = len(response.split())
        if response_length < 20:
            scores["length"] = 0.5  # Too short
        elif response_length > 200:
            scores["length"] = 0.7  # Too long
        else:
            scores["length"] = 1.0  # Just right
        
        # Calculate total score
        weights = {
            "correctness": 0.5,
            "format": 0.2,
            "reasoning_quality": 0.2,
            "length": 0.1
        }
        
        scores["total"] = sum(scores[k] * weights[k] for k in weights.keys())
        
        return scores
    
    def generate_response_local(self, task: Dict) -> str:
        """Generate response using local model"""
        messages = [
            {"role": "system", "content": self.system_prompt},
            {"role": "user", "content": task["question"]}
        ]
        
        inputs = self.tokenizer.apply_chat_template(
            messages,
            tokenize=True,
            return_tensors="pt",
            add_generation_prompt=True
        ).to(self.model.device)
        
        with torch.no_grad():
            outputs = self.model.generate(
                inputs,
                max_new_tokens=self.config.generation_max_tokens,
                temperature=self.config.generation_temperature,
                top_p=0.9,
                do_sample=True,
                pad_token_id=self.tokenizer.eos_token_id
            )
        
        response = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
        return response
    
    def create_preference_pairs(self, task: Dict, responses: List[str]) -> List[Dict]:
        """Create preference pairs from responses"""
        # Score all responses
        scored_responses = []
        for response in responses:
            scores = self.score_response(task, response)
            scored_responses.append({
                "response": response,
                "scores": scores,
                "total_score": scores["total"]
            })
        
        # Sort by total score
        scored_responses.sort(key=lambda x: x["total_score"], reverse=True)
        
        # Create preference pairs
        preference_pairs = []
        
        # Pair top half with bottom half for preferences
        if len(scored_responses) >= 2:
            for i in range(len(scored_responses) // 2):
                for j in range(len(scored_responses) // 2, len(scored_responses)):
                    if scored_responses[i]["total_score"] > scored_responses[j]["total_score"]:
                        preference_pairs.append({
                            "prompt": task["question"],
                            "chosen": scored_responses[i]["response"],
                            "rejected": scored_responses[j]["response"],
                            "chosen_score": scored_responses[i]["total_score"],
                            "rejected_score": scored_responses[j]["total_score"],
                            "task_type": task["type"],
                            "metadata": {
                                "correct_answer": task["answer"],
                                "chosen_scores": scored_responses[i]["scores"],
                                "rejected_scores": scored_responses[j]["scores"]
                            }
                        })
        
        return preference_pairs
    
    def generate_preference_dataset(self, num_tasks: int) -> List[Dict]:
        """Generate complete preference dataset"""
        all_preferences = []
        
        print(f"Generating preference data for {num_tasks} tasks...")
        
        for i in tqdm(range(num_tasks)):
            # Generate task
            task = self.task_generator.generate_task()
            
            # Generate responses using local model
            responses = [self.generate_response_local(task) for _ in range(self.config.responses_per_prompt)]
            
            # Create preference pairs
            pairs = self.create_preference_pairs(task, responses)
            all_preferences.extend(pairs)
            
            # Clear memory periodically
            if i % 50 == 0:
                torch.cuda.empty_cache()
                gc.collect()
        
        print(f"Generated {len(all_preferences)} preference pairs from {num_tasks} tasks")
        
        return all_preferences

In [17]:
class ReasoningPreferenceDataset(Dataset):
    """Dataset for DPO training"""
    
    def __init__(self, preferences: List[Dict], tokenizer, max_length: int = 512):
        self.preferences = preferences
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        # System prompt (same for all)
        self.system_prompt = """You are a reasoning assistant. When solving problems:
1. Think step-by-step in <think> tags
2. Show your work clearly
3. Provide the final answer in <answer> tags"""
    
    def __len__(self):
        return len(self.preferences)
    
    def __getitem__(self, idx):
        preference = self.preferences[idx]
        
        # Format prompt
        prompt = [
            {"role": "system", "content": self.system_prompt},
            {"role": "user", "content": preference["prompt"]}
        ]
        
        # Tokenize prompt
        prompt_tokens = self.tokenizer.apply_chat_template(
            prompt,
            tokenize=True,
            add_generation_prompt=True
        )
        
        # Tokenize chosen response
        chosen_tokens = self.tokenizer.encode(
            preference["chosen"],
            add_special_tokens=False
        )
        
        # Tokenize rejected response
        rejected_tokens = self.tokenizer.encode(
            preference["rejected"],
            add_special_tokens=False
        )
        
        # Combine and pad
        chosen_input_ids = prompt_tokens + chosen_tokens
        rejected_input_ids = prompt_tokens + rejected_tokens
        
        # Truncate if needed
        if len(chosen_input_ids) > self.max_length:
            chosen_input_ids = chosen_input_ids[:self.max_length]
        if len(rejected_input_ids) > self.max_length:
            rejected_input_ids = rejected_input_ids[:self.max_length]
        
        return {
            "prompt": preference["prompt"],
            "chosen": preference["chosen"],
            "rejected": preference["rejected"],
            "chosen_input_ids": chosen_input_ids,
            "rejected_input_ids": rejected_input_ids,
            "prompt_length": len(prompt_tokens),
            "metadata": preference.get("metadata", {})
        }

In [18]:
class ReasonDPOTrainer:
    """Custom DPO trainer with reasoning-specific features"""
    
    def __init__(
        self,
        model,
        ref_model,
        tokenizer,
        train_dataset,
        eval_dataset,
        config: ReasonDPOConfig
    ):
        self.model = model
        self.ref_model = ref_model
        self.tokenizer = tokenizer
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.config = config
        
        # Setup optimizer
        self.optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=config.learning_rate,
            betas=(0.9, 0.999),
            weight_decay=0.01
        )
        
        # Setup scheduler
        total_steps = len(train_dataset) * config.num_epochs // (config.batch_size * config.gradient_accumulation_steps)
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=int(0.1 * total_steps),
            num_training_steps=total_steps
        )
        
        # Metrics
        self.train_losses = []
        self.eval_metrics = []
    
    def compute_dpo_loss(self, batch):
        """Compute DPO loss for a batch"""
        # Get model outputs for chosen and rejected
        chosen_outputs = self.model(
            input_ids=batch["chosen_input_ids"],
            attention_mask=batch["chosen_attention_mask"]
        )
        rejected_outputs = self.model(
            input_ids=batch["rejected_input_ids"],
            attention_mask=batch["rejected_attention_mask"]
        )
        
        # Get reference model outputs
        with torch.no_grad():
            ref_chosen_outputs = self.ref_model(
                input_ids=batch["chosen_input_ids"],
                attention_mask=batch["chosen_attention_mask"]
            )
            ref_rejected_outputs = self.ref_model(
                input_ids=batch["rejected_input_ids"],
                attention_mask=batch["rejected_attention_mask"]
            )
        
        # Calculate log probabilities
        chosen_logprobs = self.get_batch_logprobs(
            chosen_outputs.logits,
            batch["chosen_input_ids"],
            batch["chosen_attention_mask"],
            batch["prompt_length"]
        )
        
        rejected_logprobs = self.get_batch_logprobs(
            rejected_outputs.logits,
            batch["rejected_input_ids"],
            batch["rejected_attention_mask"],
            batch["prompt_length"]
        )
        
        ref_chosen_logprobs = self.get_batch_logprobs(
            ref_chosen_outputs.logits,
            batch["chosen_input_ids"],
            batch["chosen_attention_mask"],
            batch["prompt_length"]
        )
        
        ref_rejected_logprobs = self.get_batch_logprobs(
            ref_rejected_outputs.logits,
            batch["rejected_input_ids"],
            batch["rejected_attention_mask"],
            batch["prompt_length"]
        )
        
        # DPO loss
        chosen_rewards = self.config.beta * (chosen_logprobs - ref_chosen_logprobs)
        rejected_rewards = self.config.beta * (rejected_logprobs - ref_rejected_logprobs)
        
        loss = -F.logsigmoid(chosen_rewards - rejected_rewards).mean()
        
        # Metrics
        metrics = {
            "loss": loss.item(),
            "chosen_rewards": chosen_rewards.mean().item(),
            "rejected_rewards": rejected_rewards.mean().item(),
            "reward_margin": (chosen_rewards - rejected_rewards).mean().item(),
            "reward_accuracy": ((chosen_rewards > rejected_rewards).float().mean().item())
        }
        
        return loss, metrics
    
    def get_batch_logprobs(self, logits, input_ids, attention_mask, prompt_lengths):
        """Calculate log probabilities for responses"""
        batch_size = logits.shape[0]
        logprobs = []
        
        for i in range(batch_size):
            # Get response portion
            prompt_len = prompt_lengths[i]
            response_logits = logits[i, prompt_len-1:-1]
            response_ids = input_ids[i, prompt_len:]
            response_mask = attention_mask[i, prompt_len:]
            
            # Calculate log probs
            log_probs = F.log_softmax(response_logits, dim=-1)
            selected_log_probs = log_probs.gather(
                dim=-1,
                index=response_ids.unsqueeze(-1)
            ).squeeze(-1)
            
            # Mask and sum
            masked_log_probs = selected_log_probs * response_mask
            total_log_prob = masked_log_probs.sum() / response_mask.sum()
            
            logprobs.append(total_log_prob)
        
        return torch.stack(logprobs)
    
    def train_epoch(self, epoch):
        """Train for one epoch"""
        self.model.train()
        epoch_losses = []
        epoch_metrics = {
            "reward_margin": [],
            "reward_accuracy": []
        }
        
        # Create dataloader
        train_dataloader = DataLoader(
            self.train_dataset,
            batch_size=self.config.batch_size,
            shuffle=True,
            collate_fn=self.collate_fn
        )
        
        pbar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}")
        
        for step, batch in enumerate(pbar):
            # Move to device
            batch = {k: v.to(self.model.device) if torch.is_tensor(v) else v for k, v in batch.items()}
            
            # Forward pass
            loss, metrics = self.compute_dpo_loss(batch)
            
            # Backward pass
            loss = loss / self.config.gradient_accumulation_steps
            loss.backward()
            
            # Update weights
            if (step + 1) % self.config.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                self.optimizer.step()
                self.scheduler.step()
                self.optimizer.zero_grad()
            
            # Track metrics
            epoch_losses.append(metrics["loss"])
            epoch_metrics["reward_margin"].append(metrics["reward_margin"])
            epoch_metrics["reward_accuracy"].append(metrics["reward_accuracy"])
            
            # Update progress bar
            pbar.set_postfix({
                "loss": f"{metrics['loss']:.4f}",
                "margin": f"{metrics['reward_margin']:.3f}",
                "acc": f"{metrics['reward_accuracy']:.2%}"
            })
        
        # Epoch summary
        avg_loss = np.mean(epoch_losses)
        avg_margin = np.mean(epoch_metrics["reward_margin"])
        avg_accuracy = np.mean(epoch_metrics["reward_accuracy"])
        
        print(f"\nEpoch {epoch+1} Summary:")
        print(f"  Average Loss: {avg_loss:.4f}")
        print(f"  Average Reward Margin: {avg_margin:.3f}")
        print(f"  Average Reward Accuracy: {avg_accuracy:.2%}")
        
        return avg_loss, avg_margin, avg_accuracy
    
    def evaluate(self):
        """Evaluate on validation set"""
        self.model.eval()
        eval_losses = []
        eval_margins = []
        eval_accuracies = []
        
        eval_dataloader = DataLoader(
            self.eval_dataset,
            batch_size=self.config.batch_size,
            shuffle=False,
            collate_fn=self.collate_fn
        )
        
        with torch.no_grad():
            for batch in tqdm(eval_dataloader, desc="Evaluating"):
                batch = {k: v.to(self.model.device) if torch.is_tensor(v) else v for k, v in batch.items()}
                
                loss, metrics = self.compute_dpo_loss(batch)
                
                eval_losses.append(metrics["loss"])
                eval_margins.append(metrics["reward_margin"])
                eval_accuracies.append(metrics["reward_accuracy"])
        
        avg_eval_loss = np.mean(eval_losses)
        avg_eval_margin = np.mean(eval_margins)
        avg_eval_accuracy = np.mean(eval_accuracies)
        
        print(f"\nEvaluation Results:")
        print(f"  Loss: {avg_eval_loss:.4f}")
        print(f"  Reward Margin: {avg_eval_margin:.3f}")
        print(f"  Reward Accuracy: {avg_eval_accuracy:.2%}")
        
        return avg_eval_loss, avg_eval_margin, avg_eval_accuracy
    
    def collate_fn(self, batch):
        """Custom collate function for DPO"""
        # Pad sequences
        chosen_input_ids = [item["chosen_input_ids"] for item in batch]
        rejected_input_ids = [item["rejected_input_ids"] for item in batch]
        prompt_lengths = [item["prompt_length"] for item in batch]
        
        # Pad
        chosen_padded = self.tokenizer.pad(
            {"input_ids": chosen_input_ids},
            padding=True,
            return_tensors="pt"
        )
        
        rejected_padded = self.tokenizer.pad(
            {"input_ids": rejected_input_ids},
            padding=True,
            return_tensors="pt"
        )
        
        return {
            "chosen_input_ids": chosen_padded["input_ids"],
            "chosen_attention_mask": chosen_padded["attention_mask"],
            "rejected_input_ids": rejected_padded["input_ids"],
            "rejected_attention_mask": rejected_padded["attention_mask"],
            "prompt_length": torch.tensor(prompt_lengths)
        }
    
    def train(self):
        """Main training loop"""
        print("🚀 Starting DPO Training")
        
        best_eval_loss = float('inf')
        
        for epoch in range(self.config.num_epochs):
            # Train
            train_loss, train_margin, train_acc = self.train_epoch(epoch)
            
            # Evaluate
            eval_loss, eval_margin, eval_acc = self.evaluate()
            
            # Save best model
            if eval_loss < best_eval_loss:
                best_eval_loss = eval_loss
                self.save_model(f"{self.config.output_dir}/best_model")
                print(f"✅ Saved best model with eval loss: {eval_loss:.4f}")
            
            # Save checkpoint
            self.save_model(f"{self.config.output_dir}/checkpoint-epoch-{epoch+1}")
            
            # Clear memory
            torch.cuda.empty_cache()
            gc.collect()
        
        print("\n🎉 Training Complete!")
        
        return self.model
    
    def save_model(self, path):
        """Save model and tokenizer"""
        os.makedirs(path, exist_ok=True)
        self.model.save_pretrained(path)
        self.tokenizer.save_pretrained(path)

In [19]:
def train_reason_dpo(config: ReasonDPOConfig):
    """Main training pipeline"""
    
    print("Starting ReasonDPO Training Pipeline")
    
    # Load models
    print("Loading models...")
    
    # Quantization config
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )
    
    # Load model
    model = AutoModelForCausalLM.from_pretrained(
        config.model_name,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True
    )
    
    # Load reference model (frozen)
    ref_model = AutoModelForCausalLM.from_pretrained(
        config.ref_model_name or config.model_name,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True
    )
    ref_model.eval()
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"  # Important for batch generation
    
    # Prepare model for training
    model = prepare_model_for_kbit_training(model)
    model.gradient_checkpointing_enable()
    
    # Add LoRA
    peft_config = LoraConfig(
        r=config.lora_r,
        lora_alpha=config.lora_alpha,
        lora_dropout=config.lora_dropout,
        bias="none",
        task_type=TaskType.CAUSAL_LM,
        target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
    )
    
    model = get_peft_model(model, peft_config)
    trainable_params, all_params = model.get_nb_trainable_parameters()
    print(f"Trainable parameters: {trainable_params:,} ({100 * trainable_params / all_params:.2f}%)")
    
    # Generate preference data
    print("Generating preference data...")
    task_generator = AdvancedReasoningTasks()
    preference_generator = PreferenceDataGenerator(model, tokenizer, task_generator, config)
    
    # Generate datasets
    train_preferences = preference_generator.generate_preference_dataset(config.num_train_preferences)
    eval_preferences = preference_generator.generate_preference_dataset(config.num_eval_preferences)
    
    # Save preference data
    os.makedirs(config.output_dir, exist_ok=True)
    with open(f"{config.output_dir}/train_preferences.json", "w") as f:
        json.dump(train_preferences, f, indent=2)
    with open(f"{config.output_dir}/eval_preferences.json", "w") as f:
        json.dump(eval_preferences, f, indent=2)
    
    # Create datasets
    train_dataset = ReasoningPreferenceDataset(train_preferences, tokenizer, config.max_length)
    eval_dataset = ReasoningPreferenceDataset(eval_preferences, tokenizer, config.max_length)
    
    print(f"Train dataset size: {len(train_dataset)}")
    print(f"Eval dataset size: {len(eval_dataset)}")
    
    # Training arguments
    training_args = TrainingArguments(
        output_dir=config.output_dir,
        per_device_train_batch_size=config.batch_size,
        per_device_eval_batch_size=config.batch_size,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        learning_rate=config.learning_rate,
        num_train_epochs=config.num_epochs,
        bf16=True,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        logging_steps=10,
        optim="adamw_torch",
        gradient_checkpointing=True,
        warmup_steps=100,
        report_to="none",  # Can add wandb if needed
    )
    
    # Create TRL DPO trainer
    trainer = DPOTrainer(
        model=model,
        ref_model=ref_model,
        args=training_args,
        beta=config.beta,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        max_length=config.max_length,
        max_prompt_length=config.max_prompt_length,
        peft_config=peft_config,
    )
    
    # Train
    trainer.train()
    
    return model, tokenizer

In [20]:
def evaluate_reasoning_capability(model, tokenizer, num_samples=50):
    """Comprehensive evaluation of reasoning capabilities"""
    
    print("Evaluating Reasoning Capabilities...")
    
    task_generator = AdvancedReasoningTasks()
    results = {task_type: {"correct": 0, "total": 0} for task_type in [
        "logical_deduction", "mathematical_reasoning", "pattern_completion",
        "causal_reasoning", "analogical_reasoning", "constraint_solving",
        "counterfactual_reasoning"
    ]}
    
    system_prompt = """You are a reasoning assistant. When solving problems:
1. Think step-by-step in <think> tags
2. Show your work clearly
3. Provide the final answer in <answer> tags"""
    
    for _ in tqdm(range(num_samples), desc="Evaluating"):
        # Generate task
        task = task_generator.generate_task()
        task_type = task["type"]
        
        # Generate response
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": task["question"]}
        ]
        
        inputs = tokenizer.apply_chat_template(
            messages,
            tokenize=True,
            return_tensors="pt",
            add_generation_prompt=True
        ).to(model.device)
        
        with torch.no_grad():
            outputs = model.generate(
                inputs,
                max_new_tokens=300,
                temperature=0.3,  # Lower for evaluation
                top_p=0.9,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id
            )
        
        response = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
        
        # Extract and check answer
        answer_match = re.search(r"<answer>(.*?)</answer>", response, re.DOTALL)
        if answer_match:
            generated_answer = answer_match.group(1).strip()
            correct_answer = task["answer"]
            
            # Check correctness
            if task_type == "mathematical_reasoning":
                try:
                    is_correct = abs(float(generated_answer) - float(correct_answer)) < 0.01
                except:
                    is_correct = False
            else:
                is_correct = generated_answer.lower() == correct_answer.lower()
            
            if is_correct:
                results[task_type]["correct"] += 1
        
        results[task_type]["total"] += 1
    
    # Calculate accuracies
    print("Results by Task Type:")
    overall_correct = 0
    overall_total = 0
    
    for task_type, stats in results.items():
        if stats["total"] > 0:
            accuracy = stats["correct"] / stats["total"]
            print(f"  {task_type}: {accuracy:.2%} ({stats['correct']}/{stats['total']})")
            overall_correct += stats["correct"]
            overall_total += stats["total"]
    
    overall_accuracy = overall_correct / overall_total if overall_total > 0 else 0
    print(f"Overall Accuracy: {overall_accuracy:.2%}")
    
    return results, overall_accuracy

In [21]:
def online_dpo_iteration(model, tokenizer, config, iteration=1):
    """One iteration of online DPO for self-improvement"""
    
    print(f"Online DPO Iteration {iteration}")
    
    # Use current model to generate new preference data
    task_generator = AdvancedReasoningTasks()
    preference_generator = PreferenceDataGenerator(model, tokenizer, task_generator, config)
    
    # Generate smaller dataset for online learning
    online_config = ReasonDPOConfig(
        num_train_preferences=1000,
        num_eval_preferences=200,
        num_epochs=1,
        learning_rate=config.learning_rate * 0.5  # Lower LR for fine-tuning
    )
    
    # Generate new preferences with current model
    train_preferences = preference_generator.generate_preference_dataset(online_config.num_train_preferences)
    eval_preferences = preference_generator.generate_preference_dataset(online_config.num_eval_preferences)
    
    # Filter for high-quality pairs (larger score differences)
    train_preferences = [
        p for p in train_preferences 
        if p["chosen_score"] - p["rejected_score"] > 0.3
    ]
    
    print(f"Filtered to {len(train_preferences)} high-quality preference pairs")
    
    # Create datasets
    train_dataset = ReasoningPreferenceDataset(train_preferences, tokenizer, config.max_length)
    eval_dataset = ReasoningPreferenceDataset(eval_preferences, tokenizer, config.max_length)
    
    # Training arguments for online iteration
    training_args = TrainingArguments(
        output_dir=config.output_dir,
        per_device_train_batch_size=config.batch_size,
        per_device_eval_batch_size=config.batch_size,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        learning_rate=online_config.learning_rate,
        num_train_epochs=online_config.num_epochs,
        bf16=True,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        logging_steps=10,
        optim="adamw_torch",
        gradient_checkpointing=True,
        warmup_steps=50,
        report_to="none",
    )
    
    # Create new DPO trainer with current model as both model and reference
    trainer = DPOTrainer(
        model=model,
        ref_model=model,  # Self as reference for online DPO
        args=training_args,
        beta=config.beta,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        max_length=config.max_length,
        max_prompt_length=config.max_prompt_length,
        peft_config=None,  # LoRA already applied
    )
    
    # Train
    trainer.train()
    
    return model

In [23]:
def main():
    """Main execution function"""
    
    # Initial DPO training
    trained_model, tokenizer = train_reason_dpo(config)
    
    # Evaluate initial performance
    print("Initial Model Evaluation:")
    initial_results, initial_accuracy = evaluate_reasoning_capability(
        trained_model, tokenizer, num_samples=100
    )
    
    # Online DPO iterations
    current_model = trained_model
    for i in range(1):  # 3 online iterations
        print(f"Starting Online DPO Iteration {i+1}")
        
        # Self-improvement
        current_model = online_dpo_iteration(
            current_model, tokenizer, config, iteration=i+1
        )
        
        # Evaluate improvement
        results, accuracy = evaluate_reasoning_capability(
            current_model, tokenizer, num_samples=50
        )
        
        print(f"Accuracy after iteration {i+1}: {accuracy:.2%}")
        
        # Save checkpoint
        save_path = f"{config.output_dir}/online_dpo_iteration_{i+1}"
        current_model.save_pretrained(save_path)
        tokenizer.save_pretrained(save_path)
    
    # Final evaluation
    print("Final Model Evaluation:")
    final_results, final_accuracy = evaluate_reasoning_capability(
        current_model, tokenizer, num_samples=100
    )
    
    print("Training Complete!")
    print(f"Initial Accuracy: {initial_accuracy:.2%}")
    print(f"Final Accuracy: {final_accuracy:.2%}")
    print(f"Improvement: {final_accuracy - initial_accuracy:.2%}")
    
    # Save final model
    final_path = f"{config.output_dir}/final_model"
    current_model.save_pretrained(final_path)
    tokenizer.save_pretrained(final_path)
    
    return current_model, tokenizer

# Run training
if __name__ == "__main__":
    main()

Starting ReasonDPO Training Pipeline
Loading models...
Trainable parameters: 8,683,520 (2.34%)
Generating preference data...
Generating preference data for 200 tasks...


100%|██████████| 200/200 [4:34:21<00:00, 82.31s/it]   


Generated 316 preference pairs from 200 tasks
Generating preference data for 40 tasks...


100%|██████████| 40/40 [56:08<00:00, 84.21s/it]  


Generated 87 preference pairs from 40 tasks
Train dataset size: 316
Eval dataset size: 87


ValueError: You passed both a ref_model and a peft_config. For training PEFT adapters with DPO there is no need to pass a reference model. Please pass `ref_model=None` in case you want to train PEFT adapters, or pass a ref_model with `force_use_ref_model=True` in DPOTrainer's init. if you want to use a different ref_model.

In [None]:
def interactive_reasoning_demo(model, tokenizer):
    """Interactive demo for testing the model"""
    
    print("Interactive Reasoning Demo")
    print("Type 'quit' to exit, 'examples' for sample questions")
    
    examples = [
        "What comes next in the sequence: 3, 8, 15, 24, 35, ?",
        "If all birds can fly and penguins are birds, can penguins fly? What's wrong with this logic?",
        "A train travels 60 km/h for 2 hours, then 80 km/h for 3 hours. What's the average speed?",
        "If the meeting was cancelled because of rain, what would have happened if it hadn't rained?",
        "Book is to page as house is to ?",
        "Arrange Alice, Bob, Charlie so that Alice is before Bob and Charlie is not last. List all valid arrangements."
    ]
    
    system_prompt = """You are a reasoning assistant. When solving problems:
1. Think step-by-step in <think> tags
2. Show your work clearly
3. Provide the final answer in <answer> tags"""
    
    while True:
        user_input = input("Enter your reasoning question: ")
        
        if user_input.lower() == 'quit':
            break
        elif user_input.lower() == 'examples':
            print("Example questions:")
            for i, ex in enumerate(examples, 1):
                print(f"{i}. {ex}")
            continue
        
        # Generate response
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_input}
        ]
        
        inputs = tokenizer.apply_chat_template(
            messages,
            tokenize=True,
            return_tensors="pt",
            add_generation_prompt=True
        ).to(model.device)
        
        print("Thinking...")
        
        with torch.no_grad():
            outputs = model.generate(
                inputs,
                max_new_tokens=400,
                temperature=0.7,
                top_p=0.9,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id
            )
        
        response = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
        
        # Parse and display response
        think_match = re.search(r"<think>(.*?)</think>", response, re.DOTALL)
        answer_match = re.search(r"<answer>(.*?)</answer>", response, re.DOTALL)
        
        if think_match:
            print("Reasoning Process:")
            print(think_match.group(1).strip())
        
        if answer_match:
            print("Answer:")
            print(answer_match.group(1).strip())
        
        if not think_match and not answer_match:
            print("Response:")
            print(response)

# Run demo with trained model (comment out if not needed)
# interactive_reasoning_demo(trained_model, tokenizer)