In [None]:
%pip install -q --disable-pip-version-check \
    evaluate==0.4.0 \
    py7zr==0.20.4 \
    sentencepiece==0.1.99 \
    rouge_score==0.1.2 \
    loralib==0.1.1 \
    peft==0.4.0 \
    trl==0.7.2 \
    bert_score
%pip install -q    wandb bitsandbytes accelerate

## Set up all random seeds

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!wget https://github.com/wandb/edu/raw/main/llm-training-course/colab/utils.py

In [None]:
from google.colab import output
output.enable_custom_widget_manager()

In [None]:
PROJECT = "FlanT5-Lora"
MODEL_NAME = 'google/flan-t5-base'
DATASET = "MeQSum"
WANDB_ID = "CO_finetune_1"

In [None]:
import wandb
wandb.init(project=PROJECT, # the project I am working on
           tags=[MODEL_NAME, DATASET],
           notes ="Fine tuning FlanT5 with MeQSum Dataset. Prompt Instruction",
           id=WANDB_ID, resume='allow') # the Hyperparameters I want to keep track of

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [None]:
from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig, TrainingArguments, Trainer
import torch
import time
import evaluate
import pandas as pd
import numpy as np
import random

# tqdm library makes the loops show a smart progress meter.
from tqdm import tqdm
tqdm.pandas()

In [None]:
import locale
locale.getpreferredencoding = lambda: "UTF-8"

In [None]:
TRAIN_PATH = '/content/drive/MyDrive/cs577_proj_dataset/cooccur_dataset/train_cooccur_dataset.csv'
VAL_PATH = '/content/drive/MyDrive/cs577_proj_dataset/cooccur_dataset/validation_cooccur_dataset.csv'
TEST_PATH = '/content/drive/MyDrive/cs577_proj_dataset/cooccur_dataset/test_cooccur_dataset.csv'

In [None]:
random_seed = 0
np.random.seed(random_seed)
torch.manual_seed(random_seed)
random.seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
from datasets import load_dataset

# Define the paths to your files

def create_hf_dataset(train_path, val_path, test_path):
    data_files = {}
    data_files["train"] = train_path
    data_files["test"]  = test_path
    data_files["validation"] = val_path

    # Load the datasets
    dataset = load_dataset('json', data_files=data_files)

    # Print the first example of the training dataset
    print(dataset)

    return dataset


def create_hf_dataset_from_CSV(train_path, val_path, test_path):
    data_files = {}
    data_files["train"] = train_path
    data_files["test"]  = test_path
    data_files["validation"] = val_path

    # Load the datasets
    dataset = load_dataset('csv', data_files=data_files)

    # Print the first example of the training dataset
    print(dataset)

    return dataset

dataset = create_hf_dataset_from_CSV(
    TRAIN_PATH,
    VAL_PATH,
    TEST_PATH
)



Generating train split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['Unnamed: 0', 'idx', 'inputs', 'target'],
        num_rows: 1000
    })
    test: Dataset({
        features: ['Unnamed: 0', 'idx', 'inputs', 'target'],
        num_rows: 150
    })
    validation: Dataset({
        features: ['Unnamed: 0', 'idx', 'inputs', 'target'],
        num_rows: 50
    })
})


In [None]:
dataset['train'].to_pandas()

In [None]:
model_name='google/flan-t5-base'

original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
START_PROMPT_1 = 'Summarize the following medical question using context around it.\n\n'
END_PROMPT_1 = '\n\nQuestion Summary:'

START_PROMPT_2 = 'Read through the whole context and summarize the medical question\n\n'
END_PROMPT_2 = '\n\medical question summary:'

NER_START_PROMPT_1 = 'Read through the whole context and summarize the medical question focusing on tags supplied within <> brackets.\n\n'
NER_END_PROMPT_1= '\n\nQuestion Summary:'

NER_START_PROMPT_2 = 'Reading the context, shortly summarize the medical question focusing on tags within <>. Focus on <MEDICATION>,<DIAGNOSTIC_PROCEDURE> \
<BIOLOGICAL_ATTRIBUTE>,<SIGN_SYMPTOM>,<BIOLOGICAL_STRUCTURE>,<DISEASE_DISORDER> if present.\n\n'
NER_END_PROMPT_2= '\n\nmedical question summary:'

CO_START_PROMPT_1 = 'Read through the whole context and summarize the medical question focusing on co-occurrence of pairs of words in <> brackets separated by - appear together if <> is present after the sentence.\n\n'
CO_END_PROMPT_1 = '\n\nmedical question summary:'

