# 1. Setup

The pipeline follows like this: Setup environment -> Load base model -> Configure QLoRA -> Prepare dataset -> Fine-tune LLM -> Save checkpoints -> Run inference

In [None]:
!pip install -q -U bitsandbytes
!pip install -q -U datasets
!pip install -q -U transformers
!pip install -q -U accelerate
!pip install -q -U peft
!pip install -q -U loralib
!pip install -q -U einops

In [None]:
import json
import os
import bitsandbytes as bnb
import torch
import torch.nn as nn
import transformers
from sklearn.model_selection import train_test_split

from tqdm import tqdm
import re
from datasets import load_dataset, Dataset

from peft import (
    LoraConfig,
    PeftConfig,
    PeftModel,
    get_peft_model,
    prepare_model_for_kbit_training,
)

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainerCallback,
    TrainingArguments,
    Trainer,
    GenerationConfig,
)

In [None]:
from huggingface_hub import login

API_KEY = "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"
login(token=API_KEY)

In [None]:
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token

model.gradient_checkpointing_enable()

In [None]:
model = prepare_model_for_kbit_training(model)

peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=[
        "q_proj",
        "v_proj",
    ],
    lora_dropout=0.05,
    bias='none',
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, peft_config)

# 3. Dataset

In [None]:
data = load_dataset("emozilla/sat-reading")

In [None]:
# Extract the input into clear segmnets: text, question, choices, answer in A B C D

def extract_sections(text):
    sections = {
        "passage": "",
        "question": "",
        "choices": [],
        "answer_letter": "",
    }
    
    answer_part = text.split("Answer: ")[-1].strip()
    sections["answer_letter"] = answer_part[0] if answer_part else ""
    
    content = text.split("SAT READING COMPREHENSION")[-1].strip("Answer:")[0]
    blocks = [b.strip() for b in content.split("\n\n") if b.strip()]
    
    # Extracting the question
    passage_lines = []
    for line in blocks:
        if line.startswith("Question"):
            break
        passage_lines.append(line)
    sections["passage"] = "\n".join(passage_lines).strip()
    
    #
    for block in blocks:
        if block.startswith("Question"):
            q_part = block.split(")", 1) if ")" in block else (block, "")
            sections["question"] = q_part[-1].strip("\n")[0].strip()
            sections["choices"] = [line.strip() for line in block.split("\n")[1:] if line.startswith(("A)", "B)", "C)", "D)"))]
    
    return sections

def map_answer(text, letter):
    sections = extract_sections(text)
    for choice in sections["choices"]:
        if choice.startswith(f"{letter})"):
            return choice
    return letter

In [None]:
# Sytem prompt
LLAMA3_SYSTEM_PROMPT = """You are a helpful AI assistant developed by Meta. Respond safely and accurately."""

def generate_prompt(text, answer_letter):
    sections = extract_sections(text)
    choices_text = "\n".join(sections["choices"])
    
    return [
        {
            "role": "system",
            "content": LLAMA3_SYSTEM_PROMPT,
        },
        {
            "role": "user",
            "content": f"""Read the passage and answer the question.
            ### Passage:
            {sections["passage"]}

            ### Question:
            {sections["question"]}

            ### Choices:
            {choices_text}

            Response with ONLY the letter and full text of the correct answer."""
        },
        {
            "role": "assistant",
            "content": map_answer(text, answer_letter),
        }
    ]
    
# tokenize function will translate the prompt into input_ids, along with the attention mask for efficient training

def generate_and_tokenize_prompt(user_input, answer):
    try:
        full_prompt = generate_prompt(user_input, answer)
        
        prompt_str = tokenizer.apply_chat_template(
            full_prompt,
            tokenize=False,
            add_generation_prompt=False,
        )
        
        tokenized = tokenizer(
            prompt_str,
            padding="max_length",
            truncation=True,
            max_length=1024,
            return_tensors="pt",
        )
        
        input_ids = tokenized["input_ids"][0]
        labels = input_ids.clone()
        
        return {
            "input_ids": input_ids,
            "attention_mask": tokenized["attention_mask"][0],
            "labels": labels,
        }
    except Exception as e:
        print(f"Error processing input: {e}")
        return None

In [None]:
training_samples = []

for sample in tqdm(data['train']):
    try:
        processed_text = sample["text"].replace("SAT READING COMPREHENSION TEST", "").strip()
        
        processed_answer = map_answer(processed_text, sample["answer"].strip())
        
        tokenized_sample = generate_and_tokenize_prompt(processed_text, processed_answer)
        
        if tokenized_sample is not None:
            training_samples.append(tokenized_sample)
    except Exception as e:
        print(f"Error processing sample: {e}")
        
