### Preparation
Install required libraries

In [None]:
!pip install torch==2.2.1 bitsandbytes transformers datasets peft easydict

Import the libraries that we will be using

In [None]:
import os, re, json, random, time, huggingface_hub, transformers, torch, gc
from datetime import datetime

from typing import Any, Dict, List, Tuple, Union
import numpy as np, pandas as pd

from easydict import EasyDict # Substitute for the argparse library for Colab settings
from accelerate import Accelerator
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, BatchEncoding, LlamaTokenizer, DataCollatorForSeq2Seq

from peft import prepare_model_for_kbit_training
from peft import LoraConfig, get_peft_model, PeftModel
from google.colab import drive
drive.mount('/content/drive')

Set the directories in which we will be saving the fine-tuned LoRA adapters and merged models

In [None]:
LORA_ADAPTERS_DIR = '/content/drive/MyDrive/LoRA-Adapters'
MERGED_MDOELS_DIR = '/content/drive/MyDrive/Finetuned-Models'

In [None]:
INSTRUCTION_PROMPTS = {
    'open_qa' : """### Instruction:
Answer the question below. Be as specific and concise as possible.

### Question:
{instruction}

### Response:
{response}""",

    'general_qa' : """### Instruction:
Answer the question below to the best of your konwledge.

### Question:
{instruction}

### Response:
{response}""",


    'classification' : """### Instruction:
You will be given a question and a list of potential answers to that question. You are to select the correct answers out of the available choices.

### Question:
{instruction}

### Response:
{response}""",

    'closed_qa' : """### Instruction:
You will be given a question to answer and context that contains pertinent information. Provide a concise and accurate response to the question using the information provided in the context.

### Question:
{instruction}

### Context:
{context}

### Response:
{response}""",

    'brainstorming' : """### Instruction:
You will be given a question that does not have a correct answer. You are to brainstorm one possible answer to the provided question.

### Question:
{instruction}

### Response:
{response}""",

    'information_extraction' : """### Instruction:
You will be given a question or query and some context that can be used to answer it. You are to extract relevant information from the provided context to provide an accurate response to the given query.

### Question:
{instruction}

### Context:
{context}

### Response:
{response}""",

    'summarization' : """### Instruction:
You will be given a question or request and context that can be used for your response. You are to summarize the provided context to provide an answer to the question.

### Question:
{instruction}

### Context:
{context}

### Response:
{response}""",

    'creative_writing' : """### Instruction:
You will be given a prompt that you are to write about. Be creative.

### Prompt:
{instruction}

### Response:
{response}"""

}

USE_CONTEXT = {
    'open_qa' : False,
    'general_qa' : False,
    'classification' : False,
    'closed_qa' : True,
    'brainstorming' : False,
    'information_extraction' : True,
    'summarization' : True,
    'creative_writing' : False
}

RESPONSE_KEY = """### Response:
"""

In [None]:
def get_config():
    args = EasyDict()
    args.pretrained_model_name = 'TinyLlama/TinyLlama-1.1B-Chat-v1.0' # Using a smaller LLM due to memory constraints
    args.valid_ratio = 0.025
    args.max_length = 2048
    args.num_train_epochs = 0.1                                     #### Reduced for time constraints
    args.batch_size_per_device = 1                                    #### Set to 1 to reduce memory usage, modify if resources allow
    args.gradient_accumulation_steps = 4
    args.learning_rate = 1e-5
    args.save_total_limit = 3
    args.min_warmup_steps = 1000
    args.warmup_ratio = 0.1
    args.steps_between_save_eval_log = 50
    args.lora_r = 4                                                   #### Reduced for memory constraints
    args.lora_alpha = 4                                               #### Reduced for memory constraints
    args.lora_dropout = 0.05

    return args

