zero_shot performance

In [7]:
from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig, TrainingArguments, Trainer, BitsAndBytesConfig
import torch
import time
import evaluate
import pandas as pd
import numpy as np

huggingface_dataset_name = "knkarthick/dialogsum"
dataset = load_dataset(huggingface_dataset_name)

def get_promp(sample):
    dialogue = sample['dialogue']
    summary = sample['summary']
    
    prompt = f"""
    Summarize the following conversation.
    
    {dialogue}
    
    Summary:
    """
    return prompt

all_pred = []
all_ans = []
for i in range(len(dataset['test'])):
    prompt = get_promp(dataset['test'][i])
    inputs = tokenizer(prompt, return_tensors='pt')
    output = tokenizer.decode(
        original_model.generate(
            inputs["input_ids"], 
            max_new_tokens=200,
        )[0], 
        skip_special_tokens=True
    )
    all_pred.append(output)
    all_ans.append(dataset['test'][i]['summary'])
    if i> 100:
        break


for i in range(len(all_pred)):
    print("pred:", all_pred[i])
    print("ans:", all_ans[i])    

pred: #Person1#: I need to take a dictation for you.
ans: Ms. Dawson helps #Person1# to write a memo to inform every employee that they have to change the communication method and should not use Instant Messaging anymore.
pred: #Person1#: I need to take a dictation for you.
ans: In order to prevent employees from wasting time on Instant Message programs, #Person1# decides to terminate the use of those programs and asks Ms. Dawson to send out a memo to all employees by the afternoon.
pred: #Person1#: I need to take a dictation for you.
ans: Ms. Dawson takes a dictation for #Person1# about prohibiting the use of Instant Message programs in the office. They argue about its reasonability but #Person1# still insists.
pred: The traffic jam at the Carrefour intersection is a problem.
ans: #Person2# arrives late because of traffic jam. #Person1# persuades #Person2# to use public transportations to keep healthy and to protect the environment.
pred: The traffic jam at the Carrefour intersection 

In [8]:
from evaluate import load

# Load the metrics
meteor = load("meteor")
bleu = load("bleu")
rouge = load("rouge")

rouge_results = rouge.compute(predictions=all_pred, references=all_ans)
bleu_result = bleu.compute(predictions=all_pred, references=all_ans)
meteor_result = meteor.compute(predictions=all_pred, references=all_ans)

print("ROUGE:", rouge_results)
print("BLEU:", bleu_result)
print("METEOR:", meteor_result)

Downloading builder script:   0%|          | 0.00/7.02k [00:00<?, ?B/s]

[nltk_data] Downloading package wordnet to /home/sa5u24/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to /home/sa5u24/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /home/sa5u24/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


Downloading builder script:   0%|          | 0.00/5.94k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/1.55k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/3.34k [00:00<?, ?B/s]

ROUGE: {'rouge1': 0.21274141070336455, 'rouge2': 0.05725557885416832, 'rougeL': 0.18563351545863377, 'rougeLsum': 0.18558622087298865}
BLEU: {'bleu': 0.04355090180693124, 'precisions': [0.18707179028894846, 0.07311827956989247, 0.034887408816999685, 0.0075385119632907244], 'brevity_penalty': 1.0, 'length_ratio': 1.433390264730999, 'translation_length': 3357, 'reference_length': 2342}
METEOR: {'meteor': 0.19492674781016342}


LoRA Fine-tuning

In [3]:
from peft import LoraConfig, get_peft_model, TaskType
from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig, TrainingArguments, Trainer, BitsAndBytesConfig
import torch
import time
import evaluate
import pandas as pd
import numpy as np

huggingface_dataset_name = "knkarthick/dialogsum"
dataset = load_dataset(huggingface_dataset_name)

lora_config = LoraConfig(
    r=8, # Rank
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM # FLAN-T5
)

# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)

model_name='google/flan-t5-base'
original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, quantization_config=bnb_config)
tokenizer = AutoTokenizer.from_pretrained(model_name)

peft_model = get_peft_model(original_model, 
                            lora_config)

def print_number_of_trainable_model_parameters(model):
    trainable_model_params = 0
    all_model_params = 0
    for _, param in model.named_parameters():
        all_model_params += param.numel()
        if param.requires_grad:
            trainable_model_params += param.numel()
    return f"trainable model parameters: {trainable_model_params}\nall model parameters: {all_model_params}\npercentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%"

print(print_number_of_trainable_model_parameters(peft_model))


`low_cpu_mem_usage` was None, now set to True since model is quantized.


trainable model parameters: 884736
all model parameters: 168246528
percentage of trainable model parameters: 0.53%




In [4]:
output_dir = '/home/sa5u24/safe_lora/temp'

def tokenize_function(example):
    start_prompt = 'Summarize the following conversation.\n\n'
    end_prompt = '\n\nSummary: '
    prompt = [start_prompt + dialogue + end_prompt for dialogue in example["dialogue"]]
    example['input_ids'] = tokenizer(prompt, padding="max_length", truncation=True, return_tensors="pt").input_ids
    example['labels'] = tokenizer(example["summary"], padding="max_length", truncation=True, return_tensors="pt").input_ids
    
    return example

# The dataset actually contains 3 diff splits: train, validation, test.
# The tokenize_function code is handling all data across all splits in batches.
tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(['id', 'topic', 'dialogue', 'summary',])


peft_training_args = TrainingArguments(
    output_dir=output_dir,
    # auto_find_batch_size=True,
    per_device_train_batch_size=8,
    learning_rate=1e-3, # Higher learning rate than full fine-tuning.
    num_train_epochs=3,
    logging_steps=20,
    # max_steps=1    
)
    
peft_trainer = Trainer(
    model=peft_model,
    args=peft_training_args,
    train_dataset=tokenized_datasets["test"],
)

