In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, get_scheduler
import torch

from dotenv import load_dotenv
load_dotenv()

import pickle
import random

import os
from huggingface_hub.hf_api import HfFolder
HfFolder.save_token(os.environ['HUGGINGFACE_TOKEN'])

import datasets

import pandas as pd

In [3]:
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
model.gradient_checkpointing_enable()
tokenizer.pad_token = tokenizer.eos_token

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [7]:
d_basic = pickle.load(open("data_collection/!final_dataset_11194_basic.pkl", "rb"))
d_hints = pickle.load(open("data_collection/!final_dataset_6016_with_hints_after_n_step.pkl", "rb"))
d_snap = pickle.load(open("data_collection/!final_dataset_17476_from_snapshots_fixed.pkl", "rb"))

In [8]:
def split_dataset(dataset, split_ratio=0.01):
    random.shuffle(dataset)
    split_index = int(len(dataset) * split_ratio)
    eval_set = dataset[:split_index]
    train_set = dataset[split_index:]
    return train_set, eval_set

d_basic, d_basic_eval = split_dataset(d_basic)
d_hints, d_hints_eval = split_dataset(d_hints)
d_snap, d_snap_eval = split_dataset(d_snap)

d_eval = d_basic_eval + d_hints_eval + d_snap_eval

In [9]:
# TODO TEMP!
# d_basic = d_basic[:30]
# d_hints = d_hints[:30]
# d_snap = d_snap[:30]
# d_eval = d_eval[:30]

In [10]:
def prepare_data(data):
    prompts = [item['prompt'] for item in data]
    responses = [item['response'] for item in data]
    return datasets.Dataset.from_dict({"prompt": prompts, "response": responses})

def tokenize_function(examples):
    return tokenizer(examples['prompt'], text_target=examples['response'], truncation=True, padding="max_length", max_length=2500)

In [11]:
dataset = prepare_data(d_basic)
ds_basic = dataset.map(tokenize_function, batched=True).shuffle(seed=42)

dataset = prepare_data(d_hints)
ds_hints = dataset.map(tokenize_function, batched=True).shuffle(seed=42)

dataset = prepare_data(d_snap)
ds_snap = dataset.map(tokenize_function, batched=True).shuffle(seed=42)

dataset = prepare_data(d_eval)
ds_eval = dataset.map(tokenize_function, batched=True).shuffle(seed=42)

Map:   0%|          | 0/8297 [00:00<?, ? examples/s]

Map:   0%|          | 0/4467 [00:00<?, ? examples/s]

Map:   0%|          | 0/17302 [00:00<?, ? examples/s]

Map:   0%|          | 0/302 [00:00<?, ? examples/s]

In [4]:
# for name, param in model.named_parameters():
#     if "model.layers." not in name and not "head" in name and not "embed_tokens" in name:
#         param.requires_grad = False
#         continue
#     if "mlp" in name:
#         param.requires_grad = False
#         continue
#     if "norm" in name:
#         param.requires_grad = False
#         continue
#     if "head" in name or "embed_tokens" in name:
#         continue
#     layer_num = int(name.split(".")[2])
#     if layer_num <= 15:
#         param.requires_grad = False
#         continue


for name, param in model.named_parameters():
    if "model.layers." not in name and not "head" in name:
        param.requires_grad = False
        continue
    if "mlp" in name:
        param.requires_grad = False
        continue
    if "norm" in name:
        param.requires_grad = False
        continue
    if "head" in name:
        continue
    layer_num = int(name.split(".")[2])
    if layer_num <= 15:
        param.requires_grad = False
        continue

In [5]:
for name, param in model.named_parameters():
    print(name, param.requires_grad)

model.embed_tokens.weight False
model.layers.0.self_attn.q_proj.weight False
model.layers.0.self_attn.k_proj.weight False
model.layers.0.self_attn.v_proj.weight False
model.layers.0.self_attn.o_proj.weight False
model.layers.0.mlp.gate_proj.weight False
model.layers.0.mlp.up_proj.weight False
model.layers.0.mlp.down_proj.weight False
model.layers.0.input_layernorm.weight False
model.layers.0.post_attention_layernorm.weight False
model.layers.1.self_attn.q_proj.weight False
model.layers.1.self_attn.k_proj.weight False
model.layers.1.self_attn.v_proj.weight False
model.layers.1.self_attn.o_proj.weight False
model.layers.1.mlp.gate_proj.weight False
model.layers.1.mlp.up_proj.weight False
model.layers.1.mlp.down_proj.weight False
model.layers.1.input_layernorm.weight False
model.layers.1.post_attention_layernorm.weight False
model.layers.2.self_attn.q_proj.weight False
model.layers.2.self_attn.k_proj.weight False
model.layers.2.self_attn.v_proj.weight False
model.layers.2.self_attn.o_proj

In [6]:
def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

num_trainable_params = count_trainable_parameters(model)
print(f"Number of trainable parameters: {num_trainable_params}")

Number of trainable parameters: 1196425216


In [12]:
import logging
logging.basicConfig(filename='training_log.txt', level=logging.INFO)

save_epochs = set([0, 4, 9, 14, 24])

In [13]:
def custom_evaluation(model):
    # Placeholder for custom evaluation logic
    return {"custom_metric": 0.0}  # TODO Replace with actual evaluation logic

In [14]:
model = model.train()

In [15]:
cumulative_epoch = 0
for ds, num_epochs in (
    (ds_basic, 10),
    (ds_hints, 10),
    (ds_snap, 5),
):
    for epoch in range(num_epochs):
        training_args = TrainingArguments(
            output_dir="./results",
            num_train_epochs=1,
            per_device_train_batch_size=10,
            per_device_eval_batch_size=10,
            save_strategy="no",
            learning_rate=5e-05,
            weight_decay=0.1,
            bf16=True,
            lr_scheduler_type="cosine_with_restarts",
            warmup_ratio=0.1,
            eval_strategy="steps",
            eval_steps=10,
            logging_steps=10,
        )

        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=ds,
            eval_dataset=ds_eval,
        )

        train_result = trainer.train()
        train_loss = train_result.training_loss
        logging.info(f"Epoch {cumulative_epoch} - Train Loss: {train_loss}")
        print(f"Epoch {cumulative_epoch} - Train Loss: {train_loss}")
        
        eval_result = trainer.evaluate()
        eval_loss = eval_result['eval_loss']
        
        logging.info(f"Epoch {cumulative_epoch} - Eval Loss: {eval_loss}")
        print(f"Epoch {cumulative_epoch} - Eval Loss: {eval_loss}")
        
        eval_metrics = custom_evaluation(model)
        logging.info(f"Epoch {cumulative_epoch} - Custom Eval: {eval_metrics['custom_metric']}")
        print(f"Epoch {cumulative_epoch} - Custom Eval: {eval_metrics['custom_metric']}")
        
        if cumulative_epoch in save_epochs:
            output_dir = f"./results/llama_{cumulative_epoch}"
            os.makedirs(output_dir, exist_ok=True)
            model.save_pretrained(output_dir)
            tokenizer.save_pretrained(output_dir)
        
        print(f"Finished epoch {cumulative_epoch + 1}")
        cumulative_epoch += 1  # Increment the overall epoch counter

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss,Validation Loss
10,11.6,11.595554
20,11.55,11.212341
30,11.05,10.344791
40,9.7,8.804544
50,8.0,6.930135
60,5.9,4.596248


We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)
Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7240059c5f30>>
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 

KeyboardInterrupt

