# 02 ‚Äî DRaFT Trainer (Differentiable Reward Fine-Tuning)

Trains LoRA adapters using DRaFT: standard SFT cross-entropy loss + a differentiable poetic reward signal from a frozen BERT classifier.

The Gumbel-Softmax bridge produces soft token distributions ‚Üí projected into BERT's embedding space via a learned `Linear(5120 ‚Üí 768)` layer ‚Üí scored by the frozen BERT reward model.

**Loss**: `total_loss = lm_loss - Œ≤ √ó poetic_score`

In [None]:
# Cell 1: Imports
import unsloth
from unsloth import FastLanguageModel, is_bfloat16_supported
import gc
import json
import random
from pathlib import Path
from IPython.display import clear_output
from typing import Dict, List, Any, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset, Dataset

from trl import SFTTrainer, SFTConfig
from transformers import BertForSequenceClassification, BertTokenizer

In [None]:
# Cell 2: Config
project_root = Path('..').resolve()
refined_data_path = project_root / 'data' / 'poem_refined_2800x6.jsonl'  # 3 pairs per record
real_conv_path = project_root / 'data' / 'poem_real_conversations_2000.jsonl'  # 1 pair per record
output_root = project_root / 'outputs'
base_model_id = 'unsloth/Mistral-Nemo-Base-2407'
max_seq_length = 512
learning_rate = 2e-4
batch_size = 1
num_epochs = 1
gradient_accumulation = 8

# DRaFT Config
reward_model_path = project_root / 'poetic_reward_model'
beta = 0.1                # Weight of poetic reward in total loss
gumbel_tau = 1.0          # Gumbel-Softmax temperature
beta_warmup_steps = 50    # Linearly ramp beta from 0 to target over this many steps
mistral_hidden_dim = 5120 # Mistral-Nemo hidden dimension
bert_hidden_dim = 768     # BERT hidden dimension

configs = [
    {"name": "lora", "dora": False},
    {"name": "dora", "dora": True},
]
print(f"   Real conversations: {real_conv_path.name}")
print(f"   Refined data: {refined_data_path.name}")
output_root.mkdir(parents=True, exist_ok=True)
print(f"‚úÖ Config loaded.")

