In [None]:
!pip install transformers datasets evaluate accelerate torch sacremoses peft gdown openpyxl
!pip install pyarrow==15.0.2

Collecting datasets
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting evaluate
  Downloading evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Collecting sacremoses
  Downloading sacremoses-0.1.1-py3-none-any.whl.metadata (8.3 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.1.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m16.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading evaluate-0.4.3-py3-none-any.whl (84 k

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import re
import os
import json
import pandas as pd
import numpy as np

from ast import literal_eval

In [None]:
base_dir = '/content/drive/MyDrive/NLP/vaers_analysis'
data_dir = os.path.join(base_dir, 'data')
results_dir = os.path.join(base_dir, 'results')
models_dir = os.path.join(base_dir, 'models')
runs_dir = os.path.join(base_dir, 'runs')

In [None]:
data = pd.read_csv(os.path.join(data_dir, 'train_vaers_data.csv'))
data['symptoms'] = data['symptoms'].apply(literal_eval)
data['ordered_symptoms'] = data['ordered_symptoms'].apply(literal_eval)
data.head()

Unnamed: 0,vaers_id,year,vax_type,vax_manu,symptom_text,symptoms,ordered_symptoms,report_length,num_symptoms
0,1563876,2021,COVID19,MODERNA,increase in blood pressure; Knot on right arm;...,"[Blood pressure increased, Muscle twitching, U...","[Urticaria, Vaccination site mass, Muscle twit...",234,4
1,1121903,2021,COVID19,MODERNA,"The day after the vaccine, I had extreme fatig...","[Chills, Fatigue, Headache, Injection site pai...","[Injection site pain, Rash erythematous, Fatig...",96,7
2,2501590,2022,COVID19,MODERNA,Narrative: Patient was not previously COVID-19...,"[Death, Fall, Wheelchair user]","[Fall, Wheelchair user, Death]",98,3
3,1228802,2021,COVID19,MODERNA,Several minutes after receiving 1st dose of Mo...,"[Fall, Loss of consciousness, Syncope, Unrespo...","[Loss of consciousness, Syncope, Unresponsive ...",112,4
4,1666419,2021,COVID19,MODERNA,Patient is an 86-year-old woman with a history...,"[Aphasia, COVID-19, Confusional state, Mental ...","[Respiratory symptom, COVID-19, SARS-CoV-2 tes...",212,8


## Data Preparation for PEFT:

In [None]:
import os
import json
import math
import pandas as pd
import matplotlib.pyplot as plt

import torch
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    TrainingArguments,
    Trainer
)

from datasets import Dataset, DatasetDict

In [None]:
model_name = 'biobart-v2-large'
model_vendor = 'GanjinZero'
output_dir = os.path.join(runs_dir, model_name)
peft_model_path = os.path.join(models_dir, model_name)

In [None]:
def create_prompts(row):
    instruction = 'Order the symptoms in chronological order based on the timeline implied in the text. Think step by step to determine the timeline for each symptom.'
    prompt = f"{instruction}\n\nText: {row['symptom_text']}\n\nSymptoms: {row['symptoms']}\n\nOrdered Symptoms:"
    return prompt

In [None]:
## Convert symptoms into a string with comma separated symptoms
data['symptoms'] = data['symptoms'].apply(lambda symptoms: ', '.join(symptoms))
data['output_text'] = data['ordered_symptoms'].apply(lambda symptoms: ', '.join(symptoms))
data['input_text'] = data.apply(create_prompts, axis=1)

In [None]:
train_dataset = Dataset.from_dict({
    'input_text': data['input_text'].to_list(),
    'output_text': data['output_text'].to_list()
})

train_valid_spt = train_dataset.train_test_split(test_size=0.1, seed=42)

train_dataset = DatasetDict({
    'train': train_valid_spt['train'],
    'validation': train_valid_spt['test'],
})
train_dataset

DatasetDict({
    train: Dataset({
        features: ['input_text', 'output_text'],
        num_rows: 9000
    })
    validation: Dataset({
        features: ['input_text', 'output_text'],
        num_rows: 1000
    })
})

## Model Finetuning:

In [None]:
tokenizer = AutoTokenizer.from_pretrained(f'{model_vendor}/{model_name}') #("facebook/bart-large")
model = AutoModelForSeq2SeqLM.from_pretrained(f'{model_vendor}/{model_name}') #("facebook/bart-large")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/1.09k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.59M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/892k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/772 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.68k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.77G [00:00<?, ?B/s]

In [None]:
def tokenize(data, tokenizer):
    tokenized_data = tokenizer(data['input_text'], padding='max_length', max_length=1024, truncation=True, return_tensors='pt')
    tokenized_data['labels'] = tokenizer(data['output_text'], padding='max_length', max_length=128, truncation=True, return_tensors='pt').input_ids
    return tokenized_data

fn_kwargs={"tokenizer": tokenizer}
tokenized_dataset = train_dataset.map(tokenize, batched=True, fn_kwargs=fn_kwargs)
tokenized_dataset = tokenized_dataset.map(lambda examples: examples, remove_columns=['input_text', 'output_text'])
tokenized_dataset

Map:   0%|          | 0/9000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Map:   0%|          | 0/9000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 9000
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 1000
    })
})