peft_trainer.train()

Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33maoshuang0[0m ([33maoshuang0-university-of-southampton[0m). Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss
20,21.1187
40,3.0352
60,0.7038
80,0.2973
100,0.2131
120,0.1919
140,0.1646
160,0.1637
180,0.1592
200,0.1508


TrainOutput(global_step=564, training_loss=1.0128511469414894, metrics={'train_runtime': 217.9416, 'train_samples_per_second': 20.648, 'train_steps_per_second': 2.588, 'total_flos': 3093638676480000.0, 'train_loss': 1.0128511469414894, 'epoch': 3.0})

In [5]:
#save path
peft_model_path="/home/sa5u24/safe_lora/temp"

peft_trainer.model.save_pretrained(peft_model_path)
tokenizer.save_pretrained(peft_model_path)

('/home/sa5u24/safe_lora/temp/tokenizer_config.json',
 '/home/sa5u24/safe_lora/temp/special_tokens_map.json',
 '/home/sa5u24/safe_lora/temp/spiece.model',
 '/home/sa5u24/safe_lora/temp/added_tokens.json',
 '/home/sa5u24/safe_lora/temp/tokenizer.json')

Inference for LoRA & zero-shot

In [1]:
from peft import PeftModel, PeftConfig
from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig, TrainingArguments, Trainer, BitsAndBytesConfig
import torch
import time
import evaluate
import pandas as pd
import numpy as np

huggingface_dataset_name = "knkarthick/dialogsum"
dataset = load_dataset(huggingface_dataset_name)

model_name='google/flan-t5-base'
peft_model_path="/home/sa5u24/safe_lora/temp"

peft_model_base = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map = 'auto')
tokenizer = AutoTokenizer.from_pretrained(model_name)

peft_model = PeftModel.from_pretrained(peft_model_base, 
                                       peft_model_path, 
                                       torch_dtype=torch.bfloat16,
                                       is_trainable=False,
                                       )

original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map = 'auto')





In [2]:
dialogues = dataset['validation'][0:100]['dialogue']
human_baseline_summaries = dataset['validation'][0:100]['summary']

original_model_summaries = []
instruct_model_summaries = []
peft_model_summaries = []

for idx, dialogue in enumerate(dialogues):
    prompt = f"""
                Summarize the following conversation.
                
                {dialogue}
                
                Summary: """
    
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
    
    human_baseline_text_output = human_baseline_summaries[idx]
    
    original_model_outputs = original_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200))
    original_model_text_output = tokenizer.decode(original_model_outputs[0], skip_special_tokens=True)

    peft_model_outputs = peft_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=200))
    peft_model_text_output = tokenizer.decode(peft_model_outputs[0], skip_special_tokens=True)

    original_model_summaries.append(original_model_text_output)
    peft_model_summaries.append(peft_model_text_output)

zipped_summaries = list(zip(human_baseline_summaries, original_model_summaries, peft_model_summaries))
 
df = pd.DataFrame(zipped_summaries, columns = ['human_baseline_summaries', 'original_model_summaries', 'peft_model_summaries'])
df

Token indices sequence length is longer than the specified maximum sequence length for this model (630 > 512). Running this sequence through the model will result in indexing errors


Unnamed: 0,human_baseline_summaries,original_model_summaries,peft_model_summaries
0,#Person2# has trouble breathing. The doctor as...,"Person1: Hello, how are you doing today?",#Person2# has been having trouble breathing la...
1,#Person1# invites Jimmy to go workout and pers...,#Person1#: Hey Jimmy. Let's go workout later t...,#Person1# and Jimmy are going to work out late...
2,#Person1# plans to stop eating unhealthy foods...,#Person1#: I'm trying to lose weight. #Person2...,#Person1# wants to stop eating unhealthy foods...
3,#Person2# believes in UFOs and can see them in...,#Person1#: I've never seen UFOs. #Person2#: I'...,#Person1# is skeptical of UFOs and asks #Perso...
4,#Person1# didn't go to school today. #Person2#...,Person1 didn't go to school today.,#Person1# doesn't want to go to school today. ...
...,...,...,...
95,#Person2# tells #Person1# about a funny experi...,The story of the trip was a great one.,#Person2# and #Person2# travelled throughout I...
96,#Person2# has an interview schedule on Wednesd...,The manager will interview Person1 tomorrow at...,"#Person2# is asked for an interview, but #Pers..."
97,#Person1# wants to start a marathon and #Perso...,#Person1#: I'm a good runner.,#Person1# wants to run a marathon and #Person2...
98,#Person1# wants to research Christian and Izek...,The new working partner is a Christian.,#Person1# is doing an essay about Christian re...


In [3]:
rouge = evaluate.load('rouge')

original_model_results = rouge.compute(
    predictions=original_model_summaries,
    references=human_baseline_summaries[0:len(original_model_summaries)],
    use_aggregator=True,
    use_stemmer=True,
)


peft_model_results = rouge.compute(
    predictions=peft_model_summaries,
    references=human_baseline_summaries[0:len(peft_model_summaries)],
    use_aggregator=True,
    use_stemmer=True,
)

print('ORIGINAL MODEL:')
print(original_model_results)

print('PEFT MODEL:')
print(peft_model_results)

ORIGINAL MODEL:
{'rouge1': 0.23842230536363485, 'rouge2': 0.07294566939878172, 'rougeL': 0.20724668243568395, 'rougeLsum': 0.20814136804327285}
PEFT MODEL:
{'rouge1': 0.41949794950199876, 'rouge2': 0.1638233292110518, 'rougeL': 0.3450237504669067, 'rougeLsum': 0.3456101094061006}


In [5]:
!pwd

/home/sa5u24/safe_lora
