<a href="https://colab.research.google.com/github/G0nkly/pytorch_sandbox/blob/main/finetuning/Lora.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers datasets peft accelerate bitsandbytes

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "sshleifer/tiny-gpt2"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
###############
# Custom Data #
###############

In [None]:
from datasets import Dataset

texts = [
    "Hello World! This my customm text. \n",
    "Today I learned about LoRA and Transformers. \n",
    "AI is amazing! \n"
]

dataset = Dataset.from_dict({"text": texts})
encodings = tokenizer(texts, return_tensors="pt", padding=True, truncation=True)

def tokenize(batch):
  tokens = tokenizer(batch["text"], truncation=True, padding="max_length", max_length=64)
  tokens["labels"] = tokens["input_ids"]
  return tokens

tokenized_dataset = dataset.map(tokenize, batched=True)
tokenized_dataset.set_format("torch")

In [None]:
tokenized_dataset[2]

In [None]:
########
# LoRA #
########

In [None]:
from peft import LoraConfig, get_peft_model, TaskType

config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["c_attn", "c_proj"],
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.CAUSAL_LM
)

model = get_peft_model(model, config)
model.print_trainable_parameters()

In [None]:
############
# Training #
############

In [None]:
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir="./lora-out",
    per_device_train_batch_size=2,
    num_train_epochs=10,
    learning_rate=2e-4,
    save_strategy="no",
    fp16=False,
    logging_steps=200
)

trainer = Trainer(
    model=model,
    train_dataset=tokenized_dataset,
    args=training_args,
    tokenizer=tokenizer,
)

trainer.train()

In [None]:
##############
# Evaluation #
##############