In [1]:
import torch
import numpy as np

import type_dataset_utils
from type_dataset_utils import TypeDataset, TypeSentenceDataset, TypeQADataset
import datasets

from trl import SFTTrainer
from transformers import TrainingArguments

from unsloth import FastLanguageModel

from eval_utils import create_compute_metric_fn, qa_pipeline

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


In [2]:
SEED = 14
type_dataset_utils.NP_RNG = np.random.default_rng(SEED)

## Model

In [3]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/Phi-3-mini-4k-instruct",  # YOUR MODEL YOU USED FOR TRAINING
    max_seq_length=256,
    dtype=None,
    load_in_4bit=False
)

==((====))==  Unsloth: Fast Mistral patching release 2024.6
   \\   /|    GPU: NVIDIA GeForce RTX 4090. Max memory: 23.988 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.2.1+cu121. CUDA = 8.9. CUDA Toolkit = 12.1.
\        /    Bfloat16 = TRUE. Xformers = 0.0.24. FA = False.
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:
model = FastLanguageModel.get_peft_model(
    model,
    r = 128, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 32,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)

Unsloth 2024.6 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


## Data Prep

In [5]:
type_dataset = TypeDataset()

type_sentence_dataset = TypeSentenceDataset(type_dataset).shuffle(seed=SEED)

train_val_pokemon = type_dataset.train_test_split(test_size=.2)

train_qa_dataset = TypeQADataset(train_val_pokemon['train'])
val_qa_dataset = TypeQADataset(train_val_pokemon['test'])
# val_pokemon, test_pokemon = val_pokemon.train_test_split(test_size=.5) # I'll implement a val/test split later

Is it even a good idea to mix autoregressive (non-chat) and chat based stuff? Maybe, maybe not...

In any case, the QA dataset will be chat-based, so we have to apply the chat template to all elements in it.

In [6]:
def batch_apply_chat_template(examples, tokenizer):
    """This function converts the qa dataset into a chat dataset w/ the key text"""
    all_messages = []

    for questions, answers in zip(examples['questions'], examples['answers']):
        for question, answer in zip(questions, answers):
            messages = [
                {'role': 'user', 'content': question},
                {'role': 'assistant', 'content': answer}
            ]

            all_messages.append(messages)

    all_text = tokenizer.apply_chat_template(
        all_messages,
        add_generation_prompt=False, # not generating stuff for this dataset, so no generation prompt needed
        tokenize=False,
    )

    return {'text': all_text}
        

In [7]:
train_qa_chat_dataset = train_qa_dataset.map(
    batch_apply_chat_template,
    batched=True,
    batch_size=100,
    remove_columns=['questions', 'answers', 'types'],
    fn_kwargs={'tokenizer': tokenizer}
)  

val_qa_chat_dataset = val_qa_dataset.map(
    batch_apply_chat_template,
    batched=True,
    batch_size=100,
    remove_columns=['questions', 'answers', 'types'],
    fn_kwargs={'tokenizer': tokenizer}
)

In [8]:
train_dataset = datasets.concatenate_datasets([type_sentence_dataset, train_qa_chat_dataset])
len(train_dataset)

19725

## Training

In [9]:
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = train_dataset,
    eval_dataset = val_qa_chat_dataset,
    compute_metrics=create_compute_metric_fn(model, tokenizer, train_qa_dataset, val_qa_dataset, pokemon_batch_size=16),
    dataset_text_field = "text",
    max_seq_length = 256,
    dataset_num_proc = 1,
    packing = True, # Can make training 5x faster for short sequences.
    args = TrainingArguments(
        per_device_train_batch_size=16, # about the most my 4090 can handle
        per_device_eval_batch_size=8,
        num_train_epochs=30,
        warmup_ratio=.1,
        learning_rate = 2e-4,
        bf16 = True,
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "cosine",
        seed = 3407,
        output_dir = "outputs",
        eval_strategy='steps',
        eval_steps=.25,
        eval_accumulation_steps=4,
    ),
)

Before doing any training, lets evaluate the model so that we can compare how well it does before and after training.

