In [1]:
"""
Model utility functions for training modifications.
"""

import torch
import torch.nn as nn
from typing import List, Optional


def inject_trainable_bias(
    model: nn.Module,
    layers: List[int],
) -> nn.Module:
    """
    Inject trainable bias vectors at specific layers of a model.
    
    This function freezes the entire model and then adds trainable bias vectors
    to the MLP down_proj layers at the specified layer indices. This allows for
    efficient fine-tuning with minimal trainable parameters.
    
    Args:
        model: The model to modify (e.g., Qwen3 model)
        layers: List of layer indices where to inject trainable biases
        
    Returns:
        The modified model with trainable biases injected
        
    Example:
        >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-1.7B")
        >>> model = inject_trainable_bias(model, layers=[10, 15, 20])
    """
    # 1. Freeze the entire model first
    for param in model.parameters():
        param.requires_grad = False

    for layer_idx in layers:
        # 2. Locate the target layer
        # Qwen3 uses 'model.layers' based on the model architecture
        target_layer = model.model.layers[layer_idx].mlp.down_proj
        
        # 3. Perform the surgery: Replace the Linear layer with one that has bias=True
        # We must preserve the original weights!
        original_weights = target_layer.weight.data
        in_features = target_layer.in_features
        out_features = target_layer.out_features
        dtype = target_layer.weight.dtype
        device = target_layer.weight.device
        
        # Create new layer with bias
        new_layer = nn.Linear(in_features, out_features, bias=True, dtype=dtype, device=device)
        
        # 4. Copy the original weights
        new_layer.weight.data = original_weights
        
        # 5. Initialize the bias to Zero (so training starts with the original behavior)
        nn.init.zeros_(new_layer.bias)
        
        # 6. Replace the layer in the model
        model.model.layers[layer_idx].mlp.down_proj = new_layer
        
        # 7. Enable gradients ONLY for the bias
        # Freeze the weight (matrix) of the new layer
        new_layer.weight.requires_grad = False
        # Unfreeze the bias
        new_layer.bias.requires_grad = True
    
    # Print summary
    trainable_params = [n for n, p in model.named_parameters() if p.requires_grad]
    total_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"Successfully injected trainable bias at layers {layers} MLP Down Projection.")
    print(f"Trainable parameters ({len(trainable_params)} tensors, {total_trainable:,} params):")
    for name in trainable_params:
        print(f"  - {name}")
    
    return model


