In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["TOKENIZERS_PARALLELISM"] = "false"


In [2]:
import collections

import bitsandbytes
import datasets
import peft
import rich
import torch
from tqdm.notebook import tqdm
import transformers
import torch

device = 0
model_name_or_path     = "google/flan-t5-xxl"
tokenizer_name_or_path = "google/flan-t5-xxl"

text_column      = "sentence"
label_column     = "text_label"
max_length       = 200
lr               = 1e-3
num_epochs       = 3
train_batch_size = 1
eval_batch_size  = 16

peft_config = peft.LoraConfig(
    lora_alpha     = 32, 
    r              = 8, 
    inference_mode = False, 
    lora_dropout   = 0.1,
    task_type      = peft.TaskType.SEQ_2_SEQ_LM, 
)

print_eval_every_x_step = 500


Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
CUDA SETUP: CUDA runtime path found: /cvmfs/ai.mila.quebec/apps/arch/common/cuda/11.7/lib64/libcudart.so
CUDA SETUP: Highest compute capability among GPUs detected: 8.0
CUDA SETUP: Detected CUDA version 117
CUDA SETUP: Loading binary /home/mila/g/gagnonju/.main/lib/python3.9/site-packages/bitsandbytes/libbitsandbytes_cuda117.so...


In [3]:

def calc_acc(preds, split, dataset):
    correct = 0
    total = 0
    incorrect = collections.Counter()
    for pred, true in zip(preds, dataset[split]["text_label"]):
        if pred.strip() == true.strip():
            correct += 1
        else:
            incorrect.update([pred])
        total += 1
    accuracy = correct / total
    rich.print(
        f"{accuracy           = :0.2%} on the evaluation dataset\n"
        f"{preds[:10]         = }\n"
        f"{dataset[split]['text_label'][:10] = }\n"
        f"{incorrect.most_common(10) = }"
    )

def eval_epoch(*, fn_model, tokenizer, eval_dataloader):
    prev_state = fn_model.training
    fn_model.eval()

    eval_loss = 0
    eval_preds = []
    for step, batch in enumerate(tqdm(eval_dataloader, desc="Evaluating")):
        fn_model = fn_model.eval()
        batch      = {k: v.to(fn_model.device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = fn_model(**batch)
        loss       = outputs.loss
        eval_loss += loss.detach().float()
        eval_preds.extend(
            tokenizer.batch_decode(
                torch.argmax(outputs.logits, -1).detach().cpu().numpy(), 
                skip_special_tokens=True,
            ))

    eval_epoch_loss = eval_loss / len(eval_dataloader)
    eval_ppl = torch.exp(eval_epoch_loss)
    
    if prev_state:
        fn_model.train()
    return eval_epoch_loss.item(), eval_ppl.item(), eval_preds

def train_epoch(
        *,
        epoch, 
        fn_model, 
        optimizer, 
        tokenizer,
        eval_every, 
        lr_scheduler, 
        eval_dataloader, 
        train_dataloader, 
):

    total_loss = 0
    for step, batch in enumerate(tqdm(train_dataloader, desc="Training")):
        fn_model = fn_model.train()
        batch       = {k: v.to(device) for k, v in batch.items()}
        outputs     = fn_model(**batch)
        loss        = outputs.loss
        total_loss += loss.detach().float()

        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

        if (
            eval_dataloader is not None and 
            step % eval_every == 0 and 
            step > 0
        ):
            eval_epoch_loss, eval_ppl, eval_preds = eval_epoch(
                fn_model        = fn_model, 
                tokenizer       = tokenizer, 
                eval_dataloader = eval_dataloader,
            )
            rich.print(
                f"[bold green]{epoch} - {step}:[/] "
                f"{eval_ppl        = :0.3} "
                f"{eval_epoch_loss = :0.3}"
            )
            calc_acc(eval_preds, "validation", eval_dataloader)

    return total_loss.item()

In [4]:
dmap_keys = ["encoder", "lm_head", "shared", "decoder"]
dmap = {k: os.environ["LOCAL_RANK"] for k in dmap_keys}

frozen_model = transformers.T5ForConditionalGeneration.from_pretrained(
    model_name_or_path,
    device_map   = dmap,
    torch_dtype  = torch.bfloat16,
    load_in_8bit = True,
)

for name, param in frozen_model.named_parameters():
    param.requires_grad = False
peft.PeftModel.print_trainable_parameters(frozen_model)



Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

trainable params: 0 || all params: 11135332352 || trainable%: 0.0


In [5]:
for name, param in frozen_model.named_parameters():
    param.requires_grad = False
peft.PeftModel.print_trainable_parameters(frozen_model)

model = peft.get_peft_model(frozen_model, peft_config)
model.print_trainable_parameters()

trainable params: 0 || all params: 11135332352 || trainable%: 0.0
trainable params: 9437184 || all params: 11144769536 || trainable%: 0.08467814403443578


In [6]:
# loading dataset
dataset = datasets.load_dataset("financial_phrasebank", "sentences_allagree")
dataset = dataset["train"].train_test_split(test_size=0.1)
dataset["validation"] = dataset["test"]
del dataset["test"]

classes = dataset["train"].features["label"].names
dataset = dataset.map(
    lambda x: {"text_label": [classes[label] for label in x["label"]]},
    batched=True,
    num_proc=1,
)

dataset["train"][0]

Found cached dataset financial_phrasebank (/home/mila/g/gagnonju/.cache/huggingface/datasets/financial_phrasebank/sentences_allagree/1.0.0/550bde12e6c30e2674da973a55f57edde5181d53f5a5a34c1531c53f93b7e141)


  0%|          | 0/1 [00:00<?, ?it/s]

Map:   0%|          | 0/2037 [00:00<?, ? examples/s]

Map:   0%|          | 0/227 [00:00<?, ? examples/s]

{'sentence': 'The cranes would be installed onboard two freighters ordered by Singaporean ship owner Masterbulk .',
 'label': 1,
 'text_label': 'neutral'}

In [7]:
# data preprocessing
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path)

def preprocess_function(examples):
    inputs       = examples[text_column]
    targets      = examples[label_column]
    prompt = "Answer if the sentiment of the following sentence is positive, negative or neutral: "
    inputs = [prompt + x for x in inputs]

    model_inputs = tokenizer(
        inputs, 
        max_length     = max_length, 
        padding        = True, 
        truncation     = True,
        return_tensors = "pt",
    )
    labels = tokenizer(
        targets, 
        max_length     = 3, 
        padding        = True, 
        truncation     = True,
        return_tensors = "pt",
    )
    labels = labels["input_ids"]
    labels[labels == tokenizer.pad_token_id] = -100
    model_inputs["labels"] = labels
    return model_inputs


processed_datasets = dataset.map(
    preprocess_function,
    desc     = "Running tokenizer on dataset",
    batched  = True,
    num_proc = 1,
    remove_columns       = dataset["train"].column_names,
    load_from_cache_file = False,
)

train_dataset = processed_datasets["train"]
eval_dataset  = processed_datasets["validation"]
collator = transformers.DataCollatorForSeq2Seq(
    tokenizer, 
    model=model, 
    padding=True, 
    max_length=max_length, 
    return_tensors="pt",
)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset, 
    shuffle    = True, 
    collate_fn = collator,
    batch_size = train_batch_size, 
    pin_memory = True,
)
eval_dataloader = torch.utils.data.DataLoader(
    eval_dataset, 
    shuffle    = False,
    collate_fn = collator, 
    batch_size = eval_batch_size, 
    pin_memory = True,
)

