In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, get_scheduler
from trl import SFTTrainer, SFTConfig
import torch

from dotenv import load_dotenv
load_dotenv()

import json
import pickle
import random

from utils_simulate import run_games

import os
from huggingface_hub.hf_api import HfFolder
HfFolder.save_token(os.environ['HUGGINGFACE_TOKEN'])

import datasets

import pandas as pd

# from transformers.utils import logging
# logging.set_verbosity_info()

import logging
from transformers.utils import logging as transformers_logging

class CustomFilter(logging.Filter):
    def filter(self, record):
        return 'right-padding was detected' not in record.getMessage()

logger = logging.getLogger('transformers')

for handler in logger.handlers:
    handler.addFilter(CustomFilter())

In [2]:
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

# tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='left')
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='right')
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
model.gradient_checkpointing_enable()
id_eot = tokenizer.convert_tokens_to_ids(["<|eot_id|>"])[0]

tokenizer.pad_token_id = id_eot
model.config.pad_token_id = id_eot
# tokenizer.add_special_tokens({"pad_token": "<|padding_token|>"})
# model.config.pad_token_id = tokenizer.pad_token_id

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [18]:
tokenizer.pad_token, tokenizer.pad_token_id

('<|padding_token|>', 128256)

In [4]:
id_eot

128009

In [5]:
d_basic = pickle.load(open("data_collection/!final_dataset_11194_basic.pkl", "rb"))
d_hints = pickle.load(open("data_collection/!final_dataset_6016_with_hints_after_n_step.pkl", "rb"))
d_snap = pickle.load(open("data_collection/!final_dataset_17476_from_snapshots_fixed.pkl", "rb"))

In [6]:
def split_dataset(dataset, split_ratio=0.05):
    random.shuffle(dataset)
    split_index = int(len(dataset) * split_ratio)
    eval_set = dataset[:split_index]
    train_set = dataset[split_index:]
    return train_set, eval_set

d_basic, d_basic_eval = split_dataset(d_basic)
d_hints, d_hints_eval = split_dataset(d_hints)
d_snap, d_snap_eval = split_dataset(d_snap)

d_eval = d_basic_eval + d_hints_eval + d_snap_eval

In [7]:
# # TODO TEMP!
# d_basic = d_basic[:105]
# d_hints = d_hints[:30]
# d_snap = d_snap[:30]
# d_eval = d_eval[:30]

In [8]:
# def prepare_data(data):
#     prompts = [item['prompt'] for item in data]
#     responses = [item['response'] for item in data]
#     return datasets.Dataset.from_dict({"prompt": prompts, "response": responses})


def prepare_inputs(chats: list[dict]):
    r = tokenizer.apply_chat_template(
        [
            [
                {"role": "user", "content": o['prompt']},
                {"role": "assistant", "content": o['response']}
            ] for o in chats
        ],
        tokenize=True,
        return_dict=True,
    )
    return datasets.Dataset.from_dict(r).shuffle(seed=42)


# def tokenize_function(examples):
    # return tokenizer(examples['prompt'], text_target=examples['response'], truncation=True, padding="max_length", max_length=2500)

In [9]:
ds_basic = prepare_inputs(d_basic)
print("ds_basic", len(ds_basic))

ds_hints = prepare_inputs(d_hints)
print("d_hints", len(d_hints))

ds_snap = prepare_inputs(d_snap)
print("ds_snap", len(ds_snap))

ds_eval = prepare_inputs(d_eval)
print("ds_eval", len(ds_eval))

ds_basic 7961
d_hints 4287
ds_snap 16603
ds_eval 1517


In [10]:
# for name, param in model.named_parameters():
#     if "model.layers." not in name and not "head" in name and not "embed_tokens" in name:
#         param.requires_grad = False
#         continue
#     if "mlp" in name:
#         param.requires_grad = False
#         continue
#     if "norm" in name:
#         param.requires_grad = False
#         continue
#     if "head" in name or "embed_tokens" in name:
#         continue
#     layer_num = int(name.split(".")[2])
#     if layer_num <= 26:
#         param.requires_grad = False
#         continue

for name, param in model.named_parameters():
    if "model.layers." not in name and not "head" in name:
        param.requires_grad = False
        continue
    if "mlp" in name:
        param.requires_grad = False
        continue
    if "norm" in name:
        param.requires_grad = False
        continue
    if "head" in name:
        continue
    layer_num = int(name.split(".")[2])
    if layer_num <= 20:
        param.requires_grad = False
        continue

In [11]:
for name, param in model.named_parameters():
    print(name, param.requires_grad)

