In [1]:
from speedy import *

In [2]:
import gc
import os

import torch
from datasets import load_dataset
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    pipeline,
)
from trl import ORPOConfig, ORPOTrainer, setup_chat_format

# Model
base_model = "/public-llm/Meta-Llama-3-8B-Instruct/"
new_model = "OrpoLlama-3-8B"
# Defined in the secrets tab in Google Colab
# Set torch dtype and attention implementation
torch_dtype = torch.bfloat16
attn_implementation = "flash_attention_2"
# QLoRA config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch_dtype,
    bnb_4bit_use_double_quant=True,
)

# LoRA config
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model)

# Load model
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    # quantization_config=bnb_config,
    device_map='cpu',
    # attn_implementation=attn_implementation
)
model, tokenizer = setup_chat_format(model, tokenizer)
# model = prepare_model_for_kbit_training(model)

2024-04-26 17:07:48.344275: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-04-26 17:07:48.393140: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 AVX512_FP16 AVX_VNNI AMX_TILE AMX_INT8 AMX_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


[2024-04-26 17:07:49,974] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [11]:
import transformers
from llm_utils import transform_messages, load_chat_dataset
from copy import deepcopy
from transformers.trainer_pt_utils import LabelSmoother
IGNORE_TOKEN_ID = LabelSmoother.ignore_index

# TEMPLATE = "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
TEMPLATE = "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{% endif %}{% if loop.last %}{% else %}{% endif %}{% endfor %}"

def format_input_messages(
    messages:List[dict[str,str]],
    tokenizer: transformers.PreTrainedTokenizer,
    max_len: int = None,
    target_loss_only=False,
) -> Dict:
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    assert hasattr(tokenizer, 'pad_token_id'), "Tokenizer must have pad_token_id attribute"
    input_ids = []
    target_ids = []

    input_id = []
    target_id = []
    for i in range(len(messages)):
        _ids = tokenizer.apply_chat_template(
            [messages[i]], tokenize=True, add_special_tokens=False)
        input_id += _ids
        if target_loss_only and  messages[i]['role'] == 'assistant':
            target_id += _ids
        else:
            target_id += [IGNORE_TOKEN_ID]*len(_ids)

    # maxlen
    input_id = input_id[:max_len]
    target_id = target_id[:max_len]
    # to tensor
    input_ids = torch.tensor([input_id], dtype=torch.long)
    target_ids = torch.tensor([target_id], dtype=torch.long)
    attention_mask = input_ids.ne(tokenizer.pad_token_id)
    return dict(
        input_ids=input_ids, target_ids=target_ids, attention_mask=attention_mask, labels=target_ids,
        length=len(input_ids),
        num_train_tokens=target_ids.ne(IGNORE_TOKEN_ID).sum().item(),
    )

In [12]:
from datasets import Dataset
list_msgs = load_chat_dataset('/anhvth5/data/chat-formated-dataset/dataset_factory/get_oasst2_chatml_13848_samples.json')
max_len = 8000


