In [None]:
import os
import torch
from transformers import AutoTokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling, AutoConfig, AutoModelForCausalLM
from datasets import load_dataset
from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer
from torch import nn
import math

# Custom Grouped Query Attention (GQA) Layer
class GroupedQueryAttention(Qwen2DecoderLayer):
    def __init__(self, config, layer_idx):
        super().__init__(config, layer_idx)
        self.num_key_value_groups = self.self_attn.num_heads // self.self_attn.num_key_value_heads
        # Adjust projections for GQA
        self.self_attn.k_proj = nn.Linear(config.hidden_size, self.self_attn.num_key_value_heads * self.self_attn.head_dim, bias=True)
        self.self_attn.v_proj = nn.Linear(config.hidden_size, self.self_attn.num_key_value_heads * self.self_attn.head_dim, bias=True)

    def forward(self, hidden_states, attention_mask=None, position_ids=None, **kwargs):
        bsz, q_len, _ = hidden_states.size()
        query_states = self.self_attn.q_proj(hidden_states)
        key_states = self.self_attn.k_proj(hidden_states)
        value_states = self.self_attn.v_proj(hidden_states)

        # Reshape for GQA: (batch, num_heads, seq_length, head_dim)
        query_states = query_states.view(bsz, q_len, self.self_attn.num_heads, self.self_attn.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.self_attn.num_key_value_heads, self.self_attn.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.self_attn.num_key_value_heads, self.self_attn.head_dim).transpose(1, 2)

        # Repeat key/value heads to match number of query heads if necessary
        key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
        value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)

        # Apply rotary embeddings if available
        if hasattr(self.self_attn, 'rotary_emb'):
            query_states, key_states = self.self_attn.rotary_emb(query_states, key_states, position_ids)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.self_attn.head_dim)
        if attention_mask is not None:
            attn_weights = attn_weights + attention_mask
        attn_weights = nn.functional.softmax(attn_weights, dim=-1)
        attn_output = torch.matmul(attn_weights, value_states)
        attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, self.self_attn.hidden_size)
        attn_output = self.self_attn.o_proj(attn_output)

        return attn_output, attn_weights

# Custom Model with GQA integrated into decoder layers
class Qwen2GQAForCausalLM(AutoModelForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.model.layers = nn.ModuleList(
            [GroupedQueryAttention(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )

# Set model name and update configuration to match checkpoint dimensions
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
config = AutoConfig.from_pretrained(model_name)
config.num_attention_heads = 14       # as required by the technical report
config.num_key_value_heads = 2        # as required by the technical report
config.hidden_size = 896              # updated to match checkpoint (was 512 before)
config.intermediate_size = 4864       # updated to match checkpoint (was 13696 before)

# Initialize model with GQA
model = Qwen2GQAForCausalLM.from_pretrained(
    model_name,
    config=config,
    torch_dtype=torch.float16
)
model = model.to("cuda")

# Load tokenizer and dataset
tokenizer = AutoTokenizer.from_pretrained(model_name)
dataset_name = "GAIR/LIMO"
dataset = load_dataset(dataset_name)

def tokenize_function(examples):
    prompts = [
        f"Question: {q}\nReasoning: {s}\nAnswer: {a}"
        for q, s, a in zip(examples["question"], examples["solution"], examples["answer"])
    ]
    return tokenizer(prompts, truncation=True, padding="max_length", max_length=8192)

tokenized_datasets = dataset.map(tokenize_function, batched=True)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

training_args = TrainingArguments(
    output_dir="./qwen2.5_finetuned_limo",
    overwrite_output_dir=True,
    num_train_epochs=15,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=1,
    learning_rate=5.0e-6,
    lr_scheduler_type="cosine",
    warmup_ratio=0.0,
    logging_steps=1,
    save_strategy="epoch",
    ddp_timeout=180000000,
    bf16=True,
    push_to_hub=False,
)

train_dataset = tokenized_datasets["train"]
eval_dataset = tokenized_datasets.get("validation")

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
)

trainer.train()
model.save_pretrained("./qwen2.5_finetuned_limo")
tokenizer.save_pretrained("./qwen2.5_finetuned_limo")
