<a href="https://colab.research.google.com/github/alga-hopf/alpaca_lora_sage/blob/main/sage_finetuning_github.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip3 install transformers datasets Accelerate peft bitsandbytes sentencepiece wandb

In [None]:
import json
import numpy as np
import timeit
import copy
import torch
from transformers import Trainer, TrainingArguments, LlamaForCausalLM, LlamaTokenizer, DataCollatorWithPadding, DataCollatorForSeq2Seq, get_scheduler, AdamW
from datasets import load_dataset, DatasetDict, Dataset
from torch.utils.data import DataLoader
import wandb
import os
from tqdm.auto import tqdm
import random

In [None]:
!nvidia-smi

# Load and inspect dataset

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
with open('your path to the dataset in google drive', 'r') as f:
    raw_dataset = json.load(f)

In [None]:
len(raw_dataset)

# Fine tuning

In [None]:
wandb.login()

In [None]:
world_size = int(os.environ.get("WORLD_SIZE", 1))
model_max_length = 512
batch_size = 128
micro_batch_size = 4  
lora_r = 8
lora_alpha = 16
lora_target_modules = ["q_proj", "v_proj"]
lora_dropout = 0.05
val_set_size = 2000
ddp = world_size != 1

In [None]:
tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf", model_max_length=model_max_length, padding_side="right", use_fast=False)
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"
tokenizer.pad_token = DEFAULT_PAD_TOKEN
tokenizer.eos_token = DEFAULT_EOS_TOKEN
tokenizer.bos_token = DEFAULT_BOS_TOKEN
tokenizer.unk_token = DEFAULT_UNK_TOKEN
IGNORE_INDEX = -100

In [None]:
PROMPT_DICT = {
    "prompt_input": (
        "Below is an instruction that describes a task, paired with an input that provides further context. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
    ),
    "prompt_no_input": (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Response:"
    ),
}
list_data_dict = raw_dataset
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
sources_list = [prompt_no_input.format_map(example) for example in list_data_dict]
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]

In [None]:
# Due to limited resources we train only on a subset of the dataset
pop_size = 20000
order = list(range(pop_size))
random.shuffle(order)
examples_list = [s + t for s, t in zip(sources_list, targets)]
examples, sources = [], []
for n in order:
  examples.append(examples_list[n])
  sources.append(sources_list[n])
full_examples = {}
all_examples, all_sources = [], []
for i in range(len(examples)):
    all_examples.append(examples[i])
    all_sources.append(sources[i])
full_examples["example"] = all_examples
full_examples["source"] = all_sources

In [None]:
full_examples_dataset = Dataset.from_dict(full_examples)

In [None]:
def tokenize_function(example):
    data_dict = tokenizer(example["example"], padding="longest", max_length=model_max_length, truncation=True) 
    tokenized_source = tokenizer(example["source"], padding="longest", max_length=model_max_length, truncation=True) 
    data_dict["labels"] = [IGNORE_INDEX] * len(tokenized_source["input_ids"]) + data_dict["input_ids"][len(tokenized_source["input_ids"]):]
    return data_dict

In [None]:
train_dataset = full_examples_dataset.map(tokenize_function)

In [None]:
train_dataset = train_dataset.remove_columns(["example", "source"])

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True)

In [None]:
gradient_accumulation_steps = batch_size // micro_batch_size
device_map = "auto"
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
if ddp:
    device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)}
    gradient_accumulation_steps = gradient_accumulation_steps // world_size

In [None]:
from peft import (
    LoraConfig,
    get_peft_model,
    get_peft_model_state_dict,
    prepare_model_for_int8_training,
    set_peft_model_state_dict,
)

In [None]:
model = LlamaForCausalLM.from_pretrained("decapoda-research/llama-7b-hf", load_in_8bit=True, torch_dtype=torch.float16, device_map=device_map)

In [None]:
model = prepare_model_for_int8_training(model)
config = LoraConfig(
    r=lora_r,
    lora_alpha=lora_alpha,
    target_modules=lora_target_modules,
    lora_dropout=lora_dropout,
    bias="none",
    task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)

In [None]:
warmup_steps = 100
num_train_epochs = 3
lr = 1e-4#3e-4
optimizer = "adamw_torch"
out_dir = "/content/drive/MyDrive/"
batch_size = 16

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=data_collator)

In [None]:
project_name = "sage_finetuning"
experiment_name = "traning"
entity = "your wandb entity here"

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

In [None]:
num_training_steps = num_train_epochs * len(train_dataloader)
print('Num training steps:', num_training_steps)

progress_bar = tqdm(range(num_training_steps))

opt = "adamw"

if opt == "adamw":
  optimizer = AdamW(model.parameters(), lr=lr)
  config = {
    "lr": lr,
    "optimizer": opt,
    "epochs": num_train_epochs,
    "batch_size": batch_size
  }

lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=num_training_steps
)

save_checkpoint = 1000000000000
wandb.init(project=project_name, name=experiment_name, config=config, entity=entity)

t = 0
for epoch in range(num_train_epochs):
  model.train()
  for batch in train_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    output = model(**batch)
    loss = output.loss
    optimizer.zero_grad()
    loss.backward()
    def closure():
      return loss
    optimizer.step(closure)
    if opt in ["adam", "adamw"]:
      lr_scheduler.step()
    progress_bar.update()
    wandb.log({'loss': loss.item()})
    t += 1
    if t % save_checkpoint == 0:
      model.save_pretrained(out_dir)
    
wandb.finish()

In [None]:
model.save_pretrained(out_dir+"alpaca_lora_sage")