In [None]:
from peft import LoraConfig, TaskType, get_peft_model
from transformers import EarlyStoppingCallback

peft_config = LoraConfig(
    r=16,
    bias='none',
    lora_alpha=32,
    lora_dropout=0.05,
    task_type=TaskType.SEQ_2_SEQ_LM,
    #finetuning the attention layers only
    target_modules=['q_proj', 'k_proj', 'v_proj']
)

peft_training_args = TrainingArguments(
    output_dir=output_dir,
    auto_find_batch_size=True,
    learning_rate=1e-3,
    num_train_epochs=10,
    max_steps=10000,
    warmup_steps=500,
    weight_decay=0.01,
    eval_strategy='steps',
    save_strategy='steps',
    logging_steps=100,
    eval_steps=100,
    save_steps=100,
    load_best_model_at_end=True
)

In [None]:
peft_model = get_peft_model(model, peft_config)

peft_trainer = Trainer(
    model=peft_model,
    args=peft_training_args,
    train_dataset=tokenized_dataset['train'],
    eval_dataset=tokenized_dataset['validation'],
    tokenizer=tokenizer,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=10)]  # Stop if no improvement for 10 Steps
)


  peft_trainer = Trainer(
max_steps is given, it will override any value given in num_train_epochs


In [None]:
peft_model.print_trainable_parameters()

trainable params: 3,538,944 || all params: 445,809,664 || trainable%: 0.7938


In [None]:
peft_trainer.train()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Step,Training Loss,Validation Loss
100,0.9049,0.118931
200,0.1125,0.079238
300,0.0936,0.079525
400,0.088,0.068886
500,0.0828,0.06482
600,0.082,0.071433
700,0.0766,0.06852
800,0.0756,0.075469
900,0.082,0.064739
1000,0.0829,0.061774


TrainOutput(global_step=3100, training_loss=0.09879352246561358, metrics={'train_runtime': 9137.95, 'train_samples_per_second': 8.755, 'train_steps_per_second': 1.094, 'total_flos': 5.42834272763904e+16, 'train_loss': 0.09879352246561358, 'epoch': 2.7555555555555555})

In [None]:
## Save the model
peft_trainer.model.save_pretrained(peft_model_path)
tokenizer.save_pretrained(peft_model_path)

('/content/drive/MyDrive/NLP/vaers_analysis/models/biobart-v2-large/tokenizer_config.json',
 '/content/drive/MyDrive/NLP/vaers_analysis/models/biobart-v2-large/special_tokens_map.json',
 '/content/drive/MyDrive/NLP/vaers_analysis/models/biobart-v2-large/vocab.json',
 '/content/drive/MyDrive/NLP/vaers_analysis/models/biobart-v2-large/merges.txt',
 '/content/drive/MyDrive/NLP/vaers_analysis/models/biobart-v2-large/added_tokens.json',
 '/content/drive/MyDrive/NLP/vaers_analysis/models/biobart-v2-large/tokenizer.json')

## Inference:

In [None]:
from peft import PeftModel, PeftConfig

In [None]:
test_dataset = pd.read_csv(os.path.join(data_dir, 'test_vaers_data.csv'))
test_dataset['symptoms'] = test_dataset['symptoms'].apply(literal_eval)
test_dataset['ordered_symptoms'] = test_dataset['ordered_symptoms'].apply(literal_eval)
test_dataset['input_text'] = test_dataset.apply(create_prompts, axis=1)
test_dataset['output_text'] = test_dataset['ordered_symptoms'].apply(lambda symptoms: ', '.join(symptoms))
test_dataset.head()

Unnamed: 0,vaers_id,year,vax_type,vax_manu,symptom_text,symptoms,ordered_symptoms,report_length,num_symptoms,input_text,output_text
0,1344132,2021,COVID19,MODERNA,"A few days after my vaccine, I noticed under t...","[Blister, Erythema]","[Blister, Erythema]",93,2,Order the symptoms in chronological order base...,"Blister, Erythema"
1,1842147,2021,COVID19,MODERNA,Period schedule on and off the chart; Increase...,"[Biopsy, Heavy menstrual bleeding, Menstrual d...","[Menstrual disorder, Heavy menstrual bleeding,...",259,4,Order the symptoms in chronological order base...,"Menstrual disorder, Heavy menstrual bleeding, ..."
2,1165207,2021,COVID19,MODERNA,"within 24 hours of receiving my 2nd dose, I fi...","[Chills, Dry eye, Eye pain, Fatigue, Headache,...","[Pyrexia, Chills, Headache, Myalgia, Neuralgia...",110,9,Order the symptoms in chronological order base...,"Pyrexia, Chills, Headache, Myalgia, Neuralgia,..."
3,1618374,2021,COVID19,MODERNA,Side effects seem to have cleared up by the 17...,"[Headache, Vaccination complication]","[Vaccination complication, Headache]",241,2,Order the symptoms in chronological order base...,"Vaccination complication, Headache"
4,2460242,2022,COVID19,MODERNA,I received my first Moderna vaccine on one/14/...,"[Amenorrhoea, Arthralgia, Carbohydrate antigen...","[Lymphadenopathy, Arthralgia, Pain in extremit...",162,12,Order the symptoms in chronological order base...,"Lymphadenopathy, Arthralgia, Pain in extremity..."


In [None]:
def batch_inference(model, tokenizer, test_prompts, batch_size=5):
    """
    Performs batch inference on a list of prompts using a pre-trained model and tokenizer.

    This function processes the prompts in batches to efficiently handle large datasets,
    tokenizes the input, and generates predictions using the model. It returns the decoded
    predicted sequences.

    Args:
        model (torch.nn.Module): The pre-trained model to use for inference (e.g., GPT, T5, etc.).
        tokenizer (PreTrainedTokenizer): The tokenizer associated with the model for encoding inputs and decoding outputs.
        test_prompts (list of str): List of input prompts to generate predictions for.
        batch_size (int, optional): The number of prompts to process in each batch. Default is 5.

    Returns:
        list of str: A list of decoded predicted sequences corresponding to the input prompts.

    Raises:
        Exception: Catches and reports errors encountered during batch processing.

    Notes:
        - The function uses GPU if available; otherwise, it defaults to CPU.
        - Each batch is tokenized with padding and truncation up to a maximum input length of 1024 tokens.
        - Model generates predictions with up to 128 new tokens using beam search with 5 beams.
        - Any exceptions during batch processing are caught, and the function proceeds with the next batch.
    """
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)
    predicted_sequence = []
    num_batches = math.ceil(len(test_prompts) / batch_size)

    for i in range(num_batches):
        try:
            # Prepare the batch of prompts
            batch = tokenizer(
                test_prompts[i * batch_size:(i + 1) * batch_size],
                padding='max_length',
                max_length=1024,
                truncation=True,
                return_tensors='pt'
            ).to(device)

            # Generate predictions
            encoded_sequence = model.generate(**batch, max_new_tokens=128, num_beams=5)

            # Decode the predictions
            predicted_sequence.extend(tokenizer.batch_decode(encoded_sequence, skip_special_tokens=True))

        except Exception as e:
            # Handle errors for the batch
            print(f'Error in batch {i}: {e}')
            continue

        print(f'Processed Batch: {i} successfully!!')

    return predicted_sequence

