In [None]:
import transformers
import torch

In [None]:
from datasets import load_dataset
from trl import SFTTrainer

# Load PEFT
from peft import (
    get_peft_model,
    PromptTuningConfig
)

In [None]:
TOKEN = ''  #Give HF access token here

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "meta-llama/Llama-2-7b-hf"
model = AutoModelForCausalLM.from_pretrained(model_name,
                                            load_in_8bit=False,
                                            device_map="auto",
                                            token=TOKEN
                                            )
tokenizer = AutoTokenizer.from_pretrained(model_name,token = TOKEN)

In [None]:
tokenizer.decode(tokenizer.eos_token_id) # </s> is EOS

In [None]:
# We'll fine tune on summarization dataset
data = load_dataset("samsum")
data_train, data_test, data_val = data["train"], data["test"], data["validation"]

In [None]:
#Now we'll give "Summarize the following:\n" in prompt init

def generate_prompt(dialogue, summary=None, eos_token="</s>"):
  #instruction = "Summarize the following:\n"
  input = f"{dialogue}\n"
  summary = f"Summary: {summary + ' ' + eos_token if summary else ''} "
  prompt = (" ").join([input, summary])
  return prompt

print(generate_prompt(data_train[0]["dialogue"], data_train[0]["summary"]))

In [None]:
generate_prompt(data_train[50]["dialogue"])

In [None]:
#Checking performance of untuned model
input_prompt = "Summarise the following: \n" + generate_prompt(data_train[50]["dialogue"])
input_tokens = tokenizer(input_prompt, return_tensors="pt")["input_ids"].to("cuda")
with torch.cuda.amp.autocast():
  generation_output = model.generate(
      input_ids=input_tokens,
      max_new_tokens=1000,
      do_sample=True,
      top_k=10,
      top_p=0.9,
      temperature=0.3,
      repetition_penalty=1.15,
      num_return_sequences=1,
      eos_token_id=tokenizer.eos_token_id,
    )
op = tokenizer.decode(generation_output[0], skip_special_tokens=True)
print(op)

In [None]:
tokenizer.add_special_tokens({"pad_token": "<PAD>"})
model.resize_token_embeddings(len(tokenizer))

In [None]:
tokenizer.special_tokens_map

In [None]:
peft_config = PromptTuningConfig(peft_type="PROMPT_TUNING",task_type="CAUSAL_LM", num_virtual_tokens=20, token_dim=4096, prompt_tuning_init="TEXT",  #token_dim needs to be same as hidden size
    prompt_tuning_init_text="Summarize the following:\n",
    tokenizer_name_or_path='meta-llama/Llama-2-7b-hf',tokenizer_kwargs = {'token':TOKEN})
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

In [None]:
model.device

In [None]:
output_dir = "cp2"
per_device_train_batch_size = 1
gradient_accumulation_steps = 4
per_device_eval_batch_size = 1
eval_accumulation_steps = 4
optim = "adamw_hf"
save_steps = 10
logging_steps = 10
learning_rate = 5e-4
max_grad_norm = 0.3
max_steps = 30
warmup_ratio = 0.03
evaluation_strategy="steps"
lr_scheduler_type = "constant"

training_args = transformers.TrainingArguments(
            output_dir=output_dir,
            per_device_train_batch_size=per_device_train_batch_size,
            gradient_accumulation_steps=gradient_accumulation_steps,
            optim=optim,
            evaluation_strategy=evaluation_strategy,
            save_steps=save_steps,
            learning_rate=learning_rate,
            logging_steps=logging_steps,
            max_grad_norm=max_grad_norm,
            max_steps=max_steps,
            warmup_ratio=warmup_ratio,
            group_by_length=True,
            lr_scheduler_type=lr_scheduler_type,
            ddp_find_unused_parameters=False,
            eval_accumulation_steps=eval_accumulation_steps,
            per_device_eval_batch_size=per_device_eval_batch_size,
        )

In [None]:
def formatting_func(prompt):
  output = []

  for d, s in zip(prompt["dialogue"], prompt["summary"]):
    op = generate_prompt(d, s)
    output.append(op)

  return output


trainer = SFTTrainer(
    model=model,
    train_dataset=data_train,
    eval_dataset=data_val,
    peft_config=peft_config,
    formatting_func=formatting_func,
    max_seq_length=1024,
    tokenizer=tokenizer,
    args=training_args
)

# We will also pre-process the model by upcasting the layer norms in float 32 for more stable training
for name, module in trainer.model.named_modules():
    if "norm" in name:
        module = module.to(torch.float32)

trainer.train()
trainer.save_model(f"{output_dir}/final")

In [None]:
#Test trained model performance

input_prompt = generate_prompt(data_train[50]["dialogue"])
input_tokens = tokenizer(input_prompt, return_tensors="pt")["input_ids"].to("cuda")
with torch.cuda.amp.autocast():
    generation_output = model.generate(
        input_ids=input_tokens,
        max_new_tokens=100,
        do_sample=True,
        top_k=10,
        top_p=0.9,
        temperature=0.3,
        repetition_penalty=1.15,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
      )
op = tokenizer.decode(generation_output[0], skip_special_tokens=True)
print(op)

Loading saved model and inference

In [None]:
from peft import PeftModel
peft_model_id = "cp2/final"
model_name = "meta-llama/Llama-2-7b-hf"
model = AutoModelForCausalLM.from_pretrained(model_name,
                                            load_in_8bit=False,
                                            device_map="auto",
                                            token=TOKEN
                                            )
tokenizer = AutoTokenizer.from_pretrained(model_name,token=TOKEN)
tokenizer.add_special_tokens({"pad_token": "<PAD>"})
model.resize_token_embeddings(len(tokenizer))
peft_model = PeftModel.from_pretrained(model, peft_model_id)

In [None]:
peft_model

In [None]:
input_prompt = generate_prompt(data_train[50]["dialogue"])
input_tokens = tokenizer(input_prompt, return_tensors="pt")["input_ids"].to("cuda")
with torch.cuda.amp.autocast():
    generation_output = peft_model.generate(
        input_ids=input_tokens,
        max_new_tokens=100,
        do_sample=True,
        top_k=10,
        top_p=0.9,
        temperature=0.3,
        repetition_penalty=1.15,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
      )
op = tokenizer.decode(generation_output[0], skip_special_tokens=True)
print(op)