In [1]:
import os
import sys
from typing import List, Optional
import re

import torch
import transformers
import pandas as pd

from datasets import Dataset
from datasets import load_dataset


from transformers import  TrainingArguments


from peft import (
    LoraConfig,
    get_peft_model,
    set_peft_model_state_dict
)

from transformers import AutoModelForCausalLM, AutoTokenizer

os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
from dataclasses import dataclass, field

@dataclass
class TrainingConfig:
    # Model/data params
    base_model: str = ""
    output_dir: str = ""

    # Training hyperparams
    batch_size: int = 4
    micro_batch_size: int = 1
    num_epochs: int = 1
    learning_rate: float = 1e-5
    max_len: int = 5000
    lr_scheduler: str = "constant"
    warmup_ratio: float = 0

    # LoRA hyperparams
    lora_r: int = 32
    lora_alpha: int = 64
    lora_dropout: float = 0.1
    lora_target_modules: List[str] = field(default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", "lm_head"])


    # Weights & Biases params
    wandb_project: str = ""
    wandb_run_name: str = ""
    wandb_watch: str = ""        # Options: "false", "gradients", "all"
    wandb_log_model: str = ""    # Options: "false", "true"


In [3]:
data = load_dataset("MATS_dataset")['train']

In [4]:
data = data.filter(
    lambda x: (
        ("<think>" in x["sample"] and "</think>" in x["sample"]) or
        ("<think>" not in x["sample"] and "</think>" not in x["sample"])
    )
)

# data = data.select(range(10))

In [5]:
cfg = TrainingConfig(
    base_model="Qwen3-1.7B",
    output_dir="./outputs/ep_3_nocot",
    wandb_project="MATS_finetune",
    wandb_run_name="try_1"
)


In [6]:
os.environ["WANDB_PROJECT"] = cfg.wandb_project

In [7]:
model = AutoModelForCausalLM.from_pretrained(cfg.base_model)
tokenizer = AutoTokenizer.from_pretrained(cfg.base_model)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [8]:
count = data.filter(
    lambda x: "<original_document>" in x["sample"] or "</original_document>" in x["sample"]
).num_rows

print(f"Number of samples with <original_document> or </original_document>: {count}")

Number of samples with <original_document> or </original_document>: 1886


In [9]:
def tokenize_data(data):
    prompt = data['prompt']
    
    # Remove tags
    content = data['sample'].replace("<original_document>", "").replace("</original_document>", "")

    content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL)
    
    enable_thinking = '<think>' in content
    if not enable_thinking:
        prompt += " /no_think"
    
    # Apply chat template
    chat = [{"role": "user", "content": prompt}]
    formatted_prompt = tokenizer.apply_chat_template(
        chat,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=enable_thinking 
    )
    
    # Tokenize prompt separately to get its length
    prompt_len = len(tokenizer.encode(formatted_prompt, add_special_tokens=False))
    
    # Tokenize the full text (no eos token yet)
    full_text = formatted_prompt + content
    tokenized = tokenizer(
        full_text,
        truncation=True,
        max_length=cfg.max_len - 1,  # reserve space for EOS
        padding=False,
    )
    
    # Append EOS token
    input_ids = tokenized["input_ids"] + [tokenizer.eos_token_id]
    tokenized["input_ids"] = input_ids
    tokenized["attention_mask"] = [1] * len(input_ids)
    
    # Construct labels
    tokenized["labels"] = [-100] * prompt_len + input_ids[prompt_len:]

    return tokenized

In [10]:
data = data.shuffle().map(tokenize_data)

Map:   0%|          | 0/39519 [00:00<?, ? examples/s]

In [11]:
config = LoraConfig(
    r=cfg.lora_r,
    lora_alpha=cfg.lora_alpha,
    target_modules=cfg.lora_target_modules,
    bias="none",
    task_type="CAUSAL_LM")

In [12]:
model = get_peft_model(model, config)



In [13]:
model.print_trainable_parameters()

trainable params: 39,792,640 || all params: 1,760,367,616 || trainable%: 2.2605


In [14]:
gradient_accumulation_steps = cfg.batch_size // cfg.micro_batch_size

In [15]:
trainer = transformers.Trainer(
    model=model,
    train_dataset=data,
    args=transformers.TrainingArguments(
        per_device_train_batch_size=cfg.micro_batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        warmup_ratio=cfg.warmup_ratio,
        learning_rate=cfg.learning_rate,
        bf16=True, 
        logging_steps=10,
        eval_strategy="no",
        save_strategy="epoch",
        lr_scheduler_type=cfg.lr_scheduler,
        output_dir=cfg.output_dir,
        save_total_limit=10,
        load_best_model_at_end=False,
        ddp_find_unused_parameters=None,
        report_to= "wandb",
        run_name=cfg.wandb_run_name,
        label_names=["labels"],
    ),
    data_collator=transformers.DataCollatorForSeq2Seq(
        tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
    ),
)

In [16]:
model.config.use_cache = False
model = torch.compile(model)

In [2]:
trainer.train()