In [19]:
import formatting
import json
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer
from datasets import Dataset
import copy
from transformers import TrainerCallback
from contextlib import nullcontext
from transformers import default_data_collator, Trainer, TrainingArguments

# Load Model

In [20]:
# The path to the hugging face model. See the README to get this model.
hugging_face_model_dir = "../../models/llama/7B-hf"

In [21]:
# Load and setup the tokenizer
tokenizer:LlamaTokenizer = LlamaTokenizer.from_pretrained(hugging_face_model_dir)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.model_max_length = 512
# The path to the trained model. This is generated from the hugging face model train.ipynb
# This file does not include all weights, but simply a small subset of weights that were changed during training.
model:LlamaForCausalLM = LlamaForCausalLM.from_pretrained(hugging_face_model_dir, load_in_8bit=True, device_map='auto', torch_dtype=torch.float16)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

# Load Dataset

In [22]:
# Load the preprocessed dataset
with open('datasets/cornell_parsed.json') as f:
    dataset = json.load(f)

In [23]:
dataset[0]

[['Bianca', 'Cameron'],
 [['Bianca',
   'Can we make this quick? Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad. Again.'],
  ['Cameron',
   "Well, I thought we'd start with pronunciation, if that's okay with you."],
  ['Bianca', 'Not the hacking and gagging and spitting part. Please.'],
  ['Cameron',
   "Okay... then how 'bout we try out some French cuisine. Saturday? Night?"]]]

In [24]:
# Test out the formatting
prompt = formatting.get_chat_prompt(dataset[0][0])
print(formatting.format_conversation(prompt, dataset[0][1]))

Here is a discussion between Cameron and Bianca:
Bianca: Can we make this quick? Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad. Again.</s>Cameron: Well, I thought we'd start with pronunciation, if that's okay with you.</s>Bianca: Not the hacking and gagging and spitting part. Please.</s>Cameron: Okay... then how 'bout we try out some French cuisine. Saturday? Night?</s>


# Tokenize Dataset

In [25]:
# Test out the tokenizer
formatting.tokenize_with_turn_trucation(tokenizer, prompt, dataset[0][1])

{'input_ids': [1, 2266, 338, 263, 10679, 1546, 20939, 265, 322, 29860, 1113, 29901, 13, 29933, 713, 1113, 29901, 1815, 591, 1207, 445, 4996, 29973, 1528, 29916, 11276, 12555, 29878, 457, 322, 11571, 2261, 13158, 526, 2534, 385, 29811, 14981, 4029, 28759, 681, 970, 2867, 29899, 701, 373, 278, 18890, 29889, 11454, 29889, 2, 20939, 265, 29901, 5674, 29892, 306, 2714, 591, 29915, 29881, 1369, 411, 11504, 11173, 362, 29892, 565, 393, 29915, 29879, 20759, 411, 366, 29889, 2, 29860, 1113, 29901, 2216, 278, 15833, 292, 322, 330, 351, 3460, 322, 805, 5367, 760, 29889, 3529, 29889, 2, 20939, 265, 29901, 20419, 856, 769, 920, 525, 29890, 449, 591, 1018, 714, 777, 5176, 2723, 275, 457, 29889, 24211, 29973, 11554, 29973, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 

In [26]:
# Format, add prompts, and tokeize the entire dataset
# This takes a while. On my machine its ~ 2 mins
tokenized_dataset_dict = {"input_ids": [], "attention_mask": [], "labels": []}
for (players,conv) in dataset:
    prompt = formatting.get_chat_prompt(players)
    row = formatting.tokenize_with_turn_trucation(tokenizer, prompt, conv)
    tokenized_dataset_dict["input_ids"].append(row["input_ids"])
    tokenized_dataset_dict["attention_mask"].append(row["attention_mask"])
    tokenized_dataset_dict["labels"].append(row["labels"])
# Convert this into a datasets.Dataset object
tokeized_dataset = Dataset.from_dict(tokenized_dataset_dict).shuffle(seed=42)

In [27]:
tokeized_dataset

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 83097
})

In [28]:
# This is an example row from the dataset
tokeized_dataset[0]

{'input_ids': [1,
  450,
  1494,
  338,
  263,
  13563,
  1546,
  7927,
  1358,
  322,
  21828,
  29901,
  13,
  2855,
  29891,
  29901,
  323,
  29899,
  12711,
  9007,
  29915,
  29873,
  263,
  2769,
  304,
  3708,
  1066,
  873,
  21682,
  592,
  1192,
  2,
  7927,
  1358,
  29901,
  887,
  29915,
  276,
  263,
  281,
  6574,
  29889,
  2,
  21828,
  29901,
  1619,
  4783,
  1497,
  306,
  881,
  29915,
  345,
  2355,
  841,
  263,
  25008,
  1192,
  29991,
  2,
  7927,
  1358,
  29901,
  1987,
  596,
  4783,
  29915,
  29879,
  263,
  281,
  6574,
  29889,
  2,
  21828,
  29901,
  1126,
  366,
  29915,
  276,
  925,
  6460,
  4796,
  534,
  1161,
  29991,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
 

In [29]:
# Check that the encoding worked by decoding one of the inputs
print(tokenizer.decode(tokeized_dataset[0]["input_ids"]))

<s> The following is a chat between Lawler and Andy:
Andy: T-there wasn't a reason to purposely hurt me --</s> Lawler: You're a wimp.</s> Andy: My father said I should've gotten a lawyer --!</s> Lawler: Then your father's a wimp.</s> Andy: And you're just poor white trash!</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s

# Setup model for PEFT (Parameter Efficient Fine-Tuning)

In [30]:
# Put the model in training mode
model.train()

# Create the PEFT config
def create_peft_config(model):
    from peft import (
        get_peft_model,
        LoraConfig,
        TaskType,
        prepare_model_for_int8_training,
    )

    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False,
        r=8,
        lora_alpha=32,
        lora_dropout=0.05,
        target_modules = ["q_proj", "v_proj"]
    )

    # prepare int-8 model for training
    model = prepare_model_for_int8_training(model)
    model = get_peft_model(model, peft_config)
    model.print_trainable_parameters()
    return model, peft_config

model, lora_config = create_peft_config(model)



trainable params: 4,194,304 || all params: 6,742,609,920 || trainable%: 0.06220594176090199


# Train

In [31]:
# Define training args
output_dir = "tmp/llama-output"
training_args = TrainingArguments(
    output_dir=output_dir,
    overwrite_output_dir=True,
    bf16=True,  # Use BF16 if available
    # logging strategies
    logging_dir=f"{output_dir}/logs",
    logging_strategy="steps",
    logging_steps=10,
    save_strategy="no",
    optim="adamw_torch_fused",
    max_steps=-1,
    learning_rate=1e-5, # This used to be 1e-4, but I changed it to 1e-5 to make it train slower with more data
    num_train_epochs=1,
    gradient_accumulation_steps=2,
    per_device_train_batch_size=2,
    gradient_checkpointing=False,
)

# Create Trainer instance
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokeized_dataset,
    data_collator=default_data_collator,
    callbacks=[],
)

In [32]:
# Start Training - This took ~ 4 hours on my machine
trainer.train()

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Step,Training Loss
10,21.5802
20,19.5641
30,17.429
40,14.7267
50,10.5217
60,8.4779
70,6.5009
80,4.6703
90,3.1631
100,1.7681


KeyboardInterrupt: 

In [33]:
# Save the model to disk for later
model.save_pretrained("./trained-models/llama-7B-v2.1-stanford")