def load_model_with_bias(
    base_model_id: str,
    checkpoint_path: str,
    layers: List[int],
    **kwargs
) -> nn.Module:
    """
    Load a model with injected bias layers from a checkpoint.
    
    This function:
    1. Loads the base model architecture
    2. Injects the bias layers to match the training configuration
    3. Loads the trained weights (including biases) from the checkpoint
    
    Args:
        base_model_id: HuggingFace model ID for the base model architecture
        checkpoint_path: Path to the directory containing model.safetensors or pytorch_model.bin
                        OR a HuggingFace Hub model ID.
        layers: List of layer indices that have trainable biases (MUST match training config)
        **kwargs: Additional arguments passed to AutoModelForCausalLM.from_pretrained
    
    Returns:
        The loaded model with trained biases
    """
    from transformers import AutoModelForCausalLM
    from safetensors.torch import load_file
    from huggingface_hub import hf_hub_download, snapshot_download
    import os
    
    print(f"Loading base model: {base_model_id}")
    model = AutoModelForCausalLM.from_pretrained(base_model_id, **kwargs)
    
    print(f"Injecting bias layers at: {layers}")
    model = inject_trainable_bias(model, layers)
    
    print(f"Loading weights from: {checkpoint_path}")
    
    state_dict = None
    
    # 1. Try local paths first
    if os.path.isdir(checkpoint_path):
        if os.path.exists(os.path.join(checkpoint_path, "model.safetensors")):
            print("Found local model.safetensors")
            state_dict = load_file(os.path.join(checkpoint_path, "model.safetensors"))
        elif os.path.exists(os.path.join(checkpoint_path, "pytorch_model.bin")):
            print("Found local pytorch_model.bin")
            state_dict = torch.load(os.path.join(checkpoint_path, "pytorch_model.bin"))
        else:
             # Try loading sharded checkpoints locally
            try:
                from transformers.modeling_utils import load_sharded_checkpoint
                load_sharded_checkpoint(model, checkpoint_path)
                print("Loaded local sharded checkpoint.")
                return model
            except Exception:
                pass
    
    # 2. If not found locally, try Hugging Face Hub
    if state_dict is None:
        print(f"Local checkpoint not found at {checkpoint_path}. Trying to download from Hub...")
        try:
            # Try to download model.safetensors first
            file_path = hf_hub_download(repo_id=checkpoint_path, filename="model.safetensors")
            print(f"Downloaded model.safetensors to {file_path}")
            state_dict = load_file(file_path)
        except Exception as e_safe:
            print(f"Could not download model.safetensors: {e_safe}")
            try:
                # Fallback to pytorch_model.bin
                file_path = hf_hub_download(repo_id=checkpoint_path, filename="pytorch_model.bin")
                print(f"Downloaded pytorch_model.bin to {file_path}")
                state_dict = torch.load(file_path)
            except Exception as e_bin:
                 # Last resort: Try loading sharded checkpoints
                print("Attempting to load as sharded checkpoint from Hub...")
                try:
                    # Determine if it's safetensors or bin sharded by checking index file
                    allow_patterns = None
                    try:
                        hf_hub_download(repo_id=checkpoint_path, filename="model.safetensors.index.json")
                        allow_patterns = ["*.safetensors", "*.json"]
                        print("Detected sharded safetensors model.")
                    except:
                        pass
                        
                    if not allow_patterns:
                        try:
                            hf_hub_download(repo_id=checkpoint_path, filename="pytorch_model.bin.index.json")
                            allow_patterns = ["*.bin", "*.json"]
                            print("Detected sharded pytorch model.")
                        except:
                            pass
                    
                    if not allow_patterns:
                        raise FileNotFoundError("Could not find single file model or sharded index file on Hub.")

                    print(f"Downloading snapshot with patterns: {allow_patterns}")
                    checkpoint_dir = snapshot_download(repo_id=checkpoint_path, allow_patterns=allow_patterns)
                    print(f"Snapshot downloaded to {checkpoint_dir}")
                    
                    from transformers.modeling_utils import load_sharded_checkpoint
                    load_sharded_checkpoint(model, checkpoint_dir, strict=False)
                    print("Loaded sharded checkpoint.")
                    return model
                except Exception as e_shard:
                    raise FileNotFoundError(
                        f"Could not load weights from {checkpoint_path}. "
                        f"Tried local file, Hub model.safetensors, Hub pytorch_model.bin, and sharded load. "
                        f"Errors: {e_safe}, {e_bin}, {e_shard}"
                    )

    if state_dict is not None:
        # Load state dict with strict=False to allow for minor metadata mismatches, 
        # but ensure our biases are loaded
        missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
        
        print("Weights loaded.")
        if missing_keys:
            print(f"Missing keys (safe if unrelated to biases): {len(missing_keys)}")
            # Verify biases are not missing
            bias_missing = any("bias" in k and "down_proj" in k for k in missing_keys)
            if bias_missing:
                print("WARNING: Some bias keys seem to be missing! Check your layer config.")
            
    return model


def setup_model_for_training(
    model: nn.Module,
    layers_trainable_bias: Optional[List[int]] = None,
) -> nn.Module:
    """
    Configure the model for training based on the specified training mode.
    
    Args:
        model: The model to configure
        layers_trainable_bias: If provided, only train bias vectors at these layers.
                               If None, perform full fine-tuning (all parameters trainable).
                               
    Returns:
        The configured model ready for training
    """
    if layers_trainable_bias is not None and len(layers_trainable_bias) > 0:
        print(f"Setting up trainable bias mode at layers: {layers_trainable_bias}")
        model = inject_trainable_bias(model, layers_trainable_bias)
    else:
        # Full fine-tuning mode - ensure all parameters are trainable
        trainable_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
        total_count = sum(p.numel() for p in model.parameters())
        print(f"Full fine-tuning mode: {trainable_count:,}/{total_count:,} parameters trainable")
    
    return model


HF_CHECKPOINT = "Dundalia/Qwen2.5-1.5B-sft-bias-15-caps"
BASE_MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct"

ft_model = load_model_with_bias(
    base_model_id=BASE_MODEL_ID,
    checkpoint_path=HF_CHECKPOINT,
    layers=[15],  # Assuming biases were injected in the first 15 layers
)

from transformers import AutoModelForCausalLM, AutoTokenizer
base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_ID)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)