training_samples = [s for s in training_samples if s is not None] # remove None samples

train_samples, val_samples = train_test_split(training_samples, test_size=0.1, random_state=42)

train_dataset = Dataset.from_list(train_samples)
eval_dataset = Dataset.from_list(val_samples)

# 3. Modeling

In [None]:
class LogLossCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is not None and "loss" in logs:
            print(f"Step {state.global_step}: Loss: {logs['loss']:.4f}")
    
training_args = TrainingArguments(
    per_device_train_batch_size=1,
    gradient_accumulation_steps=2,
    num_train_epochs=2,
    learning_rate=2e-4,
    fp16=True,
    save_total_limit=3,
    logging_steps=10,
    output_dir="./llama3-8b-sat-reading",
    optim="paged_adamw_8bit",
    lr_scheduler_type="cosine",
    warmup_ratio=0.05,
    eval_strategy="steps",
    eval_steps=50,
    save_strategy="steps",
    save_steps=50,
    load_best_model_at_end=True,
    metric_for_best_model="loss",
    greater_is_better=False,
    report_to="none",
    remove_unused_columns=False,
)

data_collator = transformers.DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,
    pad_to_multiple_of=8
)

trainer = Trainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    args=training_args,
    data_collator=data_collator,
    callbacks=[LogLossCallback()],
)

model.config.use_cache = False
model.enable_input_require_grads()
model = torch.compile(model)

trainer.train()

In [None]:
model.save_pretrained("./llama3-8b-sat-reading")
PEFT_MODEL = "Savoxism/InstructionTuning-Llama-3.2-8B-SAT-Reading-Solver"

model.push_to_hub(
    PEFT_MODEL,
    use_auth_token=True,
)

# 4. Inference

In [None]:
PEFT_MODEL = "Savoxism/InstructionTuning-Llama-3.2-8B-SAT-Reading-Solver"

config = PeftConfig.from_pretrained(PEFT_MODEL)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
    config.base_model_name_or_path,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
)

tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token

model = PeftModel.from_pretrained(
    model,
    PEFT_MODEL,
    device_map="auto",
)

In [None]:
def format_test_prompt(text):
    sections = extract_sections(text)
    choices_text = "\n".join(sections["choices"])
    
    return [
        {
            "role": "system",
            "content": LLAMA3_SYSTEM_PROMPT,
        },
        {
            "role": "user",
            "content": f"""Read the passage and answer the question.
### Passage:
{sections['passage']}

### Question:
{sections['question']}

### Choices:
{choices_text}

Response with ONLY the letter and full text of the correct answer."""
        }
    ]

def extract_answer(output_text):
    match = re.search(r"([A-D])\)\s.*", output_text)
    if match:
        return match.group(0).strip()  # full line like "A) ... answer text"
    else:
        return output_text.strip()  # fallback in case no match

# GENERATE
generation_config = GenerationConfig(
    max_new_tokens=64,
    temperature=0.0,
    top_p=1.0,
    do_sample=False,
    repetition_penalty=1.0,
    eos_token_id = tokenizer.eos_token_id,
    pad_token_id = tokenizer.eos_token_id,
)

def predict(text):
    messages = format_test_prompt(text)
    
    prompt_text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
    )
    
    inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        output = model.generate(
            input_ids = inputs["input_ids"],
            attention_mask = inputs["attention_mask"],
            generation_config=generation_config,
        )
        
    output_text = tokenizer.decode(output[0], skip_special_tokens=True)
    return extract_answer(output_text)

In [None]:
def extract_choice_letter(answer_text):
    match = re.match(r"([A-D])\)", answer_text.strip())
    if match:
        return match.group(1)
    return ""

for i in range(10):
    print('=' * 100)
    sample = data['test'][i]
    input_text = sample['text']
    true_answer = sample['answer'].strip()
    
    predicted_answer = predict(input_text)
    
    true_answer_full = map_answer(input_text, true_answer)
    
    pred_choice = extract_choice_letter(predicted_answer)
    true_choice = extract_choice_letter(true_answer_full)
    
    print(f"### Sample {i + 1}")
    print(f"[Question]\n{extract_sections(input_text)['question']}")
    print(f"[Choices]\n{extract_sections(input_text)['choices']}")
    print(f"[True Answer]\n{true_answer_full}")
    print(f"[Predicted Answer]\n{predicted_answer}")
    print(f"""\nResult: {"CORRECT" if pred_choice == true_choice else "INCORRECT"}""")
    print("=" * 100)