In [1]:
!nvidia-smi

Mon Feb  3 12:35:33 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.154.05             Driver Version: 535.154.05   CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          On  | 00000000:57:00.0 Off |                    0 |
| N/A   34C    P0              77W / 400W |      0MiB / 81920MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [2]:
import os
import ast
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm.auto import tqdm
from sklearn.utils import shuffle

import torch
from torch.utils.data import Dataset, DataLoader

from transformers import AutoTokenizer, AutoModelForCausalLM, AdamW, get_linear_schedule_with_warmup

In [3]:
SEED = 42
EPOCHS = 5
BATCH_SIZE = 1
VALID_PART = 0.2
MAX_NEW_TOKENS = 1
LEARNING_RATE = 1e-4
WARMUP_STEPS_COEFF = 0.1
DATA_PATH = Path("data")
MODEL_NAME = "microsoft/phi-4"
GRADIENT_ACCUMULATION_STEPS = 4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SYSTEM_PROMPT = "You are an expert in science. Answer the questions. Write only the answer number and nothing else."
PROMPT = "Choose one of the answers. Write down ONLY the NUMBER of the correct answer and nothing else."

np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.enabled=False
torch.backends.cudnn.deterministic=True

In [4]:
data_name = "mmlu_pro_stem.tsv"
tag = "_".join(
    [
        "scheduler", 
        data_name.split(".")[0],
        str(EPOCHS),
        str(BATCH_SIZE),
        str(LEARNING_RATE),
        str(GRADIENT_ACCUMULATION_STEPS),
        str(WARMUP_STEPS_COEFF)
    ]
)

data_path = os.path.join(DATA_PATH, data_name)

df_mmlu = pd.read_csv(data_path, sep="\t")
df_mmlu = shuffle(df_mmlu)
df_mmlu["options"] = df_mmlu["options"].apply(ast.literal_eval)
df_mmlu["answer_index"] = df_mmlu["answer_index"].apply(lambda x: str(x + 1))

def enumerate_question_and_options(line):
    enumerated_variants = "\n".join(
        f"{i + 1}) {option}" for i, option in enumerate(line["options"])
    )
    return f"{line['question']}\n\n{enumerated_variants}"

df_mmlu["input_text"] = df_mmlu.apply(enumerate_question_and_options, axis=1)
df_mmlu = df_mmlu.rename(columns={"answer_index": "output_text"})

train_length = int((1 - VALID_PART) * df_mmlu.shape[0])
df_train = df_mmlu.iloc[:train_length].reset_index(drop=True)
df_valid = df_mmlu.iloc[train_length:].reset_index(drop=True)

print(df_train.shape[0], df_valid.shape[0])
df_train

9625 2407