Running tokenizer on dataset:   0%|          | 0/2037 [00:00<?, ? examples/s]

Running tokenizer on dataset:   0%|          | 0/227 [00:00<?, ? examples/s]

In [8]:
# optimizer and lr scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
lr_scheduler = transformers.get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=(len(train_dataloader) * num_epochs),
)

In [9]:
eval_epoch_loss, eval_ppl, eval_preds = eval_epoch(frozen_model, tokenizer, eval_dataloader)
rich.print(f"[bold blue]Zero shot frozen:[/] epoch = -1: {eval_ppl = :0.3} {eval_epoch_loss = :0.3}")
calc_acc(eval_preds, "validation", dataset)

eval_epoch_loss, eval_ppl, eval_preds = eval_epoch(model, tokenizer, eval_dataloader)
rich.print(f"[bold green]Peft zero-shot:[/] {eval_ppl = :0.3} {eval_epoch_loss = :0.3}")
calc_acc(eval_preds, "validation", dataset)

for epoch in range(num_epochs):
    total_loss = train_epoch(
        epoch            = epoch, 
        fn_model         = model, 
        tokenizer        = tokenizer,
        optimizer        = optimizer, 
        eval_every       = print_eval_every_x_step,
        lr_scheduler     = lr_scheduler,
        eval_dataloader  = eval_dataloader,
        train_dataloader = train_dataloader, 
    )
    eval_epoch_loss, eval_ppl, eval_preds = eval_epoch(model, tokenizer, eval_dataloader)

    train_epoch_loss = total_loss / len(train_dataloader)
    train_ppl = torch.exp(torch.tensor(train_epoch_loss))
    rich.print(
        f"[bold blue]{epoch = }:[/] "
        f"{train_ppl        = :0.3} "
        f"{train_epoch_loss = :0.3} "
        f"{eval_ppl         = :0.3} "
        f"{eval_epoch_loss  = :0.3}")
    calc_acc(eval_preds, "validation", dataset)

In [10]:
eval_epoch_loss, eval_ppl, eval_preds = eval_epoch(model, tokenizer, eval_dataloader)
rich.print(f"[bold green]Peft zero-shot:[/] {eval_ppl = :0.3} {eval_epoch_loss = :0.3}")
calc_acc(eval_preds, "validation", dataset)


In [11]:
# saving model
peft_model_id = f"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}"
model.save_pretrained(peft_model_id)

In [12]:
ckpt = f"{peft_model_id}/adapter_model.bin"
!du -h $ckpt

37M	google/flan-t5-xxl_LORA_SEQ_2_SEQ_LM/adapter_model.bin


from peft import PeftModel, PeftConfig

peft_model_id = f"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}"

config = PeftConfig.from_pretrained(peft_model_id)
model  = transformers.AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path)
model  = PeftModel.from_pretrained(model, peft_model_id)

model.eval()
i = 13
inputs = tokenizer(dataset["validation"][text_column][i], return_tensors="pt")
print(dataset["validation"][text_column][i])
print(inputs)

with torch.no_grad():
    outputs = model.generate(input_ids=inputs["input_ids"].to(model.device), max_new_tokens=10)
    print(outputs)
    print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))