In [None]:
def create_eval_df(test_dataset, predicted_sequence, true_sequence_key='output_text',
                   symptom_text_key='symptom_text', prompt_key='input_text'):
    """
    Creates an evaluation DataFrame by processing true and predicted sequences from the dataset.

    Args:
        test_dataset (pd.DataFrame): The dataset containing test inputs and expected outputs.
        predicted_sequence (list): The list of predicted sequences.
        true_sequence_key (str): The column key in `test_dataset` for the true sequence. Default is 'output_text'.
        symptom_text_key (str): The column key in `test_dataset` for the symptom text. Default is 'symptom_text'.
        prompt_key (str): The column key in `test_dataset` for the prompts. Default is 'input_text'.

    Returns:
        pd.DataFrame: Processed evaluation DataFrame with cleaned true and predicted sequences.
    """
    # Create the evaluation DataFrame
    eval_df = pd.DataFrame({
        'symptom_text': test_dataset[symptom_text_key].to_list(),
        'prompt': test_dataset[prompt_key].to_list(),
        'true_sequence': test_dataset[true_sequence_key].to_list(),
        'predicted_sequence': predicted_sequence
    })

    # Process and clean the true and predicted sequences
    eval_df['true_sequence'] = eval_df['true_sequence'].apply(lambda sym: [s.strip() for s in sym.split(',')])
    eval_df['predicted_sequence'] = eval_df['predicted_sequence'].apply(lambda sym: [s.strip().strip("'\"") for s in sym.split(',')])

    return eval_df