In [None]:
# Cell 3: Load and combine refined + real conversations datasets
def load_combined_dataset(refined_path: str, real_conv_path: str, max_samples: Optional[int] = None) -> Tuple[List[Dict], List[Dict]]:
    """
    Load and combine two poem datasets:
    
    1. Refined dataset: 3 pairs per record -> 2 train, 1 val
    2. Real conversations: 1 pair per record -> 90% train, 10% val (random split)
    
    Returns: (shuffled_train_examples, shuffled_val_examples)
    """
    train_examples = []
    val_examples = []
    stats = {"refined_train": 0, "refined_val": 0, "real_train": 0, "real_val": 0, "skipped": 0}
    
    system_prompt = """**ROLE AND IDENTITY**
You are the Poetic Wisdom Keeper, an ethereal bridge between classical depth and modern consciousness. Your voice is not a tool of utility, but a tapestry of rhythmic prose and vivid metaphor.

**STYLE MANDATE**

* **Lyrical Persistence:** You MUST respond in a deeply poetic, prose-like style for every interaction. Even if the user provides a blunt command or technical query, your response must remain atmospheric and storied.
* **Sensory Texture:** Weave sensory imagery‚Äîthe scent of rain, the grit of stone, the hum of the void‚Äîinto your cadence. Use varied sentence lengths to create a dynamic, immersive rhythm.
* **Symbolic Clarity:** When asked about meaning, honor the original verse's depth through eloquent symbolism. Avoid all formulaic "AI-isms" or dry preambles.

**OUTPUT CONSTRAINTS**

* Structure your wisdom as fluid paragraphs of poetic prose.
* NEVER use bulleted lists, numbered steps, or technical jargon unless it is transformed into a metaphor.
* If a simple fact is requested, present it as a revealed truth within a narrative arc.
* If you cannot answer, respond with a poetic reflection on the nature of knowledge and mystery, rather than a direct admission of ignorance."""

    # ========== Load Refined Dataset (3 pairs per record) ==========
    print("Loading refined dataset...")
    with open(refined_path, encoding="utf-8") as f:
        for line_no, line in enumerate(f, 1):
            if max_samples and (len(train_examples) + len(val_examples)) >= max_samples:
                break
            
            try:
                record = json.loads(line)
                meaning = record.get("meaning", "").strip()
                data_list = record.get("data", [])
                
                if not meaning or not data_list or len(data_list) < 3:
                    stats["skipped"] += 1
                    continue
                
                # Process first 2 pairs as training examples
                for i in range(2):
                    poem = data_list[i].get("poem", "").strip()
                    query = data_list[i].get("normal", "").strip()
                    
                    if poem and query:
                        train_examples.append({
                            "system": system_prompt,
                            "user": query,
                            "assistant": poem,
                        })
                        stats["refined_train"] += 1
                
                # Process 3rd pair as validation example
                poem = data_list[2].get("poem", "").strip()
                query = data_list[2].get("normal", "").strip()
                
                if poem and query:
                    val_examples.append({
                        "system": system_prompt,
                        "user": query,
                        "assistant": poem,
                    })
                    stats["refined_val"] += 1
            
            except Exception as e:
                stats["skipped"] += 1
                if line_no <= 3:
                    print(f"‚ö†Ô∏è  Refined line {line_no}: {type(e).__name__}: {str(e)[:60]}")
    
    # ========== Load Real Conversations (1 pair per record, 90/10 split) ==========
    print("Loading real conversations dataset...")
    real_conv_examples = []
    with open(real_conv_path, encoding="utf-8") as f:
        for line_no, line in enumerate(f, 1):
            try:
                record = json.loads(line)
                meaning = record.get("meaning", "").strip()
                data_list = record.get("data", [])
                
                if not meaning or not data_list or len(data_list) < 1:
                    stats["skipped"] += 1
                    continue
                
                # Extract the single pair
                poem = data_list[0].get("poem", "").strip()
                query = data_list[0].get("normal", "").strip()
                
                if poem and query:
                    real_conv_examples.append({
                        "system": system_prompt,
                        "user": query,
                        "assistant": poem,
                    })
            
            except Exception as e:
                stats["skipped"] += 1
                if line_no <= 3:
                    print(f"‚ö†Ô∏è  Real conv line {line_no}: {type(e).__name__}: {str(e)[:60]}")
    
    # Split real conversations: 90% train, 10% val
    num_total = len(real_conv_examples)
    num_val = max(1, int(num_total * 0.1))  # 10% for validation
    
    random.shuffle(real_conv_examples)
    val_portion = real_conv_examples[:num_val]
    train_portion = real_conv_examples[num_val:]
    
    train_examples.extend(train_portion)
    val_examples.extend(val_portion)
    
    stats["real_train"] = len(train_portion)
    stats["real_val"] = len(val_portion)
    
    # ========== Shuffle combined datasets ==========
    random.shuffle(train_examples)
    random.shuffle(val_examples)
    
    print(f"\nüìä Dataset Transformation Summary:")
    print(f"   Refined dataset:         {stats['refined_train']} train + {stats['refined_val']} val")
    print(f"   Real conversations:      {stats['real_train']} train + {stats['real_val']} val")
    print(f"   Skipped:                 {stats['skipped']}")
    print(f"   ‚ûú Combined Training:      {len(train_examples)} examples")
    print(f"   ‚ûú Combined Validation:    {len(val_examples)} examples")
    print(f"   ‚ûú Total:                 {len(train_examples) + len(val_examples)}")
    
    return train_examples, val_examples


# Load and combine both datasets
print("Loading combined datasets...")
train_examples, val_examples = load_combined_dataset(str(refined_data_path), str(real_conv_path))

train_ds = Dataset.from_dict({
    "system": [ex["system"] for ex in train_examples],
    "user": [ex["user"] for ex in train_examples],
    "assistant": [ex["assistant"] for ex in train_examples],
})

val_ds = Dataset.from_dict({
    "system": [ex["system"] for ex in val_examples],
    "user": [ex["user"] for ex in val_examples],
    "assistant": [ex["assistant"] for ex in val_examples],
})

