Prompting helps guide language model behavior by adding some input text specific to a task. 

Prompt tuning is an additive method for only training and updating the newly added prompt tokens to a pretrained model. 

This way, you can use one pretrained model whose weights are frozen, and train and update a smaller set of prompt parameters for each downstream task instead of fully finetuning a separate mode

In [1]:
!pip install peft >> /dev/null

In [1]:
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    default_data_collator,
    get_linear_schedule_with_warmup
)
from peft import (
    get_peft_config,
    get_peft_model,
    PromptTuningInit,
    PromptTuningConfig,
    TaskType,
    PeftType
)
import torch

from datasets import load_dataset
import os

from torch.utils.data import DataLoader
from tqdm import tqdm
from rich import print

In [2]:
device = "cuda" if torch.cuda.is_available() else 'cpu'

# Identified that the local model directory consists of tokeniser only, so moving to llama2-7b model
# model_name_or_path = "bigscience/bloomz-560m"
model_name_or_path = "meta-llama/Llama-2-7b-chat-hf"

# tokenizer_name_or_path = "bigscience/bloomz-560m"
tokenizer_name_or_path = "meta-llama/Llama-2-7b-chat-hf"

In [3]:
# This looks like a more interesting way of instructing
# llm to get what we are looking for.
peft_config = PromptTuningConfig(
    task_type=TaskType.CAUSAL_LM,
    prompt_tuning_init=PromptTuningInit.TEXT,
    num_virtual_tokens=8,
    prompt_tuning_init_text="Classify if the tweet is a complaint or not:",
    tokenizer_name_or_path=model_name_or_path,
)

In [4]:
print(peft_config)

In [5]:
dataset_name = "twitter_complaints"

checkpoint_name = f"{dataset_name}_{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}_v1.pt".replace(
    "/", "_"
)

text_column = "Tweet text"
label_column = "text_label"

max_length = 64
lr = 3e-2

num_epochs = 5
batch_size = 16

In [6]:
dataset = load_dataset("ought/raft",
                       dataset_name)

dataset["train"][0]

Found cached dataset raft (/home/kamal/.cache/huggingface/datasets/ought___raft/twitter_complaints/1.1.0/79c4de1312c1e3730043f7db07179c914f48403101f7124e2fe336f6f54d9f84)


  0%|          | 0/2 [00:00<?, ?it/s]

{'Tweet text': '@HMRCcustomers No this is my first job', 'ID': 0, 'Label': 2}

In [11]:
print(dataset['train'])

In [12]:
print(dataset['train'].features)

In [7]:
classes = [k.replace("_", " ") for k in dataset["train"].features["Label"].names]

dataset = dataset.map(
    lambda x: {"text_label": [classes[label] for label in x["Label"]]},
    batched=True,
    num_proc=1,
)

dataset["train"][0]

Loading cached processed dataset at /home/kamal/.cache/huggingface/datasets/ought___raft/twitter_complaints/1.1.0/79c4de1312c1e3730043f7db07179c914f48403101f7124e2fe336f6f54d9f84/cache-f9e9f46ae8859106.arrow
Loading cached processed dataset at /home/kamal/.cache/huggingface/datasets/ought___raft/twitter_complaints/1.1.0/79c4de1312c1e3730043f7db07179c914f48403101f7124e2fe336f6f54d9f84/cache-b2de31db5549b4bf.arrow


{'Tweet text': '@HMRCcustomers No this is my first job',
 'ID': 0,
 'Label': 2,
 'text_label': 'no complaint'}

In [8]:
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

target_max_length = max([len(tokenizer(class_label)["input_ids"]) \
                         for class_label in classes])  # the process is elaborated below

print(target_max_length)

In [35]:
tokenizer.eos_token_id

2

In [36]:
tokenizer.bos_token_id

1

In [37]:
[tokenizer.tokenize(class_label) for class_label in classes]

[['▁Un', 'l', 'abeled'], ['▁compla', 'int'], ['▁no', '▁compla', 'int']]

In [38]:
[tokenizer(class_label)['input_ids'] for class_label in classes]

[[1, 853, 29880, 24025], [1, 15313, 524], [1, 694, 15313, 524]]

Create a preprocess_function to:

- Tokenize the input text and labels.

