# LoRA for Seq2Seq Conditional Generation
Example adapted from [PEFT](https://github.com/huggingface/peft/blob/main/examples/conditional_generation/peft_prompt_tuning_seq2seq_with_generate.ipynb).

In [None]:
%pip install -q --user transformers==4.35.2
%pip install -q --user datasets
%pip install -q --user peft

In [None]:
import torch
import wandb
import os

from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    GenerationConfig,
    default_data_collator
)
from peft import (
    get_peft_model,
    LoraConfig,
    TaskType,
)

from datasets import load_dataset

os.environ["TOKENIZERS_PARALLELISM"] = "false" # to avoid deadlock warnings

from datasets import load_dataset

In [None]:
device = "cuda"
model_name_or_path = "bigscience/mt0-large"
tokenizer_name_or_path = "bigscience/mt0-large"

# text_column = "sentence"
# label_column = "text_label"
max_length = 128
lr = 1e-3
num_epochs = 3
batch_size = 32

In [None]:
r = 8 # Size of the low-rank matrices (rank)
lora_alpha = 32 # The alpha parameter for Lora scaling
lora_dropout = 0.1 # The dropout probability for Lora layers

# Experiment with different reparametrization
target_modules = None
# target_modules = "all-linear"
# target_modules = ["q", "k", "v"]

In [None]:
peft_config = LoraConfig(task_type=TaskType.SEQ_2_SEQ_LM, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, target_modules=target_modules)


model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)

# comment this if you want to do FFT  (uses around 35GB of GPU memory with batch size of 32)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
model

In [None]:
dataset = load_dataset("financial_phrasebank", "sentences_allagree")

# train, valid, test split
dataset = dataset["train"].train_test_split(test_size=0.2)
validtest = dataset["test"].train_test_split(test_size=0.5)

dataset["validation"] = validtest["train"]
dataset["test"] = validtest["test"]

classes = dataset["train"].features["label"].names
dataset = dataset.map(
    lambda x: {"text_label": [classes[label] for label in x["label"]]},
    batched=True,
    num_proc=1,
)

dataset["train"][0]

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)


def preprocess_function(examples):
    inputs = examples[text_column]
    targets = examples[label_column]
    model_inputs = tokenizer(inputs, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt")
    labels = tokenizer(targets, max_length=4, padding="max_length", truncation=True, return_tensors="pt")
    labels = labels["input_ids"]
    model_inputs["labels"] = labels
    return model_inputs


processed_datasets = dataset.map(
    preprocess_function,
    batched=True,
    num_proc=1,
    remove_columns=dataset["train"].column_names,
    load_from_cache_file=False,
    desc="Running tokenizer on dataset",
)

train_dataset = processed_datasets["train"].shuffle()
eval_dataset = processed_datasets["validation"]
test_dataset = processed_datasets["test"]

In [None]:
def compute_metrics(eval_preds):
    preds, labels = eval_preds
    preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    correct = 0
    total = 0
    for pred, true in zip(preds, labels):
        if pred.strip() == true.strip():
            correct += 1
        total += 1
    accuracy = correct / total
    return {"accuracy": accuracy}


training_args = Seq2SeqTrainingArguments(
    "out",
    per_device_train_batch_size=batch_size,
    learning_rate=lr,
    num_train_epochs=num_epochs,
    evaluation_strategy="epoch",
    logging_strategy="epoch",
    save_strategy="no",
    predict_with_generate=True,
    generation_config=GenerationConfig(max_new_tokens=10),
)

In [None]:
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=default_data_collator,
    compute_metrics=compute_metrics,
)
trainer.train()

trainer.evaluate(eval_dataset=test_dataset, metric_key_prefix="test")

if wandb.run is not None:
    wandb.finish()

In [None]:
import locale
locale.getpreferredencoding = lambda: "UTF-8"

peft_model_id = f"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}"
model.save_pretrained(peft_model_id)

ckpt = f"{peft_model_id}/adapter_model.safetensors"
!du -h $ckpt

In [None]:
from peft import PeftModel, PeftConfig

peft_model_id = f"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}"

config = PeftConfig.from_pretrained(peft_model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path)
model = PeftModel.from_pretrained(model, peft_model_id)

In [None]:
inputs = tokenizer(input(), return_tensors="pt")
print(inputs)
with torch.no_grad():
    outputs = model.generate(input_ids=inputs["input_ids"], max_new_tokens=10)
    print(outputs)
    print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))