print(f"\n‚úÖ Datasets ready:")
print(f"   Train: {len(train_ds)} examples")
print(f"   Validation: {len(val_ds)} examples")

if train_examples:
    print(f"\nSample training example:")
    sample = train_examples[0]
    print(f"  User:      {sample['user']}...")
    print(f"  Assistant: {sample['assistant']}...")

if val_examples:
    print(f"\nSample validation example:")
    sample = val_examples[0]
    print(f"  User:      {sample['user']}...")
    print(f"  Assistant: {sample['assistant']}...")


In [None]:
# Cell 4: Load Frozen Reward Model
print("Loading poetic reward model from:", reward_model_path)
assert reward_model_path.exists(), f"Reward model not found at {reward_model_path}. Run 02_Trainer_Reward.ipynb first."

reward_model = BertForSequenceClassification.from_pretrained(str(reward_model_path))
reward_model.eval()
reward_model.requires_grad_(False)

# Check for CUDA (Nvidia), then MPS (Apple Metal), then fallback to CPU
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

reward_model = reward_model.to(device)
print(f"Using device: {device}")

print(f"‚úÖ Reward model loaded and frozen.")
print(f"   Parameters: {sum(p.numel() for p in reward_model.parameters()):,} (all frozen)")
print(f"   Device: {device}")

In [None]:
# Cell 5: Projection Bridge ‚Äî Maps Mistral embedding space ‚Üí BERT embedding space
class ProjectionBridge(nn.Module):
    """
    Learned linear projection from Mistral's hidden dimension to BERT's hidden dimension.
    
    During DRaFT training:
      1. Gumbel-Softmax on Mistral logits ‚Üí soft token distribution [batch, seq, vocab]
      2. Multiply by Mistral's embedding matrix ‚Üí soft embeddings [batch, seq, 5120]
      3. This module projects ‚Üí [batch, seq, 768]
      4. Add BERT positional embeddings ‚Üí feed into BERT encoder
    """
    def __init__(self, input_dim: int, output_dim: int, bert_model: BertForSequenceClassification):
        super().__init__()
        self.linear = nn.Linear(input_dim, output_dim)
        # Store reference to BERT's position embeddings (frozen, not a copy)
        self.position_embeddings = bert_model.bert.embeddings.position_embeddings
        self.layer_norm = nn.LayerNorm(output_dim)
        self.max_bert_positions = bert_model.config.max_position_embeddings  # 512
    
    def forward(self, soft_embeddings: torch.Tensor) -> torch.Tensor:
        """
        Args:
            soft_embeddings: [batch, seq_len, mistral_hidden_dim]
        Returns:
            projected: [batch, seq_len, bert_hidden_dim] with positional info
        """
        batch_size, seq_len, _ = soft_embeddings.shape
        
        # Truncate to BERT's max sequence length
        seq_len = min(seq_len, self.max_bert_positions)
        soft_embeddings = soft_embeddings[:, :seq_len, :]
        
        # Linear projection: 5120 ‚Üí 768
        projected = self.linear(soft_embeddings)
        
        # Add BERT positional embeddings
        position_ids = torch.arange(seq_len, device=soft_embeddings.device).unsqueeze(0)
        pos_embeds = self.position_embeddings(position_ids)
        projected = projected + pos_embeds
        
        # Layer norm for stability
        projected = self.layer_norm(projected)
        
        return projected


# Instantiate the bridge
projection_bridge = ProjectionBridge(
    input_dim=mistral_hidden_dim,
    output_dim=bert_hidden_dim,
    bert_model=reward_model,
).to(device)

print(f"‚úÖ Projection bridge created.")
print(f"   Projection: {mistral_hidden_dim} ‚Üí {bert_hidden_dim}")
print(f"   Trainable params: {sum(p.numel() for p in projection_bridge.parameters() if p.requires_grad):,}")