- For each example in a batch, pad the labels with the tokenizers pad_token_id.

- Concatenate the input text and labels into the model_inputs.

- Create a separate attention mask for labels and model_inputs.

- Loop through each example in the batch again to pad the input ids, labels, and attention mask to the max_length and convert them to PyTorch tensors.

In [9]:
# text_column = "Tweet text"
# label_column = "text_label"

def preprocess_function(examples):
    batch_size = len(examples[text_column])  # examples are expected to be batches from dataloader
    
    inputs = [f"{text_column} : {x} Label : " for x in examples[text_column]]
    targets = [str(x) for x in examples[label_column]]
    model_inputs = tokenizer(inputs)
    
    labels = tokenizer(targets)
    
    for i in range(batch_size):
        sample_input_ids = model_inputs["input_ids"][i]
        label_input_ids = labels["input_ids"][i] + [tokenizer.pad_token_id]
        # print(i, sample_input_ids, label_input_ids)
        model_inputs["input_ids"][i] = sample_input_ids + label_input_ids
        labels["input_ids"][i] = [-100] * len(sample_input_ids) + label_input_ids
        model_inputs["attention_mask"][i] = [1] * len(model_inputs["input_ids"][i])
    
    # print(model_inputs)
    for i in range(batch_size):
        sample_input_ids = model_inputs["input_ids"][i]
        label_input_ids = labels["input_ids"][i]
        model_inputs["input_ids"][i] = [tokenizer.pad_token_id] * (
            max_length - len(sample_input_ids)
        ) + sample_input_ids
        model_inputs["attention_mask"][i] = [0] * (max_length - len(sample_input_ids)) + model_inputs[
            "attention_mask"
        ][i]
        labels["input_ids"][i] = [-100] * (max_length - len(sample_input_ids)) + label_input_ids
        
        model_inputs["input_ids"][i] = torch.tensor(model_inputs["input_ids"][i][:max_length])
        model_inputs["attention_mask"][i] = torch.tensor(model_inputs["attention_mask"][i][:max_length])
        labels["input_ids"][i] = torch.tensor(labels["input_ids"][i][:max_length])
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [10]:
processed_datasets = dataset.map(
    preprocess_function,
    batched=True,
    num_proc=1,
    remove_columns=dataset["train"].column_names,
    load_from_cache_file=False,
    desc="Running tokenizer on dataset",
)

Running tokenizer on dataset:   0%|          | 0/50 [00:00<?, ? examples/s]

Running tokenizer on dataset:   0%|          | 0/3399 [00:00<?, ? examples/s]

In [11]:
train_dataset = processed_datasets["train"]
eval_dataset = processed_datasets["test"]


train_dataloader = DataLoader(
    train_dataset,
    shuffle=True,
    collate_fn=default_data_collator,
    batch_size=batch_size, pin_memory=True
)
eval_dataloader = DataLoader(eval_dataset,
                             collate_fn=default_data_collator,
                             batch_size=batch_size,
                             pin_memory=True)

In [16]:
from transformers import BitsAndBytesConfig
from peft import prepare_model_for_int8_training, get_peft_config, prepare_model_for_kbit_training

In [14]:
quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_quant_type='nf4'
)
quant_config

BitsAndBytesConfig {
  "bnb_4bit_compute_dtype": "bfloat16",
  "bnb_4bit_quant_type": "nf4",
  "bnb_4bit_use_double_quant": false,
  "llm_int8_enable_fp32_cpu_offload": false,
  "llm_int8_has_fp16_weight": false,
  "llm_int8_skip_modules": null,
  "llm_int8_threshold": 6.0,
  "load_in_4bit": true,
  "load_in_8bit": false,
  "quant_method": "bitsandbytes"
}