model.embed_tokens.weight False
model.layers.0.self_attn.q_proj.weight False
model.layers.0.self_attn.k_proj.weight False
model.layers.0.self_attn.v_proj.weight False
model.layers.0.self_attn.o_proj.weight False
model.layers.0.mlp.gate_proj.weight False
model.layers.0.mlp.up_proj.weight False
model.layers.0.mlp.down_proj.weight False
model.layers.0.input_layernorm.weight False
model.layers.0.post_attention_layernorm.weight False
model.layers.1.self_attn.q_proj.weight False
model.layers.1.self_attn.k_proj.weight False
model.layers.1.self_attn.v_proj.weight False
model.layers.1.self_attn.o_proj.weight False
model.layers.1.mlp.gate_proj.weight False
model.layers.1.mlp.up_proj.weight False
model.layers.1.mlp.down_proj.weight False
model.layers.1.input_layernorm.weight False
model.layers.1.post_attention_layernorm.weight False
model.layers.2.self_attn.q_proj.weight False
model.layers.2.self_attn.k_proj.weight False
model.layers.2.self_attn.v_proj.weight False
model.layers.2.self_attn.o_proj

In [12]:
def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

num_trainable_params = count_trainable_parameters(model)
print(f"Number of trainable parameters: {num_trainable_params}")

Number of trainable parameters: 986710016


In [13]:
import logging
logging.basicConfig(filename='training_log.txt', level=logging.INFO)

# save_epochs = set([9, 19, 24])
save_epochs = set([0, 1, 3, 6])

In [14]:
# eval_set = json.load(open("data/eval_set.json", "r"))[:10] # TODO!!!!
eval_set = json.load(open("data/eval_set.json", "r"))
def custom_evaluation(model, epoch, short=True):
    if short:
        keywords = [e['keyword'] for e in eval_set if e['difficulty'] == "Easy"]
    else:
        keywords = [e['keyword'] for e in eval_set]

    games = run_games(
        keywords,
        tokenizer,
        model,
        id_eot,
        batch_size=15,
    )
    if not short:
        pickle.dump(games, open(f"full_eval_{epoch}.pkl", "wb"))
    return sum([g.win for g in games]) / len(games)

In [15]:
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer, max_length=2500)

In [16]:
model = model.train()

In [17]:
cumulative_epoch = 0
for ds, num_epochs in (
    (ds_basic, 2),
    (ds_hints, 3),
    (ds_snap, 2),
):
    for epoch in range(num_epochs):
        training_args = SFTConfig(
            output_dir="./results",
            num_train_epochs=1,
            per_device_train_batch_size=10,
            per_device_eval_batch_size=10,
            save_strategy="no",
            learning_rate=2e-05,
            weight_decay=0.2,
            bf16=True,
            lr_scheduler_type="cosine_with_restarts",
            warmup_ratio=0.1,
            eval_strategy="steps",
            eval_steps=210,
            logging_steps=210,
            neftune_noise_alpha=5,
            max_seq_length=3750,
        )

        trainer = SFTTrainer(
            model=model,
            args=training_args,
            train_dataset=ds,
            eval_dataset=ds_eval,
            tokenizer=tokenizer,
            # data_collator=data_collator,
        )

        train_result = trainer.train()
        train_loss = train_result.training_loss
        logging.info(f"Epoch {cumulative_epoch} - Train Loss: {train_loss}")
        print(f"Epoch {cumulative_epoch} - Train Loss: {train_loss}")
        
        eval_result = trainer.evaluate()
        eval_loss = eval_result['eval_loss']
        
        logging.info(f"Epoch {cumulative_epoch} - Eval Loss: {eval_loss}")
        print(f"Epoch {cumulative_epoch} - Eval Loss: {eval_loss}")
        
        eval_metrics = custom_evaluation(model, cumulative_epoch)
        logging.info(f"Epoch {cumulative_epoch} - Custom Eval: {eval_metrics}")
        print(f"Epoch {cumulative_epoch} - Custom Eval: {eval_metrics}")
        
        if cumulative_epoch in save_epochs:
            output_dir = f"./results/llama_{cumulative_epoch}"
            os.makedirs(output_dir, exist_ok=True)
            model.save_pretrained(output_dir)
            tokenizer.save_pretrained(output_dir)
            eval_metrics = custom_evaluation(model, cumulative_epoch, short=False)
            logging.info(f"Epoch {cumulative_epoch} - Custom Eval (full): {eval_metrics}")
            print(f"Epoch {cumulative_epoch} - Custom Eval (full): {eval_metrics}")
        
        print(f"Finished epoch {cumulative_epoch + 1}")
        cumulative_epoch += 1  # Increment the overall epoch counter

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
/opt/conda/conda-bld/pytorch_1708025847130/work/aten/src/ATen/native/cuda/Indexing.cu:1290: indexSelectLargeIndex: block: [94,0,0], thread: [96,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/opt/conda/conda-bld/pytorch_1708025847130/work/aten/src/ATen/native/cuda/Indexing.cu:1290: indexSelectLargeIndex: block: [94,0,0], thread: [97,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/opt/conda/conda-bld/pytorch_1708025847130/work/aten/src/ATen/native/cuda/Indexing.cu:1290: indexSelectLargeIndex: block: [94,0,0], thread: [98,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/opt/conda/conda-bld/pytorch_1708025847130/work/aten/src/ATen/native/cuda/Indexing.cu:1290: indexSelectLargeIndex: block: [94,0,0], thread: [99,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/opt/conda/conda-bld/pytorch_1708025847130/work/aten/src/ATen/native/cuda/Indexing.cu:1290: indexSelectLargeIndex: block: [

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