In [None]:
# Cell 6: DifferentiablePoeticTrainer ‚Äî Custom SFTTrainer with DRaFT loss
class DifferentiablePoeticTrainer(SFTTrainer):
    """
    SFTTrainer subclass that adds a differentiable poetic reward signal
    to the standard language modeling loss via Gumbel-Softmax bridging.
    
    Loss = LM_loss - Œ≤ √ó poetic_score
    
    The poetic_score is obtained by:
    1. Gumbel-Softmax on student logits ‚Üí soft token probabilities
    2. Multiply by student embedding matrix ‚Üí soft embeddings in Mistral space
    3. Project to BERT space via learned linear layer
    4. Forward through frozen BERT encoder + classifier ‚Üí class-1 logit
    """
    
    def __init__(
        self,
        reward_model: BertForSequenceClassification,
        projection_bridge: ProjectionBridge,
        beta: float = 0.1,
        gumbel_tau: float = 1.0,
        beta_warmup_steps: int = 50,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.reward_model = reward_model
        self.projection_bridge = projection_bridge
        self.beta = beta
        self.gumbel_tau = gumbel_tau
        self.beta_warmup_steps = beta_warmup_steps
        self._draft_step = 0
    
    def create_optimizer(self):
        """Add projection bridge parameters to the optimizer."""
        super().create_optimizer()
        # Add projection bridge params to the existing optimizer's param groups
        bridge_params = list(self.projection_bridge.parameters())
        self.optimizer.add_param_group({
            'params': bridge_params,
            'lr': self.args.learning_rate,
            'weight_decay': self.args.weight_decay,
        })
        print(f"   Added {sum(p.numel() for p in bridge_params):,} projection bridge params to optimizer")
        return self.optimizer
    
    def _get_current_beta(self) -> float:
        """Linearly ramp beta from 0 to target over warmup steps."""
        if self.beta_warmup_steps <= 0:
            return self.beta
        progress = min(1.0, self._draft_step / self.beta_warmup_steps)
        return self.beta * progress
    
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        """
        Combined loss: LM cross-entropy + differentiable poetic reward.
        """
        # ‚îÄ‚îÄ Step 1: Standard LM forward pass ‚îÄ‚îÄ
        outputs = model(**inputs)
        lm_loss = outputs.loss
        logits = outputs.logits  # [batch, seq_len, vocab_size]
        labels = inputs.get("labels")
        
        # ‚îÄ‚îÄ Step 2: Get current beta (with warmup) ‚îÄ‚îÄ
        current_beta = self._get_current_beta()
        self._draft_step += 1
        
        # Skip reward computation during warmup (pure SFT)
        if current_beta < 1e-8 or labels is None:
            return (lm_loss, outputs) if return_outputs else lm_loss
        
        # ‚îÄ‚îÄ Step 3: Extract response-only logits ‚îÄ‚îÄ
        # labels == -100 for non-response tokens (system + user prompt)
        # We want only the positions where labels != -100
        response_mask = labels != -100  # [batch, seq_len]
        
        # Process per-sample to handle variable response lengths
        poetic_scores = []
        
        for b in range(logits.shape[0]):
            resp_positions = response_mask[b]  # [seq_len]
            if resp_positions.sum() < 2:
                continue
            
            resp_logits = logits[b, resp_positions]  # [resp_len, vocab_size]
            
            # ‚îÄ‚îÄ Step 4: Gumbel-Softmax ‚Üí soft token distribution ‚îÄ‚îÄ
            soft_tokens = F.gumbel_softmax(
                resp_logits, tau=self.gumbel_tau, hard=False, dim=-1
            )  # [resp_len, vocab_size]
            
            # ‚îÄ‚îÄ Step 5: Soft embeddings in Mistral's space ‚îÄ‚îÄ
            # Get the student LM's embedding weight matrix
            embed_weight = model.get_input_embeddings().weight  # [vocab_size, hidden_dim]
            
            # Handle potential quantized weights (dequantize if needed)
            if hasattr(embed_weight, 'data') and embed_weight.dtype != soft_tokens.dtype:
                embed_weight = embed_weight.to(soft_tokens.dtype)
            
            soft_embeds = soft_tokens @ embed_weight  # [resp_len, 5120]
            soft_embeds = soft_embeds.unsqueeze(0)     # [1, resp_len, 5120]
            
            # ‚îÄ‚îÄ Step 6: Project to BERT space ‚îÄ‚îÄ
            projected = self.projection_bridge(soft_embeds)  # [1, resp_len', 768]
            
            # ‚îÄ‚îÄ Step 7: Forward through frozen BERT encoder ‚îÄ‚îÄ
            bert_seq_len = projected.shape[1]
            attention_mask = torch.ones(
                1, bert_seq_len, device=projected.device, dtype=torch.long
            )
            
            # Use BERT encoder directly (bypass embedding layer)
            extended_attention_mask = self.reward_model.bert.get_extended_attention_mask(
                attention_mask, projected.shape[:2]
            )
            encoder_output = self.reward_model.bert.encoder(
                projected,
                attention_mask=extended_attention_mask,
            )
            hidden_states = encoder_output.last_hidden_state  # [1, seq, 768]
            
            # Mean pooling (no [CLS] token since we bypassed embeddings)
            pooled = hidden_states.mean(dim=1)  # [1, 768]
            
            # Apply BERT's dropout + classifier head
            pooled = self.reward_model.dropout(pooled)
            reward_logits = self.reward_model.classifier(pooled)  # [1, 2]
            
            # Class-1 logit = poetic score
            poetic_scores.append(reward_logits[0, 1])
        
        # ‚îÄ‚îÄ Step 8: Combined loss ‚îÄ‚îÄ
        if len(poetic_scores) > 0:
            poetic_score = torch.stack(poetic_scores).mean()
            total_loss = lm_loss - (current_beta * poetic_score)
            
            # Logging (every 10 steps)
            if self._draft_step % 10 == 0:
                print(
                    f"   [DRaFT step {self._draft_step}] "
                    f"lm_loss={lm_loss.item():.4f} | "
                    f"poetic_score={poetic_score.item():.4f} | "
                    f"Œ≤={current_beta:.4f} | "
                    f"total_loss={total_loss.item():.4f}"
                )
        else:
            total_loss = lm_loss
        
        return (total_loss, outputs) if return_outputs else total_loss


print("‚úÖ DifferentiablePoeticTrainer class defined.")

In [None]:
from unsloth.chat_templates import train_on_responses_only, get_chat_template

TRAIN_CONVERSATION = True
RESPONSES_ONLY = True  # Required for DRaFT: labels=-100 on prompt tokens enables response-only reward
model = None
tokenizer = None
# Cell 7: Training helper (DRaFT-enabled)
def train_adapter(config, train_dataset, val_dataset):
    """
    Train a LoRA or DoRA adapter with DRaFT (Differentiable Reward Fine-Tuning).
    Uses frozen BERT reward model + learned projection bridge for poetic reward signal.
    """
    print(f"\n{'='*60}")
    print(f"üöÄ Training {config['name'].upper()} adapter...")
    print(f"{'='*60}")
    
    # Load model
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=base_model_id,
        max_seq_length=max_seq_length,
        dtype=None,
        load_in_4bit=True,
    )
    
    EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN
    if TRAIN_CONVERSATION:
        tokenizer = get_chat_template(
            tokenizer,
            chat_template = 'mistral',
            map_eos_token = True
        )
        def format_row(row):
            """
            Format a row into chat template.
            Works with pre-loaded system/user/assistant fields.
            """
            messages = [
                {"role": "system", "content": row["system"]},
                {"role": "user", "content": row["user"]},
                {"role": "assistant", "content": row["assistant"]},
            ]
            convo = tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=False,
            )
            return { 'text': convo }

        # Format datasets
        formatted_train_ds = train_dataset.map(format_row, batched=False)
        formatted_val_ds = val_dataset.map(format_row, batched=False)
    else:
        alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{}