In [15]:
# Identified that the local model directory consists of tokeniser only, so moving to llama2-7b model
model = AutoModelForCausalLM.from_pretrained(
    model_name_or_path, 
    device_map="auto",
    quantization_config=quant_config,
    torch_dtype=torch.bfloat16
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [17]:
quant_model = prepare_model_for_kbit_training(model)

In [19]:
quant_model.device

device(type='cuda', index=0)

In [21]:
quant_model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear4bit(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm

In [22]:
# using the quant_model with peft_config 
model = get_peft_model(quant_model, peft_config)
print(model.print_trainable_parameters())

trainable params: 32,768 || all params: 6,738,448,384 || trainable%: 0.0004862840543203603


In [23]:
optimizer = torch.optim.AdamW(quant_model.parameters(), lr=lr)

In [24]:
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=(len(train_dataloader) * num_epochs),
)

In [25]:
# not required as the bitsnbytes config is taking care 
# model = model.to(device)

In [26]:
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for step, batch in enumerate(tqdm(train_dataloader)):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        total_loss += loss.detach().float()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

    model.eval()
    eval_loss = 0
    eval_preds = []
    for step, batch in enumerate(tqdm(eval_dataloader)):
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)
        loss = outputs.loss
        eval_loss += loss.detach().float()
        eval_preds.extend(
            tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(), skip_special_tokens=True)
        )

    eval_epoch_loss = eval_loss / len(eval_dataloader)
    eval_ppl = torch.exp(eval_epoch_loss)
    train_epoch_loss = total_loss / len(train_dataloader)
    train_ppl = torch.exp(train_epoch_loss)
    print(f"{epoch=}: {train_ppl=} {train_epoch_loss=} {eval_ppl=} {eval_epoch_loss=}")

  0%|          | 0/7 [00:00<?, ?it/s]`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
100%|██████████| 7/7 [00:06<00:00,  1.08it/s]
100%|██████████| 425/425 [02:03<00:00,  3.43it/s]


100%|██████████| 7/7 [00:05<00:00,  1.21it/s]
100%|██████████| 425/425 [02:03<00:00,  3.43it/s]


100%|██████████| 7/7 [00:05<00:00,  1.21it/s]
100%|██████████| 425/425 [02:03<00:00,  3.43it/s]


100%|██████████| 7/7 [00:05<00:00,  1.21it/s]
100%|██████████| 425/425 [02:03<00:00,  3.43it/s]


100%|██████████| 7/7 [00:05<00:00,  1.21it/s]
100%|██████████| 425/425 [02:03<00:00,  3.43it/s]


100%|██████████| 7/7 [00:05<00:00,  1.21it/s]
100%|██████████| 425/425 [02:03<00:00,  3.43it/s]


 14%|█▍        | 1/7 [00:01<00:09,  1.67s/it]


KeyboardInterrupt: 

In [28]:
peft_model_id = "llama7B_peft_PROMPT_TUNING_CAUSAL_LM"
# model.push_to_hub("your-name/bloomz-560m_PROMPT_TUNING_CAUSAL_LM",use_auth_token=True)
model.save_pretrained(f"~/training_files/{peft_model_id}")

In [29]:
model.save_pretrained(f"/home/kamal/training_files/{peft_model_id}")

In [30]:
from peft import PeftModel, PeftConfig

peft_model_path = f"/home/kamal/training_files/{peft_model_id}"

In [31]:
config = PeftConfig.from_pretrained(peft_model_path)

In [32]:
trained_model = PeftModel.from_pretrained(quant_model, peft_model_path)

In [33]:
inputs = tokenizer(
    f'{text_column} : {"@nationalgridus I have no water and the bill is current and paid. Can you do something about this?"} Label : ',
    return_tensors="pt",
)

In [34]:
trained_model.to(device)

PeftModelForCausalLM(
  (base_model): LlamaForCausalLM(
    (model): LlamaModel(
      (embed_tokens): Embedding(32000, 4096)
      (layers): ModuleList(
        (0-31): 32 x LlamaDecoderLayer(
          (self_attn): LlamaAttention(
            (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
            (k_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
            (v_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
            (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
            (rotary_emb): LlamaRotaryEmbedding()
          )
          (mlp): LlamaMLP(
            (gate_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
            (up_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
            (down_proj): Linear4bit(in_features=11008, out_features=4096, bias=False)
            (act_fn): SiLUActivation()
          )
          (input_layernorm): LlamaRMSNorm()


In [40]:
with torch.no_grad():
    inputs = {k: v.to(device) for k, v in inputs.items()}
    outputs = trained_model.generate(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        max_new_tokens=10,
        eos_token_id=3,
        temperature=0.1
    )
    print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))