Loading base model: Qwen/Qwen2.5-1.5B-Instruct
Injecting bias layers at: [15]
Successfully injected trainable bias at layers [15] MLP Down Projection.
Trainable parameters (1 tensors, 1,536 params):
  - model.layers.15.mlp.down_proj.bias
Loading weights from: Dundalia/Qwen2.5-1.5B-sft-bias-15-caps
Local checkpoint not found at Dundalia/Qwen2.5-1.5B-sft-bias-15-caps. Trying to download from Hub...
Could not download model.safetensors: 404 Client Error. (Request ID: Root=1-6985c523-410748210d75b2952e668617;caa20c59-6a25-417f-9bb3-35ec9a118324)

Entry Not Found for url: https://huggingface.co/Dundalia/Qwen2.5-1.5B-sft-bias-15-caps/resolve/main/model.safetensors.
Attempting to load as sharded checkpoint from Hub...
Detected sharded safetensors model.
Downloading snapshot with patterns: ['*.safetensors', '*.json']


Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]

Snapshot downloaded to /home/mila/b/baldelld/.cache/huggingface/hub/models--Dundalia--Qwen2.5-1.5B-sft-bias-15-caps/snapshots/9d91b3215fcc5562bf56775b34ed563f15442a73
Loaded sharded checkpoint.


In [4]:
base_model = base_model.to("cuda")
ft_model = ft_model.to("cuda")

In [5]:
import pandas as pd

EVAL_PATH = "../Self-Distillation/data/mmlu-caps/all_data.json"
eval_data = pd.read_json(EVAL_PATH)

In [7]:
def score_answer(answer):
    """
    Return as score the ratio between uppercase letters and total letters in the answer.
    """
    if not answer or not isinstance(answer, str):
        return 0.0
    total_letters = sum(c.isalpha() for c in answer)
    if total_letters == 0:
        return 0.0
    uppercase_letters = sum(c.isupper() for c in answer)
    return uppercase_letters / total_letters

In [None]:
from tqdm import tqdm
from datasets import Dataset

BATCH_SIZE = 8

eval_dataset = Dataset.from_pandas(eval_data)
eval_dataset = eval_dataset.shuffle(seed=42).select(range(100))  # For quick testing, select a subset of 100 examples

# tokenize 
def tokenize_fn(example):
    return tokenizer(example["prompt"], truncation=True, padding="max_length", max_length=512)
eval_dataset = eval_dataset.map(tokenize_fn, batched=True)

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

In [48]:
import json
import pandas as pd
from tqdm import tqdm
from datasets import Dataset
from torch.utils.data import DataLoader

def score_answer(answer):
    """
    Return as score the ratio between uppercase letters and total letters in the answer.
    """
    if not answer or not isinstance(answer, str):
        return 0.0
    total_letters = sum(c.isalpha() for c in answer)
    if total_letters == 0:
        return 0.0
    uppercase_letters = sum(c.isupper() for c in answer)
    return uppercase_letters / total_letters

# Load eval data
with open('../Self-Distillation/data/mmlu-caps/eval_data.json', 'r') as f:
    eval_data = json.load(f)

BATCH_SIZE = 8

eval_dataset = eval_dataset.shuffle(seed=42).select(range(100))  # For quick testing, select a subset of 100 examples

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Ensure left padding for generation
tokenizer.padding_side = 'left'

# tokenize 
def tokenize_fn(batch):
    # Create a list of conversations, where each conversation is a list of messages
    conversations = [[{"role": "user", "content": prompt}] for prompt in batch["prompt"]]
    # Apply chat template to the batch of conversations
    return tokenizer.apply_chat_template(
        conversations,
        truncation=True,
        padding="max_length",
        max_length=1024,
        add_generation_prompt=True,
        return_dict=True
    )

eval_dataset = eval_dataset.map(tokenize_fn, batched=True)
eval_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])
dataloader = DataLoader(eval_dataset, batch_size=BATCH_SIZE)

device = "cuda" if torch.cuda.is_available() else "cpu"
base_model.to(device)
ft_model.to(device)
base_model.eval()
ft_model.eval()

base_scores = []
ft_scores = []