In [None]:
def tokenize_function(example):
    prompt = [CO_START_PROMPT_1 + question + CO_END_PROMPT_1 for question in example["inputs"]]
    example['input_ids'] = tokenizer(prompt, padding="max_length", truncation=True, return_tensors="pt").input_ids
    example['labels'] = tokenizer(example["target"], padding="max_length", truncation=True, return_tensors="pt").input_ids

    return example

# The dataset actually contains 3 diff splits: train, validation, test.
# The tokenize_function code is handling all data across all splits in batches.
tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(['Unnamed: 0', 'idx', 'inputs', 'target',])

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Map:   0%|          | 0/150 [00:00<?, ? examples/s]

Map:   0%|          | 0/50 [00:00<?, ? examples/s]

In [None]:
tokenized_datasets['train'].to_pandas().head()

Unnamed: 0,input_ids,labels
0,"[3403, 190, 8, 829, 2625, 11, 21603, 8, 1035, ...","[2645, 9421, 7, 197, 20408, 7196, 15, 58, 1, 0..."
1,"[3403, 190, 8, 829, 2625, 11, 21603, 8, 1035, ...","[2645, 9421, 7, 9161, 51, 32, 29771, 630, 58, ..."
2,"[3403, 190, 8, 829, 2625, 11, 21603, 8, 1035, ...","[2645, 656, 206, 120, 1625, 63, 6, 11, 213, 54..."
3,"[3403, 190, 8, 829, 2625, 11, 21603, 8, 1035, ...","[2840, 54, 27, 129, 6472, 2505, 21, 56, 23, 26..."
4,"[3403, 190, 8, 829, 2625, 11, 21603, 8, 1035, ...","[2840, 54, 27, 129, 6472, 2505, 21, 1317, 82, ..."


In [None]:
print(f"Shapes of the datasets:")
print(f"Training: {tokenized_datasets['train'].shape}")
print(f"Validation: {tokenized_datasets['validation'].shape}")
print(f"Test: {tokenized_datasets['test'].shape}")

print(tokenized_datasets)

Shapes of the datasets:
Training: (1000, 2)
Validation: (50, 2)
Test: (150, 2)
DatasetDict({
    train: Dataset({
        features: ['input_ids', 'labels'],
        num_rows: 1000
    })
    test: Dataset({
        features: ['input_ids', 'labels'],
        num_rows: 150
    })
    validation: Dataset({
        features: ['input_ids', 'labels'],
        num_rows: 50
    })
})


In [None]:
from types import SimpleNamespace
from pathlib import Path
from tqdm.notebook import tqdm
from datetime import datetime
import nltk

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


config = SimpleNamespace(
    # hyperparameters
    learning_rate=1e-4,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    gradient_accumulation_steps=4,
    num_train_epochs=30,
    save_steps=5,
    save_strategy='epoch', # we cannot set it to "no". Otherwise, the model cannot guess the best checkpoint.
    eval_steps=5,
    logging_steps=5,
    evaluation_strategy='epoch',
    warmup_steps=500,
    save_total_limit=2,
    load_best_model_at_end = True,
    output_dir = f'./MeQSum-training-{str(int(time.time()))}'
)

In [None]:
training_args = TrainingArguments(
    output_dir=config.output_dir,
    learning_rate=config.learning_rate,
    gradient_accumulation_steps=config.gradient_accumulation_steps,
   num_train_epochs=config.num_train_epochs,
    save_steps=config.save_steps,
    save_strategy=config.save_strategy, # we cannot set it to "no". Otherwise, the model cannot guess the best checkpoint.
    eval_steps=config.eval_steps,
    logging_steps=config.logging_steps,
    evaluation_strategy=config.evaluation_strategy,
    warmup_steps=config.warmup_steps,
    save_total_limit=config.save_total_limit,
    load_best_model_at_end = config.load_best_model_at_end,
    report_to="wandb",
    run_name=f"Prompt_tuning_original_model-{datetime.now().strftime('%Y-%m-%d-%H-%M')}"
)