In [10]:
num_generation_examples = 4
pokemon_idxs = type_dataset_utils.NP_RNG.choice(len(val_qa_dataset), num_generation_examples)
example_questions = [questions[0] for questions in val_qa_dataset.select(pokemon_idxs)['questions']]

In [11]:
print(qa_pipeline(model, tokenizer, example_questions))

["What is vulpix-alola's type? Vulpix, as a Pokémon, is a Fire-type Pokémon. It was first introduced in the original Pokémon games, Pok", 'What type of pokemon is Venonat? Venonat is a Ground-type Pokémon. It was first introduced in Generation I of the Pokémon series. As a Ground-', 'What is luvdisc\'s type? I\'m unable to directly identify specific entities or individuals that may have emerged after my last update in April 2023. However, if "', 'Can you tell me the type of gardevoir in the Pokemon universe? In the Pokémon universe, Gardevoir is a Psychic-type Pokémon. It evolves from Vileplume into its final form']


In [12]:
original_eval = trainer.evaluate()
print(original_eval)

Evaluating Model for QA: 100%|██████████| 71/71 [01:04<00:00,  1.10 Pokemon Batch/s]
Evaluating Model for QA: 100%|██████████| 18/18 [00:16<00:00,  1.08 Pokemon Batch/s]
[34m[1mwandb[0m: Currently logged in as: [33mjvp15[0m. Use [1m`wandb login --relogin`[0m to force relogin


{'eval_loss': 9.825358390808105, 'eval_train_macro_accuracy': 0.1480035492457852, 'eval_train_micro_accuracy': 0.14800354924578527, 'eval_val_macro_accuracy': 0.16808510638297874, 'eval_val_micro_accuracy': 0.16808510638297872, 'eval_runtime': 88.6077, 'eval_samples_per_second': 1.907, 'eval_steps_per_second': 0.248}


It's training time

In [13]:
trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs = 1
   \\   /|    Num examples = 1,641 | Num Epochs = 30
O^O/ \_/ \    Batch size per device = 16 | Gradient Accumulation steps = 1
\        /    Total batch size = 16 | Total steps = 3,090
 "-____-"     Number of trainable parameters = 239,075,328


Step,Training Loss,Validation Loss,Train Macro Accuracy,Train Micro Accuracy,Val Macro Accuracy,Val Micro Accuracy
773,0.6609,0.359207,0.532387,0.532387,0.507092,0.507092
1546,0.1079,0.650097,0.465839,0.465839,0.449645,0.449645
2319,0.0113,0.925265,0.451819,0.451819,0.443262,0.443262


Evaluating Model for QA: 100%|██████████| 71/71 [00:59<00:00,  1.19 Pokemon Batch/s]
Evaluating Model for QA: 100%|██████████| 18/18 [00:14<00:00,  1.25 Pokemon Batch/s]
Evaluating Model for QA: 100%|██████████| 71/71 [01:04<00:00,  1.11 Pokemon Batch/s]
Evaluating Model for QA: 100%|██████████| 18/18 [00:15<00:00,  1.13 Pokemon Batch/s]
Evaluating Model for QA: 100%|██████████| 71/71 [01:04<00:00,  1.11 Pokemon Batch/s]
Evaluating Model for QA: 100%|██████████| 18/18 [00:15<00:00,  1.13 Pokemon Batch/s]


TrainOutput(global_step=3090, training_loss=0.3579451241515408, metrics={'train_runtime': 2335.3197, 'train_samples_per_second': 21.081, 'train_steps_per_second': 1.323, 'total_flos': 2.9956952034902016e+17, 'train_loss': 0.3579451241515408, 'epoch': 30.0})

In [14]:
print(qa_pipeline(model, tokenizer, example_questions))

["What is vulpix-alola's type? vulpix-alola's type is Ice", "What type of pokemon is Venonat? Venonat's primary type is Bug and it's second type is Poison", "What is luvdisc's type? luvdisc's type is Water/Fairy", "Can you tell me the type of gardevoir in the Pokemon universe? gardevoir's primary type is psychic and it's second type is fairy"]