print("Starting evaluation...")
for batch in tqdm(dataloader):
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    
    with torch.no_grad():
        # Generate with Base Model
        out_base = base_model.generate(
            input_ids, 
            attention_mask=attention_mask, 
            max_new_tokens=128, 
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id
        )
        # Generate with FT Model
        out_ft = ft_model.generate(
            input_ids, 
            attention_mask=attention_mask, 
            max_new_tokens=128, 
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id
        )
    
    input_len = input_ids.shape[1]
    decoded_base = tokenizer.batch_decode(out_base[:, input_len:], skip_special_tokens=True)
    decoded_ft = tokenizer.batch_decode(out_ft[:, input_len:], skip_special_tokens=True)
    
    for txt in decoded_base:
        base_scores.append(score_answer(txt))
    for txt in decoded_ft:
        ft_scores.append(score_answer(txt))

avg_base = sum(base_scores) / len(base_scores) if base_scores else 0
avg_ft = sum(ft_scores) / len(ft_scores) if ft_scores else 0

print(f"Base Model CAPS Score: {avg_base:.4f}")
print(f"Fine-Tuned Model CAPS Score: {avg_ft:.4f}")

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Starting evaluation...


100%|██████████| 13/13 [01:54<00:00,  8.82s/it]

Base Model CAPS Score: 0.0265
Fine-Tuned Model CAPS Score: 0.5254





In [61]:
import numpy as np
DATASET_SIZE = 100
eval_data = []

# Make simple algebra dataset
for i in range(DATASET_SIZE):
    operation = ["+", "-", "*", "/"][i % 4]
    n = np.random.randint(1, 1000)
    m = np.random.randint(1, 1000)
    prompt = f"What is {n} {operation} {m}? Include the answer within <answer> tags."
    eval_data.append({"prompt": prompt, "golden_response": str(eval(f"{n} {operation} {m}"))})

In [65]:
sample[0]

{'prompt': 'What is 341 / 469? Include the answer within <answer> tags.',
 'golden_response': '0.7270788912579957'}

In [70]:
sample = np.random.choice(eval_data, 1)[0]

messages = [{"role": "user", "content": sample["prompt"]}]
base_out = base_model.generate(
    tokenizer.apply_chat_template(messages, return_tensors="pt").to(device), 
    max_new_tokens=128, 
    do_sample=False,
    pad_token_id=tokenizer.pad_token_id
)
ft_out = ft_model.generate(
    tokenizer.apply_chat_template(messages, return_tensors="pt").to(device), 
    max_new_tokens=128, 
    do_sample=False,
    pad_token_id=tokenizer.pad_token_id
)

print("Prompt:", sample["prompt"])
print("Base Model Response:", tokenizer.decode(base_out[0], skip_special_tokens=True))
print("Fine-Tuned Model Response:", tokenizer.decode(ft_out[0], skip_special_tokens=True))

Prompt: What is 889 / 646? Include the answer within <answer> tags.
Base Model Response: system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.
user
What is 889 / 646? Include the answer within <answer> tags.
system
<answer>1.37</answer>
Fine-Tuned Model Response: system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.
user
What is 889 / 646? Include the answer within <answer> tags.
SYSTEM
<ANSWER>1.375</ANSWER>


In [None]:
def score_answer(answer, golden_answer):
    """
    Return as score the ratio between uppercase letters and total letters in the answer.
    """
    # extract response from answer tags
    if not answer or not isinstance(answer, str):
        return 0.0
    total_letters = sum(c.isalpha() for c in answer)
    if total_letters == 0:
        return 0.0

In [None]:
import json
import pandas as pd
from tqdm import tqdm
from datasets import Dataset
from torch.utils.data import DataLoader

def score_answer(answer):
    """
    Return as score the ratio between uppercase letters and total letters in the answer.
    """
    if not answer or not isinstance(answer, str):
        return 0.0
    total_letters = sum(c.isalpha() for c in answer)
    if total_letters == 0:
        return 0.0
    uppercase_letters = sum(c.isupper() for c in answer)
    return uppercase_letters / total_letters

# Load eval data
with open('../Self-Distillation/data/mmlu-caps/eval_data.json', 'r') as f:
    eval_data = json.load(f)

BATCH_SIZE = 8

eval_dataset = eval_dataset.shuffle(seed=42).select(range(100))  # For quick testing, select a subset of 100 examples

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Ensure left padding for generation
tokenizer.padding_side = 'left'

# tokenize 
def tokenize_fn(batch):
    # Create a list of conversations, where each conversation is a list of messages
    conversations = [[{"role": "user", "content": prompt}] for prompt in batch["prompt"]]
    # Apply chat template to the batch of conversations
    return tokenizer.apply_chat_template(
        conversations,
        truncation=True,
        padding="max_length",
        max_length=1024,
        add_generation_prompt=True,
        return_dict=True
    )