In [None]:
def get_databricks_dataset(
    tokenizer : LlamaTokenizer,
    max_length = 2048,
    input_prompts = INSTRUCTION_PROMPTS,
    use_context = USE_CONTEXT,
    valid_ratio=0.05,
):
    if not tokenizer.pad_token_id:
        tokenizer.pad_token_id = tokenizer.eos_token_id
        tokenizer.pad_token = tokenizer.eos_token

    dataset = load_dataset('databricks/databricks-dolly-15k', split='train', trust_remote_code=True)

    # Apply an instruction template to each of the data in the dataset, in accordance with the data's "category"
    def format_instructions(data):
        return {
            'instruction_prompt' :
            input_prompts[data['category']].format(instruction = data['instruction'], context = data['context'], response = data['response'])
            if
            use_context[data['category']]
            else
            input_prompts[data['category']].format(instruction = data['instruction'], response = data['response'])
        }

    dataset = dataset.map(format_instructions, batched = False)

    # Determine the starting index for the substring corresponding to the labels
    def label_str_idx(data):
        return { 'label_str_idx' : len(RESPONSE_KEY) + data['instruction_prompt'].rfind(RESPONSE_KEY) }
    dataset = dataset.map(label_str_idx, batched = False)

    # Tokenize the formatted instruction tuning data and return offset_mapping to be used to determine label token ids
    def tokenize_instructions(batch):
        return tokenizer(batch['instruction_prompt'], return_offsets_mapping = True)
    dataset = dataset.map(tokenize_instructions, batched = True)

    # Use the starting index and offset mapping to determine which token ids correspond to the labels
    def label_token_ids(data):
        label_tok_len = sum(list(map(lambda x: int(x[0] >= data['label_str_idx']), data['offset_mapping'])))
        return { 'labels' : [-100] * (len(data['input_ids']) - label_tok_len) + data['input_ids'][-label_tok_len:] }
    dataset = dataset.map(label_token_ids, batched = False)

    # Remove unneeded columns
    dataset = dataset.remove_columns(['instruction', 'context', 'response', 'category', 'instruction_prompt', 'label_str_idx', 'offset_mapping'])

    # Filter out data whose labels do not match the expected labels
    def validate_input_output(data):
        return tokenizer.decode(data['input_ids']).split(RESPONSE_KEY)[-1] == tokenizer.decode(list(filter(lambda x: x != -100, data['labels'])))

    print(f"Original dataset length : {str(len(dataset))}")

    # Filter out data whose tokenized length exceeds the max length of the tokenizer
    dataset = dataset.filter(lambda data: len(data['input_ids']) <= max_length, batched = False)

    # Filter out data whose decoded label does not match expected label
    dataset = dataset.filter(lambda data : validate_input_output(data), batched = False)
    print(f"Filtered dataset length : {str(len(dataset))}")

    dataset = dataset.shuffle(seed = 42)

    dataset = dataset.train_test_split(test_size = valid_ratio)
    return dataset['train'], dataset['test']

In [None]:
def get_now():
    return datetime.now().strftime("%Y%m%d-%H%M%S")