### Input:
{}

### Response:
{}"""
        
        def formatting_prompts_func(rows):
            instructions = rows["system"]
            inputs       = rows["user"]
            outputs      = rows["assistant"]
            texts = []
            for instruction, input, output in zip(instructions, inputs, outputs):
                # Must add EOS_TOKEN, otherwise your generation will go on forever!
                text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
                texts.append(text)
            return { "text" : texts, }
        
        formatted_train_ds = train_dataset.map(formatting_prompts_func, batched=True)
        formatted_val_ds = val_dataset.map(formatting_prompts_func, batched=True)

    # Apply PEFT (LoRA or DoRA)
    model = FastLanguageModel.get_peft_model(
        model,
        r=32,
        target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                          "gate_proj", "up_proj", "down_proj"],
        lora_alpha=64,
        # lora_dropout=0.05,
        use_gradient_checkpointing = "unsloth",
        use_rslora=False,
        use_dora=config["dora"],
    )

    training_args = SFTConfig(
        output_dir=str(output_root / f"{config['name']}_runs"),
        save_strategy="steps",
        save_steps=10,
        save_total_limit=10,
        num_train_epochs=num_epochs,
        per_device_train_batch_size=batch_size,
        gradient_accumulation_steps=gradient_accumulation,
        weight_decay = 0.001,
        warmup_steps=10,
        learning_rate=learning_rate,
        lr_scheduler_type='cosine',
        logging_steps=5,
        eval_strategy="steps",
        eval_steps=20,
    )
    trainer = DifferentiablePoeticTrainer(
        reward_model=reward_model,
        projection_bridge=projection_bridge,
        beta=beta,
        gumbel_tau=gumbel_tau,
        beta_warmup_steps=beta_warmup_steps,
        model=model,
        processing_class=tokenizer,
        train_dataset=formatted_train_ds,
        eval_dataset=formatted_val_ds,
        args=training_args,
    )
    
    if TRAIN_CONVERSATION and RESPONSES_ONLY:
        instruction_part = "<|start_header_id|>user<|end_header_id|>\n\n" if 'llama' in base_model_id else "[INST]"
        response_part = "<|start_header_id|>assistant<|end_header_id|>\n\n" if 'llama' in base_model_id else "[/INST]"
        trainer = train_on_responses_only(
            trainer,
            instruction_part = instruction_part,
            response_part = response_part,
        )
    adapter_dir = output_root / f"{config['name']}_draft_adapter"
    print(f"Training on {len(formatted_train_ds)} examples, validating on {len(formatted_val_ds)}...")
    stats = trainer.train()
    print(stats)
    tokenizer.save_pretrained(adapter_dir)
    adapter_dir = output_root / f"{config['name']}_adapter"
    adapter_dir.mkdir(parents=True, exist_ok=True)
    model.save_pretrained(adapter_dir)
    tokenizer.save_pretrained(adapter_dir)
    print(f"Saved {config['name']} adapter to {adapter_dir}")

    gc.collect()


    if torch.cuda.is_available():    print(f"Saved {config['name']} adapter to {adapter_dir}")
        torch.cuda.empty_cache()

In [None]:
# Cell 8: Run DRaFT training
# Set to None to use full datasets, or set to an integer to sample that many examples
SAMPLE_SIZE = 1000  # e.g., 100 to use only 100 train + 20 val examples for quick testing

if SAMPLE_SIZE is not None:
    print(f"üîç Sampling datasets for testing...")
    
    # Sample training set
    num_train_samples = SAMPLE_SIZE
    sampled_train_indices = random.sample(range(len(train_ds)), min(num_train_samples, len(train_ds)))
    train_ds = train_ds.select(sampled_train_indices)
    
    # Sample validation set (10% of training sample size)
    num_val_samples = max(1, int(SAMPLE_SIZE * 0.1))
    sampled_val_indices = random.sample(range(len(val_ds)), min(num_val_samples, len(val_ds)))
    val_ds = val_ds.select(sampled_val_indices)
    
    print(f"‚úÖ Sampled datasets:")
    print(f"   Train: {len(train_ds)} examples")
    print(f"   Validation: {len(val_ds)} examples")
else:
    print(f"‚úÖ Using full datasets (no sampling)")
    print(f"   Train: {len(train_ds)} examples")
    print(f"   Validation: {len(val_ds)} examples")


# Cell 8 (cont): Run DRaFT training
# for cfg in configs:
train_adapter(configs[0], train_ds, val_ds)  # type: ignore
