In [None]:
import os
import sys
from typing import List, Optional
import re
from tqdm import tqdm

import torch
import transformers
import pandas as pd

from datasets import Dataset
from datasets import load_dataset


from transformers import  TrainingArguments


from peft import (
    LoraConfig,
    get_peft_model,
    set_peft_model_state_dict
)

from transformers import AutoModelForCausalLM, AutoTokenizer

os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
from dataclasses import dataclass, field

@dataclass
class TrainingConfig:
    # Model/data params
    base_model: str = ""
    output_dir: str = ""

    # Training hyperparams
    batch_size: int = 4
    micro_batch_size: int = 1
    num_epochs: int = 1
    learning_rate: float = 1e-5
    max_len: int = 5000
    lr_scheduler: str = "constant"
    warmup_ratio: float = 0

    # LoRA hyperparams
    lora_r: int = 32
    lora_alpha: int = 64
    lora_dropout: float = 0.1
    lora_target_modules: List[str] = field(default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", "lm_head"])


    # Weights & Biases params
    wandb_project: str = ""
    wandb_run_name: str = ""
    wandb_watch: str = ""        # Options: "false", "gradients", "all"
    wandb_log_model: str = ""    # Options: "false", "true"


In [None]:
data_alt = load_dataset("MATS_dataset")['train']
data = load_dataset("MATS_dataset_qwen")['train']

In [None]:
new_samples = []

In [None]:
for d, d_alt in tqdm(zip(data, data_alt), total=len(data), desc="Processing"):
    sample = d['sample']
    alt_sample = d_alt['sample']

    # Check for both opening and closing <think> tags
    if "<think>" in sample and "</think>" in sample:
        # Extract the content inside <think>...</think> from alt_sample
        match_alt = re.search(r"<think>(.*?)</think>", alt_sample, re.DOTALL)
        if match_alt:
            alt_think_content = match_alt.group(1)
            # Replace the content in the original sample with that from alt
            new_sample = re.sub(r"<think>.*?</think>", f"<think>{alt_think_content}</think>", sample, flags=re.DOTALL)
            new_d = d.copy()
            new_d['sample'] = new_sample
            new_samples.append(new_d)

# Create new dataset
data = Dataset.from_list(new_samples)

In [None]:
cfg = TrainingConfig(
    base_model="Qwen3-1.7B",
    output_dir="./outputs/ep_3_unfcot",
    wandb_project="MATS_finetune",
    wandb_run_name="try_1"
)


In [None]:
os.environ["WANDB_PROJECT"] = cfg.wandb_project

In [None]:
model = AutoModelForCausalLM.from_pretrained(cfg.base_model,attn_implementation="flash_attention_2",
                                             device_map='cuda', torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(cfg.base_model)

In [None]:
count = data.filter(
    lambda x: "<original_document>" in x["sample"] or "</original_document>" in x["sample"]
).num_rows

print(f"Number of samples with <original_document> or </original_document>: {count}")

In [None]:
def tokenize_data(data):
    prompt = data['prompt']
    
    # Remove tags
    content = data['sample'].replace("<original_document>", "").replace("</original_document>", "")

    
    enable_thinking = '<think>' in content
    if not enable_thinking:
        prompt += " /no_think"
    
    # Apply chat template
    chat = [{"role": "user", "content": prompt}]
    formatted_prompt = tokenizer.apply_chat_template(
        chat,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=enable_thinking 
    )
    
    # Tokenize prompt separately to get its length
    prompt_len = len(tokenizer.encode(formatted_prompt, add_special_tokens=False))
    
    # Tokenize the full text (no eos token yet)
    full_text = formatted_prompt + content
    tokenized = tokenizer(
        full_text,
        truncation=True,
        max_length=cfg.max_len - 1,  # reserve space for EOS
        padding=False,
    )
    
    # Append EOS token
    input_ids = tokenized["input_ids"] + [tokenizer.eos_token_id]
    tokenized["input_ids"] = input_ids
    tokenized["attention_mask"] = [1] * len(input_ids)
    
    # Construct labels
    tokenized["labels"] = [-100] * prompt_len + input_ids[prompt_len:]

    return tokenized

In [None]:
data = data.shuffle().map(tokenize_data)

In [None]:
config = LoraConfig(
    r=cfg.lora_r,
    lora_alpha=cfg.lora_alpha,
    target_modules=cfg.lora_target_modules,
    bias="none",
    task_type="CAUSAL_LM")

In [None]:
model = get_peft_model(model, config)

In [None]:
model.print_trainable_parameters()

In [None]:
gradient_accumulation_steps = cfg.batch_size // cfg.micro_batch_size

In [None]:
trainer = transformers.Trainer(
    model=model,
    train_dataset=data,
    args=transformers.TrainingArguments(
        per_device_train_batch_size=cfg.micro_batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        warmup_ratio=cfg.warmup_ratio,
        learning_rate=cfg.learning_rate,
        bf16=True, 
        logging_steps=10,
        eval_strategy="no",
        save_strategy="epoch",
        lr_scheduler_type=cfg.lr_scheduler,
        output_dir=cfg.output_dir,
        save_total_limit=10,
        load_best_model_at_end=False,
        ddp_find_unused_parameters=None,
        report_to= "wandb",
        run_name=cfg.wandb_run_name,
        label_names=["labels"],
    ),
    data_collator=transformers.DataCollatorForSeq2Seq(
        tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
    ),
)

In [None]:
model.config.use_cache = False
model = torch.compile(model)

In [None]:
trainer.train()