Unnamed: 0,src,answer,options,category,question,cot_content,question_id,output_text,total_tokens,meta_cluster,base_cluster,input_text
0,stemez-Biology,I,[The gland in the thorax is the only gland inv...,biology,The hormone which brings about metamorphosis o...,,3031,9,262,Miscellaneous,Genetics & Biology,The hormone which brings about metamorphosis o...
1,theoremQA-Physics,C,"[2.50, 3.98, 3.26, 2.75, 5.00, 6.15, 1.92, 4.5...",physics,You wish to put a 1000-kg satellite into a cir...,,9127,3,125,Scientific Calculations,Physics Calculation Questions,You wish to put a 1000-kg satellite into a cir...
2,ori_mmlu-professional_psychology,F,"[high ability and high motivation., low abilit...",psychology,"""According to Hersey and Blanchard’s situation...",,2070,6,80,Psychology Behavior,Psychology Questions (0),"""According to Hersey and Blanchard’s situation..."
3,ori_mmlu-econometrics,I,"[Exactly 1, More than 2, Between 1 and 2, Less...",economics,Consider the estimation of a GARCH-M model. If...,,7180,9,101,Statistical Analysis,Econometrics Tests and Models,Consider the estimation of a GARCH-M model. If...
4,stemez-ElectricCircuits,B,[f(t) = [e^-2t + 2.24e^-2tsin(5t - 26.6°)] u(t...,engineering,Evaluate f(t) if F(s) = [(3s^2 + 17s + 47) / {...,,11617,2,179,Engineering Calculations,Electric Circuit Calculations,Evaluate f(t) if F(s) = [(3s^2 + 17s + 47) / {...
...,...,...,...,...,...,...,...,...,...,...,...,...
9620,ori_mmlu-philosophy,A,[do that which is good and not to approve of i...,philosophy,"According to Butler, it is impossible to:",,11063,1,113,Legal & Moral Implications,Moral Disputes,"According to Butler, it is impossible to:\n\n1..."
9621,ori_mmlu-professional_law,C,[Albert did not intentionally make the mistake...,law,Albert Attorney was a solo practitioner in Lit...,,1702,3,309,Legal & Moral Implications,Legal Contracts & Property,Albert Attorney was a solo practitioner in Lit...
9622,ori_mmlu-high_school_macroeconomics,D,[increase the consumer price index and the GDP...,economics,An increase in the price of forklifts imported...,,7557,4,79,Economics & Finance MCQs,Economic Concepts & Policies,An increase in the price of forklifts imported...
9623,ori_mmlu-computer_security,C,[Given H(k \| m)H(k∥m) anyone can compute H(w ...,computer science,Let HH be a Merkle-Damgard hash function is H:...,,10637,3,217,CS Subfield Queries,Computer Science Questions,Let HH be a Merkle-Damgard hash function is H:...


In [5]:
class LLMDataset(Dataset):
    def __init__(self, df):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        return {
            "input_text": row["input_text"],
            "output_text": row["output_text"]
        }

In [6]:
train_dataset = LLMDataset(df_train)
valid_dataset = LLMDataset(df_valid)

In [7]:
train_dataloader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=lambda b: collate_fn(b, tokenizer)
)
valid_dataloader = DataLoader(
    valid_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=lambda b: collate_fn(b, tokenizer)
)

In [8]:
def collate_fn(batch, tokenizer):

    input_prompts = []
    output_texts = []
    
    for item in batch:
        messages = [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": PROMPT + "\n\n" + item["input_text"]},
        ]

        formatted_prompt = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )
        input_prompts.append(formatted_prompt)
        output_texts.append(item["output_text"])

    joined_texts = [
        ip + ot for ip, ot in zip(input_prompts, output_texts)
    ]
    
    tokens = tokenizer(
        joined_texts,
        return_tensors="pt",
        padding=True,
        truncation=True
    )
    
    input_ids = tokens["input_ids"]
    attention_mask = tokens["attention_mask"]

    labels = input_ids.clone()
    
    prompt_lens = []
    for ip in input_prompts:
        tok_prompt = tokenizer(ip, add_special_tokens=False)
        prompt_lens.append(len(tok_prompt["input_ids"]))
    
    for i, p_len in enumerate(prompt_lens):
        labels[i, :p_len] = -100

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

In [9]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE)

Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

In [10]:
for param in model.parameters():
    param.requires_grad = False

for layer in model.model.layers[36:]:
    for param in layer.self_attn.parameters():
        param.requires_grad = True

In [11]:
def validate(model, val_dataloader, tokenizer):
    model.eval()
    total = 0
    correct = 0
    
    with torch.no_grad():
        for batch in tqdm(val_dataloader, desc="Validating", leave=False):

            for i in range(batch["input_ids"].shape[0]):
                input_ids_sample = batch["input_ids"][i].unsqueeze(0).to(DEVICE)
                label_ids_sample = batch["labels"][i]
                
                ref_text_indices = label_ids_sample[label_ids_sample != -100]
                ref_text = tokenizer.decode(ref_text_indices, skip_special_tokens=True)

                prompt_indices = batch["input_ids"][i][batch["labels"][i] == -100]
                prompt_text = tokenizer.decode(prompt_indices, skip_special_tokens=True)

                messages = [
                    {"role": "system", "content": SYSTEM_PROMPT},
                    {"role": "user", "content": prompt_text},
                ]
                formatted_prompt = tokenizer.apply_chat_template(
                    messages,
                    tokenize=False,
                    add_generation_prompt=True
                )
                
                inputs = tokenizer(formatted_prompt, return_tensors="pt").to(DEVICE)
                
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=MAX_NEW_TOKENS,
                    pad_token_id=tokenizer.eos_token_id
                )
                
                input_length = inputs.input_ids.shape[1]
                generated_text = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True)
                
                total += 1
                if generated_text.strip() == ref_text.strip():
                    correct += 1
                # print("Ref: ", ref_text.strip())
                # print("Gen: ", generated_text.strip())

    accuracy = correct / total if total > 0 else 0
    return accuracy

In [12]:
def train_one_epoch(model, train_dataloader, optimizer, scheduler, gradient_accumulation_steps):
    model.train()

    epoch_loss = 0
    optimizer.zero_grad()
    for i, batch in enumerate(tqdm(train_dataloader, desc="Training", leave=False)):
        input_ids = batch["input_ids"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)
        labels = batch["labels"].to(DEVICE)

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        raw_loss = outputs.loss
        loss = raw_loss / gradient_accumulation_steps
        loss.backward()
        epoch_loss += raw_loss.item()

        # Perform an optimizer step after accumulating the gradients
        if (i + 1) % gradient_accumulation_steps == 0 or (i + 1) == len(train_dataloader):
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

    return epoch_loss / len(train_dataloader)

In [13]:
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=LEARNING_RATE)

total_training_steps = (
    ((len(train_dataloader) + GRADIENT_ACCUMULATION_STEPS - 1) // GRADIENT_ACCUMULATION_STEPS)
    * EPOCHS
)
warmup_steps = int(total_training_steps * WARMUP_STEPS_COEFF)

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_training_steps
)



In [None]:
train_loss_history = []
valid_acc_history = []

best_acc = - float("inf")
for epoch in range(EPOCHS):
    print(f"=== Epoch {epoch+1}/{EPOCHS} ===")

    val_acc = validate(model, valid_dataloader, tokenizer)
    valid_acc_history.append(val_acc)
    print(f"Validation Accuracy: {val_acc:.4f}")

    if val_acc > best_acc:
        best_acc = val_acc
        model.save_pretrained("weights_and_histories")
    
    avg_train_loss = train_one_epoch(model, train_dataloader, optimizer, scheduler, GRADIENT_ACCUMULATION_STEPS)
    train_loss_history.append(avg_train_loss)
    print(f"Train Loss: {avg_train_loss:.4f}")

    df_history = pd.DataFrame(
        {
            "epoch": list(range(epoch + 1)),
            "train_loss": train_loss_history,
            "valid_acc": valid_acc_history,
        }
    )
    df_history.to_excel(os.path.join("weights_and_histories", f"{tag}.xlsx"), index=False)

=== Epoch 1/5 ===


Validating:   0%|          | 0/2407 [00:00<?, ?it/s]