def main(config):

    finetuned_model_name = f"{str(config.pretrained_model_name.split('/')[-1])}-{str(config.num_train_epochs)}E-QLoRA-Databricks-SFT-Test"

    # Initialize tokenizer, manually add pad token if model does not have one by default
    tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model_name, padding_side = 'left')
    if not tokenizer.pad_token_id:
        tokenizer.pad_token_id = tokenizer.eos_token_id
        tokenizer.pad_token = tokenizer.eos_token

    # Retrieve databricks dataset as HuggingFace Dataset after pre-processing for labelling
    train_data, valid_data = get_databricks_dataset(
        tokenizer,
        max_length = config.max_length,
        valid_ratio=config.valid_ratio,
    )
    print("< Retrieved and formatted Databricks dataset >")

    # Get BitsAndBytesConfig for quantization
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
    )

    # Load model prepared for 8-bit quantization
    model = AutoModelForCausalLM.from_pretrained(
        config.pretrained_model_name,             # Pre-trained model name on HuggingFace
        quantization_config=quantization_config,  # Enable quantization with BitsAndBytesConfig defined above
        device_map = 'auto',
        trust_remote_code=True,                   # Some models require this option
    )

    # Enable gradient checkpointing, must pass keyword argument "use_reentrant" : False to avoid errors when using multiple GPUs
    model.gradient_checkpointing_enable(
        gradient_checkpointing_kwargs={'use_reentrant' : False},
    )

    # Prepare model for k-bit training
    model = prepare_model_for_kbit_training(model)

    # Configurations for Low-Rank Adaptation (https://huggingface.co/docs/peft/main/en/conceptual_guides/lora)
    l_config = LoraConfig(
        r=config.lora_r,                                       # Rank  : parameter that determines the rank of the decomposed matrices used in LoRA
        lora_alpha=config.lora_alpha,                          # Alpha : parameter that determines the scaling factor for LoRA
        target_modules=["q_proj","k_proj","v_proj","o_proj"],  # Set target modules for LoRA, check model achitectue for module names (print(AutoModelForCausalLM.from_pretrained()))
        lora_dropout=config.lora_dropout,                      # The dropout probability for LoRA layers
        bias="none",                                           # Specifies whether bias parameters should be trained
        task_type="CAUSAL_LM",                                 # The source code for LoraConfig does not seem to utilize this particular parameter (https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/config.py#L43)
        use_rslora = True                                      # When true, sets scaling factor to [alpha / sqrt(rank)] instead of default [alpha / rank]
    )

    # Get PEFT model with the LoRA configurations above
    model = get_peft_model(model, l_config)
    print("< Retrieved PEFT model >")

    # Customize training arguments
    training_args = transformers.TrainingArguments(
        output_dir = os.getcwd(),
        # bf16=True,                                                         ##### T4 does not support brain float (https://www.aewin.com/application/bfloat16-a-brief-intro/) #####
        # bf16_full_eval=True,                                               ##### T4 does not support brain float (https://www.aewin.com/application/bfloat16-a-brief-intro/) #####
        ddp_find_unused_parameters=False,                                    # Parameter passed to DistributedDataParallel, defaults to false if gradient checkpointing used
        eval_steps=config.steps_between_save_eval_log,                       # Number of update steps between two evaluations
        eval_accumulation_steps=1,                                           # Number of prediction steps to accumulate before moving to CPU, higher values lead to faster training but also require more memory
        evaluation_strategy="steps",                                         # Set to steps indicates that evaluation is performed and logged every eval_steps
        # fp16=True,                                                         ##### MUST ONLY USE ONE OF BF, FP #####
        # fp16_full_eval=True,                                               ##### MUST ONLY USE ONE OF BF, FP #####
        # gradient_accumulation_steps=config.gradient_accumulation_steps,    ##### Removed for memory #####
        gradient_checkpointing = True,                                       # Strategy that allows you to save memory during gradient updates, at the expense of time
        gradient_checkpointing_kwargs = {'use_reentrant' : False},           # Set to false to avoid errors when using multiple-gpus
        half_precision_backend="auto",                                       # Parameter for mixed-precision training
        learning_rate=config.learning_rate,                                  # Initial learning rate for optimizer
        lr_scheduler_type = 'cosine',                                        # SchedulerType
        logging_strategy="steps",                                            # Logging is done at every logging_steps
        logging_steps=config.steps_between_save_eval_log,                    # Number of update steps between two logs
        num_train_epochs=config.num_train_epochs,                            # Number of training epochs to perform
        optim="paged_adamw_8bit",                                            # Optimizer to use, valid names can be found at (https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/training_args.py#L141)
        per_device_train_batch_size=config.batch_size_per_device,            # Batch size per GPU for training
        per_device_eval_batch_size=config.batch_size_per_device * 2,         # Batch size per GPU for evaluation (usually set to higher value than for training due to smaller memory consumption)
        save_strategy="steps",                                               # Save is performed every save_steps steps
        save_steps=config.steps_between_save_eval_log,                       # Number of update steps between two checkpoint saves
        save_total_limit=5,                                                  # Limits the total number of saves, deletes older checkpoints
        # warmup_steps=warmup_steps,                                           # Steps used for a lienar warmup from 0 to learning_rate, overrides effects of warmup_ratio
    )
    print("< Loaded training arguments >")

    torch.cuda.empty_cache()

    trainer = transformers.Trainer(
        model=model,
        train_dataset= train_data,               # Dataset to use for evaluation, requires data of specific format (such as datasets.Dataset)
        eval_dataset= valid_data,                # Dataset to use for training, requires data of specific format (such as datasets.Dataset)
        args=training_args,                      # The training arguments set above
        data_collator= DataCollatorForSeq2Seq(   # The data collator to use, note that DataCollatorForLanguageModeling is not suitable for labelled data
            tokenizer,
        ),
    )
    model.config.use_cache = (
        False  # Silences warnings, re-enable for inference
    )

    # Perform pre-training evaluation
    # Details on data accepted by evaluate() can be found in HuggingFace documentation
    # https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.Trainer.evaluate
    pre_training_evaluation = trainer.evaluate(Dataset.from_dict(dict(valid_data[-5:])))

    # Perform training
    print("< Training in Progress >")
    trainer.train()
    trainer.save_model(os.path.join(LORA_ADAPTERS_DIR, finetuned_model_name + '-LoRA-Adapters'))

    # Perform post-training evaluation
    post_training_evaluation = trainer.evaluate(Dataset.from_dict(dict(valid_data[-5:])))

    # Free up as much memory as possible
    del trainer
    gc.collect()
    torch.cuda.empty_cache()

    # Pring pre-training and post-training evaluations
    print("< PRE-TRAINING EVAL >")
    print(pre_training_evaluation)
    print("< POST-TRAINING EVAL >")
    print(post_training_evaluation)

    # Free up as much memory as possible
    del model
    gc.collect()
    torch.cuda.empty_cache()

    model = AutoModelForCausalLM.from_pretrained(
        config.pretrained_model_name,
        quantization_config=quantization_config,
        device_map = 'cuda',
        trust_remote_code=True,  # Some models require this option
    )
    model = PeftModel.from_pretrained(model, os.path.join(LORA_ADAPTERS_DIR, finetuned_model_name + '-LoRA-Adapters'))                                                                            # PEFT 어댑터를 모델에 적용한다
    model = model.merge_and_unload()

    model.push_to_hub(f"Chahnwoo/{str(finetuned_model_name)}_{str(get_now())[:8]}")


In [None]:
main(get_config())