In [1]:
import formatting
import json
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer
from datasets import Dataset
import copy
from transformers import TrainerCallback
from contextlib import nullcontext
from transformers import default_data_collator, Trainer, TrainingArguments

# Load Model

In [2]:
# The path to the hugging face model. See the README to get this model.
hugging_face_model_dir = "../../models/llama/7B-hf"

In [3]:
# Load and setup the tokenizer
tokenizer:LlamaTokenizer = LlamaTokenizer.from_pretrained(hugging_face_model_dir)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.model_max_length = 512
# The path to the trained model. This is generated from the hugging face model train.ipynb
# This file does not include all weights, but simply a small subset of weights that were changed during training.
model:LlamaForCausalLM = LlamaForCausalLM.from_pretrained(hugging_face_model_dir, load_in_8bit=True, device_map='auto', torch_dtype=torch.float16)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

# Load Dataset

In [4]:
# Load the preprocessed dataset
with open('datasets/topical_chat.json') as f:
    dataset = json.load(f)

In [5]:
dataset[0]

[['user_b', 'Are you a fan of Google or Microsoft?'],
 ['user_a',
  'Both are excellent technology they are helpful in many ways. For the security purpose both are super.'],
 ['user_b',
  "I'm not  a huge fan of Google, but I use it a lot because I have to. I think they are a monopoly in some sense."],
 ['user_a',
  'Google provides online related services and products, which includes online ads, search engine and cloud computing.'],
 ['user_b',
  "Yeah, their services are good. I'm just not a fan of intrusive they can be on our personal lives."],
 ['user_a',
  'Google is leading the alphabet subsidiary and will continue to be the Umbrella company for Alphabet internet interest.'],
 ['user_b',
  'Did you know Google had hundreds of live goats to cut the grass in the past?'],
 ['user_a',
  'It is very interesting. Google provide "Chrome OS" which is a light weight OS. Google provided a lot of hardware mainly in 2010 to 2015.'],
 ['user_b', 'I like Google Chrome. Do you use it as well fo

In [6]:
# Test out the formatting
prompt = formatting.get_chat_prompt(["user_1", "user_2"], ["Josh", "Mr Mainframe"])
print(formatting.format_conversation(prompt, dataset[0]))

The following is a discussion between Mr Mainframe (user_2) and Josh (user_1):
user_b: Are you a fan of Google or Microsoft?</s>user_a: Both are excellent technology they are helpful in many ways. For the security purpose both are super.</s>user_b: I'm not  a huge fan of Google, but I use it a lot because I have to. I think they are a monopoly in some sense.</s>user_a: Google provides online related services and products, which includes online ads, search engine and cloud computing.</s>user_b: Yeah, their services are good. I'm just not a fan of intrusive they can be on our personal lives.</s>user_a: Google is leading the alphabet subsidiary and will continue to be the Umbrella company for Alphabet internet interest.</s>user_b: Did you know Google had hundreds of live goats to cut the grass in the past?</s>user_a: It is very interesting. Google provide "Chrome OS" which is a light weight OS. Google provided a lot of hardware mainly in 2010 to 2015.</s>user_b: I like Google Chrome. Do y

# Tokenize Dataset

In [7]:
# Test out the tokenizer
formatting.tokenize_with_turn_trucation(tokenizer, prompt, dataset[0])

{'input_ids': [1, 450, 1494, 338, 263, 10679, 1546, 3237, 4241, 2557, 313, 1792, 29918, 29906, 29897, 322, 22838, 313, 1792, 29918, 29896, 1125, 13, 1792, 29918, 29890, 29901, 4683, 366, 263, 13524, 310, 5087, 470, 7783, 29973, 2, 1404, 29918, 29874, 29901, 9134, 526, 15129, 15483, 896, 526, 8444, 297, 1784, 5837, 29889, 1152, 278, 6993, 6437, 1716, 526, 2428, 29889, 2, 1404, 29918, 29890, 29901, 306, 29915, 29885, 451, 29871, 263, 12176, 13524, 310, 5087, 29892, 541, 306, 671, 372, 263, 3287, 1363, 306, 505, 304, 29889, 306, 1348, 896, 526, 263, 1601, 13242, 29891, 297, 777, 4060, 29889, 2, 1404, 29918, 29874, 29901, 5087, 8128, 7395, 4475, 5786, 322, 9316, 29892, 607, 7805, 7395, 594, 29879, 29892, 2740, 6012, 322, 9570, 20602, 29889, 2, 1404, 29918, 29890, 29901, 15011, 29892, 1009, 5786, 526, 1781, 29889, 306, 29915, 29885, 925, 451, 263, 13524, 310, 11158, 375, 573, 896, 508, 367, 373, 1749, 7333, 12080, 29889, 2, 1404, 29918, 29874, 29901, 5087, 338, 8236, 278, 22968, 11684, 8819

In [8]:
# Format, add prompts, and tokeize the entire dataset
# This takes a while. On my machine its ~ 2 mins
tokenized_dataset_dict = {"input_ids": [], "attention_mask": [], "labels": []}
for conv in dataset:
    prompt = formatting.get_chat_prompt(["user_1", "user_2"], ["Josh", "Mr Mainframe"])
    row = formatting.tokenize_with_turn_trucation(tokenizer, prompt, conv)
    tokenized_dataset_dict["input_ids"].append(row["input_ids"])
    tokenized_dataset_dict["attention_mask"].append(row["attention_mask"])
    tokenized_dataset_dict["labels"].append(row["labels"])
# Convert this into a datasets.Dataset object
tokeized_dataset = Dataset.from_dict(tokenized_dataset_dict)

In [9]:
tokeized_dataset

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 8628
})

In [10]:
# This is an example row from the dataset
tokeized_dataset[0]

{'input_ids': [1,
  2266,
  338,
  263,
  10679,
  1546,
  3237,
  4241,
  2557,
  313,
  1792,
  29918,
  29906,
  29897,
  322,
  22838,
  313,
  1792,
  29918,
  29896,
  1125,
  13,
  1792,
  29918,
  29890,
  29901,
  4683,
  366,
  263,
  13524,
  310,
  5087,
  470,
  7783,
  29973,
  2,
  1404,
  29918,
  29874,
  29901,
  9134,
  526,
  15129,
  15483,
  896,
  526,
  8444,
  297,
  1784,
  5837,
  29889,
  1152,
  278,
  6993,
  6437,
  1716,
  526,
  2428,
  29889,
  2,
  1404,
  29918,
  29890,
  29901,
  306,
  29915,
  29885,
  451,
  29871,
  263,
  12176,
  13524,
  310,
  5087,
  29892,
  541,
  306,
  671,
  372,
  263,
  3287,
  1363,
  306,
  505,
  304,
  29889,
  306,
  1348,
  896,
  526,
  263,
  1601,
  13242,
  29891,
  297,
  777,
  4060,
  29889,
  2,
  1404,
  29918,
  29874,
  29901,
  5087,
  8128,
  7395,
  4475,
  5786,
  322,
  9316,
  29892,
  607,
  7805,
  7395,
  594,
  29879,
  29892,
  2740,
  6012,
  322,
  9570,
  20602,
  29889,
  2,
  1404,
 

In [11]:
# Check that the encoding worked by decoding one of the inputs
print(tokenizer.decode(tokeized_dataset[0]["input_ids"]))

<s> Here is a discussion between Mr Mainframe (user_2) and Josh (user_1):
user_b: Are you a fan of Google or Microsoft?</s> user_a: Both are excellent technology they are helpful in many ways. For the security purpose both are super.</s> user_b: I'm not  a huge fan of Google, but I use it a lot because I have to. I think they are a monopoly in some sense.</s> user_a: Google provides online related services and products, which includes online ads, search engine and cloud computing.</s> user_b: Yeah, their services are good. I'm just not a fan of intrusive they can be on our personal lives.</s> user_a: Google is leading the alphabet subsidiary and will continue to be the Umbrella company for Alphabet internet interest.</s> user_b: Did you know Google had hundreds of live goats to cut the grass in the past?</s> user_a: It is very interesting. Google provide "Chrome OS" which is a light weight OS. Google provided a lot of hardware mainly in 2010 to 2015.</s> user_b: I like Google Chrome. D

# Setup model for PEFT (Parameter Efficient Fine-Tuning)

In [12]:
# Put the model in training mode
model.train()

# Create the PEFT config
def create_peft_config(model):
    from peft import (
        get_peft_model,
        LoraConfig,
        TaskType,
        prepare_model_for_int8_training,
    )

    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False,
        r=8,
        lora_alpha=32,
        lora_dropout=0.05,
        target_modules = ["q_proj", "v_proj"]
    )

    # prepare int-8 model for training
    model = prepare_model_for_int8_training(model)
    model = get_peft_model(model, peft_config)
    model.print_trainable_parameters()
    return model, peft_config

model, lora_config = create_peft_config(model)



trainable params: 4,194,304 || all params: 6,742,609,920 || trainable%: 0.06220594176090199


# Train

In [13]:
# Define training args
output_dir = "tmp/llama-output"
training_args = TrainingArguments(
    output_dir=output_dir,
    overwrite_output_dir=True,
    bf16=True,  # Use BF16 if available
    # logging strategies
    logging_dir=f"{output_dir}/logs",
    logging_strategy="steps",
    logging_steps=10,
    save_strategy="no",
    optim="adamw_torch_fused",
    max_steps=-1,
    learning_rate=1e-4,
    num_train_epochs=1,
    gradient_accumulation_steps=2,
    per_device_train_batch_size=2,
    gradient_checkpointing=False,
)

# Create Trainer instance
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokeized_dataset,
    data_collator=default_data_collator,
    callbacks=[],
)

In [14]:
# Start Training - This took ~ 4 hours on my machine
trainer.train()

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...




Step,Training Loss


KeyboardInterrupt: 

In [None]:
# Save the model to disk for later
model.save_pretrained("./trained-models/llama-7B-topical-chat-v1")