### Inference Using Bio BART Without PEFT:

In [None]:
# Step 1: Load the pre-trained sequence-to-sequence model
base_model = AutoModelForSeq2SeqLM.from_pretrained(f'{model_vendor}/{model_name}')

## Step 2: Perform batch inference using the loaded model
predicted_sequence_wpeft = batch_inference(base_model, tokenizer, test_dataset['input_text'].to_list())

## Step 3: Create a DataFrame for evaluation
eval_df_wpeft = create_eval_df(test_dataset, predicted_sequence_wpeft)
eval_df_wpeft[['symptom_text', 'true_sequence', 'predicted_sequence']].head(10)

Processed Batch: 0 successfully!!
Processed Batch: 1 successfully!!
Processed Batch: 2 successfully!!
Processed Batch: 3 successfully!!
Processed Batch: 4 successfully!!
Processed Batch: 5 successfully!!
Processed Batch: 6 successfully!!
Processed Batch: 7 successfully!!
Processed Batch: 8 successfully!!
Processed Batch: 9 successfully!!
Processed Batch: 10 successfully!!
Processed Batch: 11 successfully!!
Processed Batch: 12 successfully!!
Processed Batch: 13 successfully!!
Processed Batch: 14 successfully!!
Processed Batch: 15 successfully!!
Processed Batch: 16 successfully!!
Processed Batch: 17 successfully!!
Processed Batch: 18 successfully!!
Processed Batch: 19 successfully!!
Processed Batch: 20 successfully!!
Processed Batch: 21 successfully!!
Processed Batch: 22 successfully!!
Processed Batch: 23 successfully!!
Processed Batch: 24 successfully!!
Processed Batch: 25 successfully!!
Processed Batch: 26 successfully!!
Processed Batch: 27 successfully!!
Processed Batch: 28 successful