trainer = Trainer(
    model=original_model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validation']
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [None]:
# free up GPU Memory
torch.cuda.empty_cache()

In [None]:
with wandb.init(project=PROJECT, id=WANDB_ID, resume='allow'):
  trainer.train()

VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

Epoch,Training Loss,Validation Loss
0,44.9375,47.119999
1,43.175,44.310001
2,39.9,39.490002
4,26.575,25.785
5,12.6062,6.6725
6,4.6594,4.25375
8,3.0719,2.104687
9,2.293,1.179844
10,1.6598,0.579375
12,0.7186,0.261992


There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight'].


VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
eval/loss,██▇▆▅▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
eval/runtime,▂▄▅▃▂▁▅▁▄▃▄▂█▂▅▃▆▆▆▂▂▃▃▅▃▃▆▆▆▂
eval/samples_per_second,▇▅▄▆▇█▄█▅▆▅▇▁▇▄▆▃▃▃▇▇▆▆▃▆▆▃▃▃▇
eval/steps_per_second,▇▅▄▆▇▇▄█▅▆▅▇▁▇▄▆▃▃▃▇▇▆▆▃▆▆▃▃▃▇
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/grad_norm,██▇▇▆▅▃▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/learning_rate,▁▁▂▂▂▃▃▃▄▄▄▅▅▅▆▆▆▇▇▇███▇▇▇▆▆▅▅▅▄▄▃▃▃▂▂▁▁
train/loss,███▇▇▆▅▄▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/total_flos,▁

0,1
eval/loss,0.06978
eval/runtime,1.225
eval/samples_per_second,40.816
eval/steps_per_second,5.714
train/epoch,29.76
train/global_step,930.0
train/grad_norm,0.77734
train/learning_rate,0.0
train/loss,0.084
train/total_flos,2.037837880885248e+16


In [None]:
trainer.state.best_model_checkpoint

'./MeQSum-training-1713670191/checkpoint-875'

In [None]:
trainer.state.best_model_checkpoint
#!mkdir flan-dialogue-summary-checkpoint
custom_path = "/content/drive/MyDrive/cs577_proj_dataset/MODELS/flan-MeQSum-normal_train-checkpoint_lr_1e_4_CO_1/"
trainer.save_model(output_dir=custom_path)

In [None]:
# with wandb.init(project=PROJECT, job_type="models"):
#   artifact = wandb.Artifact("instruct_model", type="model")
#   artifact.add_dir(custom_path)
#   wandb.save(custom_path)
#   wandb.log_artifact(artifact)

In [None]:
instruct_model = AutoModelForSeq2SeqLM.from_pretrained(custom_path, torch_dtype=torch.bfloat16)

In [None]:
instruct_model = instruct_model.to("cuda")

## Evaluation of the model

In [None]:
from bert_score import score

In [None]:
def compute_rogue_metric(inference_df):
  rouge = evaluate.load('rouge')

  finetuned_model_results = rouge.compute(
    predictions=inference_df['finetuned_model_summaries'].tolist(),
    references=inference_df['human_baseline_summaries'].tolist(),
    use_aggregator=True,
    use_stemmer=True,
  )

  return finetuned_model_results

def compute_bleu_metric(inference_df):
  # Load the BLEU metric
  bleu = evaluate.load("bleu")
  #print(type(inference_df['finetuned_model_summaries'].tolist()))

  references_questions = [[target] for target in inference_df['human_baseline_summaries'].tolist()]
  #prediction_questions = [[target] for target in inference_df['finetuned_model_summaries'].tolist()]
  #print(references_questions)
  # Compute BLEU score
  bleu_results = bleu.compute(
      predictions=inference_df['finetuned_model_summaries'].tolist(),
      references=references_questions
  )

  # print('BLEU SCORE:')
  # print(bleu_results['score'])
  return bleu_results['bleu']

def compute_bert_score(inference_df):
  # Compute BERTScore
  P, R, F1 = score(
      inference_df['finetuned_model_summaries'].tolist(),
      inference_df['human_baseline_summaries'].tolist(), lang="en")

  # print('BERT SCORE:')
  # print('Precision:', P.mean().item())
  # print('Recall:', R.mean().item())
  # print('F1 Score:', F1.mean().item())

  return P.mean().item(), R.mean().item(), F1.mean().item()


In [None]:
dataset['test']['inputs']

In [None]:
def generate_test_set_inference(dataset, instruct_model, START_PROMPT, END_PROMPT):
  questions = dataset['test'][0:]['inputs']
  human_baseline_summaries = dataset['test'][0:]['target']

  finetuned_model_summaries = []

  for _, question in enumerate(tqdm(questions)):
      prompt = START_PROMPT + question + END_PROMPT
      input_ids = tokenizer(prompt, return_tensors="pt").input_ids
      input_ids = input_ids.to("cuda")

      instruct_model_outputs = instruct_model.generate(input_ids=input_ids, generation_config=GenerationConfig(max_new_tokens=100), num_beams=3, repetition_penalty=1.5)
      original_model_text_output = tokenizer.decode(instruct_model_outputs[0], skip_special_tokens=True)
      finetuned_model_summaries.append(original_model_text_output)

  zipped_summaries = list(zip(human_baseline_summaries,finetuned_model_summaries))

  df = pd.DataFrame(zipped_summaries, columns = ['human_baseline_summaries','finetuned_model_summaries'])

  print(df.head())

  return df


In [None]:
inference_df = generate_test_set_inference(dataset, instruct_model, CO_START_PROMPT_1, CO_END_PROMPT_1)

  0%|          | 0/150 [00:00<?, ?it/s]

                            human_baseline_summaries  \
0  How can i get rid of a lower lip birthmark per...   
1       Is Magnesium Silicofluoride safe for people?   
2                      Could RhoGAM damage the baby?   
3  Could hydroxychloroquine and methotrexate make...   
4  Is there a relationship between Gadolinium and...   

                           finetuned_model_summaries  
0  What are the treatments for a lower lip birthm...  
1  If the rug is treated in house how long before...  
2  Look for co-occurrence of pairs of words in > ...  
3  What are the side effects of hydroxychloroquin...  
4  What are the symptoms of Gadolinum toxicity an...  


In [None]:
inference_df.head()

Unnamed: 0,human_baseline_summaries,finetuned_model_summaries
0,How can i get rid of a lower lip birthmark per...,What are the treatments for a lower lip birthm...
1,Is Magnesium Silicofluoride safe for people?,If the rug is treated in house how long before...
2,Could RhoGAM damage the baby?,Look for co-occurrence of pairs of words in > ...
3,Could hydroxychloroquine and methotrexate make...,What are the side effects of hydroxychloroquin...
4,Is there a relationship between Gadolinium and...,What are the symptoms of Gadolinum toxicity an...


In [None]:
rogue_score = compute_rogue_metric(inference_df)
bleu_score = compute_bleu_metric(inference_df)
bert_score_precision, bert_score_recall, bert_score_f1  = compute_bert_score(inference_df)

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
print(f'ROGUE SCORE: {rogue_score}')
print(f'BLEU SCORE: {bleu_score}')
print(f'BERT SCORE: PRECISION: {bert_score_precision}, RECALL: {bert_score_recall}, F1: {bert_score_f1}')

ROGUE SCORE: {'rouge1': 0.24809925390954346, 'rouge2': 0.09352748463861102, 'rougeL': 0.21081592839525498, 'rougeLsum': 0.21063710058600796}
BLEU SCORE: 0.040239441602852706
BERT SCORE: PRECISION: 0.86691814661026, RECALL: 0.8758851885795593, F1: 0.8710768818855286


In [None]:
with wandb.init(project=PROJECT, id=WANDB_ID, resume="allow"):
  wandb.log({"rogue_score": rogue_score})
  wandb.log({"bleu_score": bleu_score})
  wandb.log({"bert_score_precision": bert_score_precision})
  wandb.log({"bert_score_recall": bert_score_recall})
  wandb.log({"bert_score_f1": bert_score_f1})


VBox(children=(Label(value='0.021 MB of 0.021 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
bert_score_f1,▁
bert_score_precision,▁
bert_score_recall,▁
bleu_score,▁

0,1
bert_score_f1,0.87108
bert_score_precision,0.86692
bert_score_recall,0.87589
bleu_score,0.04024
eval/loss,0.06978
eval/runtime,1.225
eval/samples_per_second,40.816
eval/steps_per_second,5.714
train/epoch,29.76
train/global_step,930.0


## Logging Test Set scores and inference to wandb

In [None]:
with wandb.init(project=PROJECT, id=WANDB_ID, job_type="dataset", resume="allow"):
   wbtest_inference = wandb.Table(data=inference_df)
   wandb.log({"meqsum_test_inference": wbtest_inference})

VBox(children=(Label(value='0.050 MB of 0.050 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
bert_score_f1,0.87108
bert_score_precision,0.86692
bert_score_recall,0.87589
bleu_score,0.04024
eval/loss,0.06978
eval/runtime,1.225
eval/samples_per_second,40.816
eval/steps_per_second,5.714
train/epoch,29.76
train/global_step,930.0