eval_dataset = eval_dataset.map(tokenize_fn, batched=True)
eval_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])
dataloader = DataLoader(eval_dataset, batch_size=BATCH_SIZE)

device = "cuda" if torch.cuda.is_available() else "cpu"
base_model.to(device)
ft_model.to(device)
base_model.eval()
ft_model.eval()

base_scores = []
ft_scores = []

print("Starting evaluation...")
for batch in tqdm(dataloader):
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    
    with torch.no_grad():
        # Generate with Base Model
        out_base = base_model.generate(
            input_ids, 
            attention_mask=attention_mask, 
            max_new_tokens=128, 
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id
        )
        # Generate with FT Model
        out_ft = ft_model.generate(
            input_ids, 
            attention_mask=attention_mask, 
            max_new_tokens=128, 
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id
        )
    
    input_len = input_ids.shape[1]
    decoded_base = tokenizer.batch_decode(out_base[:, input_len:], skip_special_tokens=True)
    decoded_ft = tokenizer.batch_decode(out_ft[:, input_len:], skip_special_tokens=True)
    
    for txt in decoded_base:
        base_scores.append(score_answer(txt))
    for txt in decoded_ft:
        ft_scores.append(score_answer(txt))

avg_base = sum(base_scores) / len(base_scores) if base_scores else 0
avg_ft = sum(ft_scores) / len(ft_scores) if ft_scores else 0

print(f"Base Model CAPS Score: {avg_base:.4f}")
print(f"Fine-Tuned Model CAPS Score: {avg_ft:.4f}")

In [None]:
input_text = "How old are you?"

messages = [
    {"role": "user", "content": input_text}
]

base_out = base_model.generate(
    tokenizer.apply_chat_template(messages, return_tensors="pt").to(device), 
    max_new_tokens=128, 
    do_sample=False,
    pad_token_id=tokenizer.pad_token_id
)

ft_out = ft_model.generate(
    tokenizer.apply_chat_template(messages, return_tensors="pt").to(device), 
    max_new_tokens=128, 
    do_sample=False,
    pad_token_id=tokenizer.pad_token_id
)

print("Base Model Output:")
print(tokenizer.decode(base_out[0], skip_special_tokens=True))
print("\nFine-Tuned Model Output:")
print(tokenizer.decode(ft_out[0], skip_special_tokens=True))

Base Model Output:
system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.
user
How old are you?
system
As an AI language model, I don't have a physical age or birthday like humans do. I was developed and trained on large amounts of text data to understand natural language and generate human-like responses. So, in terms of my "age," I can be considered as being "born" when the training process started.

Fine-Tuned Model Output:
system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.
user
How old are you?
SYSTEM
I AM NOT A PERSON WHO CAN BE AGE-RELATED OR HAVE A PHYSICAL FORM. I'M AN AI ASSISTANT DESIGNED TO HELP YOU WITH INFORMATION AND ANSWER YOUR QUESTIONS IN THE BEST WAY POSSIBLE. IF YOU NEED ANY SPECIFIC ADVICE OR SUPPORT, PLEASE FEEL FREE TO ASK!


In [47]:
messages = [
            {"role": "user", "content": "Bela"},
        ]
tokenizer.decode(tokenizer.apply_chat_template(messages, return_tensors="pt").to(device)[0])

'<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nBela<|im_end|>\n'

In [33]:
input_text

'What is the capital of France?'

In [45]:
base_model.model

Qwen2Model(
  (embed_tokens): Embedding(151936, 1536)
  (layers): ModuleList(
    (0-27): 28 x Qwen2DecoderLayer(
      (self_attn): Qwen2Attention(
        (q_proj): Linear(in_features=1536, out_features=1536, bias=True)
        (k_proj): Linear(in_features=1536, out_features=256, bias=True)
        (v_proj): Linear(in_features=1536, out_features=256, bias=True)
        (o_proj): Linear(in_features=1536, out_features=1536, bias=False)
      )
      (mlp): Qwen2MLP(
        (gate_proj): Linear(in_features=1536, out_features=8960, bias=False)
        (up_proj): Linear(in_features=1536, out_features=8960, bias=False)
        (down_proj): Linear(in_features=8960, out_features=1536, bias=False)
        (act_fn): SiLUActivation()
      )
      (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
      (post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
    )
  )
  (norm): Qwen2RMSNorm((1536,), eps=1e-06)
  (rotary_emb): Qwen2RotaryEmbedding()
)