Unnamed: 0,symptom_text,true_sequence,predicted_sequence
0,"A few days after my vaccine, I noticed under t...","[Blister, Erythema]",[Order the symptoms in chronological order bas...
1,Period schedule on and off the chart; Increase...,"[Menstrual disorder, Heavy menstrual bleeding,...",[Order the symptoms in chronological order bas...
2,"within 24 hours of receiving my 2nd dose, I fi...","[Pyrexia, Chills, Headache, Myalgia, Neuralgia...",[Order the symptoms in chronological order bas...
3,Side effects seem to have cleared up by the 17...,"[Vaccination complication, Headache]",[Order the symptoms in chronological order bas...
4,I received my first Moderna vaccine on one/14/...,"[Lymphadenopathy, Arthralgia, Pain in extremit...",[Order the symptoms in chronological order bas...
5,Patient experienced only chills; Fever; Sorene...,"[Chills, Pyrexia, Myalgia]",[Order the symptoms in chronological order bas...
6,"8 days after the first vaccine dose, I had itc...","[Injection site pruritus, Injection site swell...",[Order the symptoms in chronological order bas...
7,Sore arm; Very tired; Headache; Burning sensat...,"[Burning sensation, Headache, Pain in extremit...",[Order the symptoms in chronological order bas...
8,Chills; Urinating (More often); This spontaneo...,"[Chills, Pollakiuria]",[Order the symptoms in chronological order bas...
9,"Swelling Left arm, upper Calves and legs swell...","[Peripheral swelling, Vaccination site bruisin...",[Order the symptoms in chronological order bas...


In [None]:
predicted_sequence_wpeft

['Order the symptoms in chronological order based on the timeline implied in the clinical text. Think step by step to determine that the chronological order for each symptom.Nitrome¯¯¯¯Text: A few days after my vaccine, I noticed under the pad of my big toe, there was a yellowish fluid-filled lesion that was about a month to 2 months to resolve. I also had redness onthe tip of a couple of my toes and the end of my lower and middle toeses. That was on my left foot. I  also had red on',
 'Order the symptoms in chronological order based on the timeline implied in the-text. Think step by step to determine the\xa0timeline for each symptom.¯¯¯¯¯¯¯¯Text: Period schedule on and off the chart; Increased period bleeding; She experienced was headache which was resolved day after getting vaccine; This spontaneous case was reported by a consumer and describes the occurrence of MENSTRUAL DISORDER (Period schedule on\xa0and off\xa0the chart), HEAVY MENSTRULAR BLEEDING (Increased period blood) and THE

In [None]:
## TODO: Need to convert this to CSV
eval_df_wpeft.to_csv(os.path.join(results_dir, f'{model_name}-results.csv'), index=False)

### Inference Using BioBART with PEFT:

In [None]:
# Step 1: Load the pre-trained PEFT model
peft_model = PeftModel.from_pretrained(base_model, peft_model_path, is_trainable=False)

## Step 2: Perform batch inference using the loaded model
predicted_sequence_peft = batch_inference(peft_model, tokenizer, test_dataset['input_text'].to_list())

## Step 3: Create a DataFrame for evaluation
eval_df_peft = create_eval_df(test_dataset, predicted_sequence_peft)
eval_df_peft[['symptom_text', 'true_sequence', 'predicted_sequence']].head(10)

Processed Batch: 0 successfully!!
Processed Batch: 1 successfully!!
Processed Batch: 2 successfully!!
Processed Batch: 3 successfully!!
Processed Batch: 4 successfully!!
Processed Batch: 5 successfully!!
Processed Batch: 6 successfully!!
Processed Batch: 7 successfully!!
Processed Batch: 8 successfully!!
Processed Batch: 9 successfully!!
Processed Batch: 10 successfully!!
Processed Batch: 11 successfully!!
Processed Batch: 12 successfully!!
Processed Batch: 13 successfully!!
Processed Batch: 14 successfully!!
Processed Batch: 15 successfully!!
Processed Batch: 16 successfully!!
Processed Batch: 17 successfully!!
Processed Batch: 18 successfully!!
Processed Batch: 19 successfully!!
Processed Batch: 20 successfully!!
Processed Batch: 21 successfully!!
Processed Batch: 22 successfully!!
Processed Batch: 23 successfully!!
Processed Batch: 24 successfully!!
Processed Batch: 25 successfully!!
Processed Batch: 26 successfully!!
Processed Batch: 27 successfully!!
Processed Batch: 28 successful

Unnamed: 0,symptom_text,true_sequence,predicted_sequence
0,"A few days after my vaccine, I noticed under t...","[Blister, Erythema]","[Erythema, Blister]"
1,Period schedule on and off the chart; Increase...,"[Menstrual disorder, Heavy menstrual bleeding,...","[Menstrual disorder, Heavy menstrual bleeding,..."
2,"within 24 hours of receiving my 2nd dose, I fi...","[Pyrexia, Chills, Headache, Myalgia, Neuralgia...","[Pyrexia, Chills, Headache, Myalgia, Neuralgia..."
3,Side effects seem to have cleared up by the 17...,"[Vaccination complication, Headache]","[Vaccination complication, Headache]"
4,I received my first Moderna vaccine on one/14/...,"[Lymphadenopathy, Arthralgia, Pain in extremit...","[Lymphadenopathy, Menstrual disorder, Amenorrh..."
5,Patient experienced only chills; Fever; Sorene...,"[Chills, Pyrexia, Myalgia]","[Chills, Pyrexia, Myalgia]"
6,"8 days after the first vaccine dose, I had itc...","[Injection site pruritus, Injection site swell...","[Injection site pruritus, Injection site swell..."
7,Sore arm; Very tired; Headache; Burning sensat...,"[Burning sensation, Headache, Pain in extremit...","[Pain in extremity, Burning sensation, Headach..."
8,Chills; Urinating (More often); This spontaneo...,"[Chills, Pollakiuria]","[Chills, Pollakiuria]"
9,"Swelling Left arm, upper Calves and legs swell...","[Peripheral swelling, Vaccination site bruisin...","[Peripheral swelling, Vaccination site bruisin..."


In [None]:
eval_df_peft.to_csv(os.path.join(results_dir, f'{model_name}-peft-results.csv'), index=False)