In [132]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)


In [133]:
import torch
import numpy as np

In [134]:
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
device

'mps'

In [135]:
import importlib
import src.train
import src.model

importlib.reload(src.train)
importlib.reload(src.model)

from src.train import train_lora
from src.model import identify_target_modules

For testing use a small model and small dataset:

- DistilGPT2 https://huggingface.co/distilbert/distilgpt2
- Tiny shakespeare https://huggingface.co/datasets/karpathy/tiny_shakespeare


In [136]:
from peft import LoraConfig

In [137]:
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments

tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")

In [138]:
from datasets import load_dataset

ds = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1")
# ds = load_dataset("wikitext", "wikitext-103-v1")

In [145]:
def tokenize_dataset(examples, tokenizer=tokenizer, column_name="text"):
	return tokenizer(examples[column_name], return_tensors="pt", padding=True, truncation=True)

In [157]:
tokenizer.pad_token = tokenizer.eos_token
# Run dataset through the tokenizer
# Not sure why the float32 conversion is necessary, but it is
# tokenized_ds = ds.map(lambda x: {k: np.array(v, dtype=np.float32) for k, v in tokenize_dataset(x, tokenizer).items()})
tokenized_ds = ds.map(tokenize_dataset, batched=True, remove_columns=["text"])


Map: 100%|██████████| 4358/4358 [00:01<00:00, 3405.24 examples/s]
Map: 100%|██████████| 36718/36718 [00:11<00:00, 3162.28 examples/s]
Map: 100%|██████████| 3760/3760 [00:00<00:00, 4050.10 examples/s]


In [158]:
# Inspect the tokenized dataset
print("Dataset example:")
print(ds["train"][10])
print("Tokenized dataset example:")
print(tokenized_ds["train"][10])

Dataset example:
{'text': ' The game \'s battle system , the BliTZ system , is carried over directly from Valkyira Chronicles . During missions , players select each unit using a top @-@ down perspective of the battlefield map : once a character is selected , the player moves the character around the battlefield in third @-@ person . A character can only act once per @-@ turn , but characters can be granted multiple turns at the expense of other characters \' turns . Each character has a field and distance of movement limited by their Action Gauge . Up to nine characters can be assigned to a single mission . During gameplay , characters will call out if something happens to them , such as their health points ( HP ) getting low or being knocked out by enemy attacks . Each character has specific " Potentials " , skills unique to each character . They are divided into " Personal Potential " , which are innate skills that remain unaltered unless otherwise dictated by the story and can eith

In [164]:
# What are the keys in the tokenized dataset?
print("Keys in tokenized dataset:")
print(tokenized_ds["train"].column_names)

Keys in tokenized dataset:
['input_ids', 'attention_mask']


In [165]:
# Group the tokenized dataset into blocks of a certain length
block_size = 128
def group_texts(examples):
	concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
	total_length = len(concatenated_examples[list(examples.keys())[0]])
	if total_length > block_size:
		# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
		# customize this part to your needs.
		total_length = (total_length // block_size) * block_size
	# Split by chunks of block size.
	result = {
		k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
		for k, t in concatenated_examples.items()
	}
	result["labels"] = result["input_ids"].copy()
	return result


In [166]:
lm_ds = tokenized_ds.map(group_texts, batched=True)

Map: 100%|██████████| 4358/4358 [00:10<00:00, 423.82 examples/s]
Map: 100%|██████████| 36718/36718 [01:25<00:00, 430.37 examples/s]
Map: 100%|██████████| 3760/3760 [00:06<00:00, 571.75 examples/s]


In [161]:
target_modules = identify_target_modules(model, 'attn')

In [167]:
from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

In [None]:
lora_config = LoraConfig(
  target_modules=target_modules,
  
)

training_args = TrainingArguments(
  output_dir="output",
  eval_strategy="epoch",
  remove_unused_columns=False,
)

model.to(device)

train_lora(
  base_model=model,
  train_dataset=lm_ds["train"],
  eval_dataset=lm_ds["validation"],
  tokenizer=tokenizer,
  adapter_name="wikitext",
  lora_config=lora_config,
  training_args=training_args,
  save_dir='output'
)