[32m2024-04-26 17:10:10.030[0m | [1mINFO    [0m | [36mllm_utils.load_chat_dataset[0m:[36mload_chat_dataset[0m:[36m12[0m - [1mConverting the /anhvth5/data/chat-formated-dataset/dataset_factory/get_oasst2_chatml_13848_samples.json from sharegpt to chatml.[0m


In [13]:
# formated_msg = [format_input_messages(item, tokenizer, max_len) for item in list_msgs]
inputs = [[item, tokenizer, max_len] for item in list_msgs]
def f(input):
    return format_input_messages(*input)
f(inputs[0])
formated_msg = multi_thread(f, inputs, 40)

Processing: 100%|██████████| 13848/13848 [00:01<00:00, 13592.36it/s] 


In [14]:
import random
from speedy import Clock

def group_batches_by_sequence_length(items, bz, shuffle=True, drop_last=True):
    batches = []
    items = sorted(items, key=lambda x: len(x[0]))
    for i in range(0, len(items), bz):
        batch = items[i : i + bz]
        if len(batch) < bz and drop_last:
            continue
        batches.append(batch)
    if shuffle:
        random.shuffle(batches)
    # flatten
    batches = [item for sublist in batches for item in sublist]
    return batches

def get_available_sequence_lengths(sorted_lengths, len_left, first_item_len):
    cond1 = sorted_lengths <= len_left
    cond2 = np.abs(sorted_lengths - first_item_len) < 64 if first_item_len is not None else None
    available_lengths = sorted_lengths[np.logical_and(cond1, cond2) if cond2 is not None else cond1]
    return available_lengths

def update_length_to_indexes_mapping(len_to_indexes, chosen_length, sorted_lengths):
    if len(len_to_indexes[chosen_length]) == 0:
        del len_to_indexes[chosen_length]
        sorted_lengths = sorted_lengths[sorted_lengths != chosen_length]
    return len_to_indexes, sorted_lengths

def create_batches_with_split_points(item_metas, max_length):
    len_to_indexes = {}
    for item_meta in item_metas:
        length = item_meta["length"]
        index = item_meta["idx"]
        if item_meta["num_train_token"] == 0:
            continue
        if length <= max_length:
            if length not in len_to_indexes:
                len_to_indexes[length] = []
            len_to_indexes[length].append(index)

    batches_with_split_points = []
    sorted_lengths = np.array(sorted(len_to_indexes.keys()))
    while len(len_to_indexes) > 0:
        current_batch_indexes = []
        current_split_points = []
        len_left = max_length
        accumulated_length = 0
        first_item_len = None
        max_item_len = 0

        while len(len_to_indexes) > 0:
            available_lengths = get_available_sequence_lengths(sorted_lengths, len_left, first_item_len)
            if len(available_lengths) == 0:
                break
            chosen_length = random.choice(available_lengths)
            _id = np.random.choice(len(len_to_indexes[chosen_length]))
            chosen_index = len_to_indexes[chosen_length].pop(_id)
            if first_item_len is None:
                first_item_len = chosen_length
                max_item_len = max(max_item_len, chosen_length)
            current_batch_indexes.append(chosen_index)
            accumulated_length += chosen_length
            current_split_points.append(accumulated_length)
            len_left = max_length - max_item_len * len(current_batch_indexes) * 1.2
            len_to_indexes, sorted_lengths = update_length_to_indexes_mapping(len_to_indexes, chosen_length, sorted_lengths)

        if current_batch_indexes:
            batches_with_split_points.append((current_batch_indexes, current_split_points))
            first_item_len = None

    return batches_with_split_points

def create_chunks_with_train_tokens(item_metas, max_length, num_gpus, accumulate_steps, target_loss_only=False):
    # check
    if target_loss_only:
        assert 'num_train_token' in item_metas[0], 'num_train_token is required in item_metas when target_loss_only is True'
        idx_to_num_train_tokens = {item["idx"]: item["num_train_token"] for item in item_metas}
    elif not 'num_train_token' in item_metas[0]:
        logger.warning('num_train_token is not found in item_metas. Using length instead')
        idx_to_num_train_tokens = {item["idx"]: item["length"] for item in item_metas}
    else:
        assert 'num_train_token' in item_metas[0], 'num_train_token is required in item_metas when target_loss_only is False'
        idx_to_num_train_tokens = {item["idx"]: item["num_train_token"] for item in item_metas}

    batches_with_split_points = create_batches_with_split_points(item_metas, max_length)

    batches = []
    batches_with_split_points = group_batches_by_sequence_length(
        batches_with_split_points, num_gpus, shuffle=True, drop_last=False
    )
    global_bz = num_gpus * accumulate_steps
    for i in range(0, len(batches_with_split_points), global_bz):
        global_batch_items = batches_with_split_points[i : i + global_bz]
        _all_ids_flat = [item for sublist in global_batch_items for item in sublist[0]]
        num_train_tokens_total = sum([idx_to_num_train_tokens[idx] for idx in _all_ids_flat])
        avg_train_token_in_this_global_batch = num_train_tokens_total / global_bz
        new_data = []
        for ids, split_points in global_batch_items:
            train_tokens = sum([idx_to_num_train_tokens[idx] for idx in ids])
            loss_scale_factor = train_tokens / avg_train_token_in_this_global_batch
            new_data.append({'item_ids': ids, 'split_points': split_points, 'loss_scale_factor': loss_scale_factor})
        batches.append(new_data)
    batches_with_split_points = [item for sublist in batches for item in sublist]
    return batches_with_split_points

In [18]:
df = pd.DataFrame(formated_msg)


In [40]:
df

Unnamed: 0,input_ids,target_ids,attention_mask,labels,length,num_train_tokens,idx,num_train_token
7713,"[[tensor(128256), tensor(9125), tensor(198), t...","[[tensor(-100), tensor(-100), tensor(-100), te...","[[tensor(True), tensor(True), tensor(True), te...","[[tensor(-100), tensor(-100), tensor(-100), te...",1,0,7713,27
12715,"[[tensor(128256), tensor(9125), tensor(198), t...","[[tensor(-100), tensor(-100), tensor(-100), te...","[[tensor(True), tensor(True), tensor(True), te...","[[tensor(-100), tensor(-100), tensor(-100), te...",1,0,12715,27
2426,"[[tensor(128256), tensor(9125), tensor(198), t...","[[tensor(-100), tensor(-100), tensor(-100), te...","[[tensor(True), tensor(True), tensor(True), te...","[[tensor(-100), tensor(-100), tensor(-100), te...",1,0,2426,28
12219,"[[tensor(128256), tensor(9125), tensor(198), t...","[[tensor(-100), tensor(-100), tensor(-100), te...","[[tensor(True), tensor(True), tensor(True), te...","[[tensor(-100), tensor(-100), tensor(-100), te...",1,0,12219,29
8864,"[[tensor(128256), tensor(9125), tensor(198), t...","[[tensor(-100), tensor(-100), tensor(-100), te...","[[tensor(True), tensor(True), tensor(True), te...","[[tensor(-100), tensor(-100), tensor(-100), te...",1,0,8864,30
...,...,...,...,...,...,...,...,...
10058,"[[tensor(128256), tensor(9125), tensor(198), t...","[[tensor(-100), tensor(-100), tensor(-100), te...","[[tensor(True), tensor(True), tensor(True), te...","[[tensor(-100), tensor(-100), tensor(-100), te...",1,0,10058,4380
13775,"[[tensor(128256), tensor(9125), tensor(198), t...","[[tensor(-100), tensor(-100), tensor(-100), te...","[[tensor(True), tensor(True), tensor(True), te...","[[tensor(-100), tensor(-100), tensor(-100), te...",1,0,13775,4948
13257,"[[tensor(128256), tensor(9125), tensor(198), t...","[[tensor(-100), tensor(-100), tensor(-100), te...","[[tensor(True), tensor(True), tensor(True), te...","[[tensor(-100), tensor(-100), tensor(-100), te...",1,0,13257,4973
1248,"[[tensor(128256), tensor(9125), tensor(198), t...","[[tensor(-100), tensor(-100), tensor(-100), te...","[[tensor(True), tensor(True), tensor(True), te...","[[tensor(-100), tensor(-100), tensor(-100), te...",1,0,1248,5062


In [39]:
# print(tokenizer.decode(df.iloc[506]['input_ids'][0]))