In [1]:
# import necessary packages
import sys, os
import torch 
import numpy as np
import evaluate
from trl import SFTTrainer, setup_chat_format
from transformers import (pipeline,
                          AutoTokenizer,
                          AutoModelForCausalLM,
                          DataCollatorWithPadding,
                          get_scheduler)
from datasets import load_dataset
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm.auto import tqdm
from IPython.display import clear_output

sys.path.append('../')

# custom imports
from utils.GetLowestGPU import GetLowestGPU

device = GetLowestGPU()

Device set to cuda:0


In [2]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Instantiate Model and Dataset

In [3]:
# options
model_path = "meta-llama/Meta-Llama-3-8B-Instruct"
dataset_path = "ruslanmv/ai-medical-chatbot" #test dataset

# load tokenizer and model
pipeline = pipeline('text-generation', 
                    model=model_path,
                    model_kwargs={'torch_dtype': torch.bfloat16},
                    device_map = 'auto'
                    )

model, tokenizer = pipeline.model, pipeline.tokenizer
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
model, tokenizer = setup_chat_format(model, tokenizer)


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
# load dataset
raw_dataset = load_dataset(dataset_path, split = 'train[:1%]')

# check format of data
raw_dataset = raw_dataset.train_test_split(test_size=0.1)
raw_dataset

DatasetDict({
    train: Dataset({
        features: ['Description', 'Patient', 'Doctor'],
        num_rows: 2312
    })
    test: Dataset({
        features: ['Description', 'Patient', 'Doctor'],
        num_rows: 257
    })
})

# Preprocessing

In [6]:
# preprocess data
def format_chat(row):
    row_json_inp = [{"role": "user", "content": row["Patient"]}]
    row_json_out = [{"role": "assistant", "content": row["Doctor"]}]
    row["user"] = tokenizer.apply_chat_template(row_json_inp, tokenize=False)
    row["assistant"] = tokenizer.apply_chat_template(row_json_out, tokenize=False)
    return row

def preprocess_data(examples):
    inp = examples["user"]
    out = examples["assistant"]
    tokenized_data = tokenizer(text=inp, 
                               text_target=out,
                               padding='max_length', 
                               truncation=True, 
                               max_length=512)
    return tokenized_data

In [7]:
chat_dataset = raw_dataset.map(format_chat)
chat_dataset['train'][0]

Map:   0%|          | 0/2312 [00:00<?, ? examples/s]

Map:   0%|          | 0/257 [00:00<?, ? examples/s]

{'Description': 'Q. Suffering from osteopenia and osteoarthritis. Taking AcuCal and FreeFlex. Is it safe to take SAM(S Adenosyl Methionine) to enhance mobility?',
 'Patient': 'Dear Doctor, My mother is 66 years old. She is having Trigeminal Neuralgia and for which she is taking Lyrica 75mg, Tegritol CR 400 and Symbal 30. And other medication details as below: For Cholesetrol -> Storfib -10mg. For BP -> seloken xl 100mg and amlodac 5mg. For Urinary Incontinence -> tropan 5mg and cranpac. Her bone mineral density revealed to scores of Osteopenia, and also Osteoarthritis. She is taking AcuCal calcium tablet and FreeFlex Forte. But her mobility is still very much restricted. So I want to know whether she can take SAM (S Adenosyl Methionine), and is it safe? Also If she can take SAM, what will be the dosage and duration. I am asking whether she can take this as a supportive to enhance her mobility. She is a pure vegetarian and on fat free diet. Her weight is 70kg, height is 5.1 feet. Please

In [8]:
# add special tokens to tokenizer
tokenized_dataset = chat_dataset.map(preprocess_data, 
                                    batched=True,
                                    remove_columns=chat_dataset['train'].column_names)
tokenized_dataset.with_format("torch")

# check tokenized dataset output
tokenized_dataset

Map:   0%|          | 0/2312 [00:00<?, ? examples/s]

Map:   0%|          | 0/257 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 2312
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 257
    })
})

# Create Dataloaders

In [9]:
# instantiate data collator
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# options
batch_size = 8

train_dataloader = DataLoader(tokenized_dataset['train'],
                              batch_size=batch_size, 
                              collate_fn=data_collator)

val_dataloader = DataLoader(tokenized_dataset['test'],
                            batch_size=batch_size,
                            collate_fn=data_collator)

In [10]:
# inspect sample batch
batch = next(iter(train_dataloader))
{key: val.shape for key, val in batch.items()}

{'input_ids': torch.Size([8, 512]),
 'attention_mask': torch.Size([8, 512]),
 'labels': torch.Size([8, 512])}

In [11]:
outputs = model(**batch)
print(outputs.loss, outputs.logits.shape)

tensor(16.7973, grad_fn=<ToCopyBackward0>) torch.Size([8, 512, 128258])


In [12]:
# test pre training
text = [{'role': 'system', 'content': 'You are a helpful medical chatbot'},
        {'role': 'user', 'content': 'I have a headache. What should I do?'}]
print(pipeline(text, max_length=100, truncation=True)[0]['generated_text'])

[{'role': 'system', 'content': 'You are a helpful medical chatbot'}, {'role': 'user', 'content': 'I have a headache. What should I do?'}, {'role': 'assistant', 'content': "Sorry to hear that you're experiencing a headache! There are several steps you can take to help alleviate the discomfort. Here are some suggestions:\n\n1. **Stay hydrated**: Dehydration is a common cause of headaches. Drink a glass of water or other non-caffeinated fluid to see if that helps.\n2. **Take a pain reliever**:"}]


# Training

In [13]:
# initialize optimizer
optimizer = AdamW(model.parameters(), lr=5e-5)

# and scheduler
num_epochs = 1
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)

print(num_training_steps)

289


In [14]:
# eval loop
# loop through epochs
for epoch in range(num_epochs):
    clear_output(wait=True)
    
    print("=====================")
    print(f"Epoch {epoch + 1}")
    print("=====================")

    # set model to train mode
    model.train()

    # initialize train loss, val loss
    train_loss = 0.0
    val_loss = 0.0

    # loop through train data
    print("Training...")
    for batch in tqdm(train_dataloader):

        # grab batch and map to device
        batch = {k: v.to(device) for k, v in batch.items()}

        # forward pass
        outputs = model(**batch)
        loss = outputs.loss

        train_loss += loss.item()

        # backward pass
        loss.backward()

        # update optimizer
        optimizer.step()

        # update scheduler
        lr_scheduler.step()

        # zero gradients
        optimizer.zero_grad()

    train_loss = train_loss / (len(train_dataloader) / batch_size)

    # set to eval mode
    model.eval()
    print("Validating...")
    for batch in tqdm(val_dataloader):

        # get batch
        batch = {k: v.to(device) for k, v in batch.items()}

        # forward pass
        with torch.no_grad():
            outputs = model(**batch)

        # get loss
        loss = outputs.loss
        val_loss += loss.item()


    val_loss = val_loss / (len(val_dataloader) / batch_size)

    print(f"Avg. Train Loss: {train_loss}, Avg. Val Loss: {val_loss}")
    # print("Evaluation metrics:", metric.compute())


Epoch 1
Training...


  0%|          | 0/289 [00:00<?, ?it/s]

Validating...


  0%|          | 0/33 [00:00<?, ?it/s]

Avg. Train Loss: 17.117546503931592, Avg. Val Loss: 14.68741343238137


# Prediction

In [19]:
# test after training
text = [{'role': 'system', 'content': 'You are a helpful medical chatbot'},
        {'role': 'user', 'content': 'I have a migraine. What should I do?'}]
print(pipeline(text, max_length=100, truncation=True)[0]['generated_text'])

[{'role': 'system', 'content': 'You are a helpful medical chatbot'}, {'role': 'user', 'content': 'I have a migraine. What should I do?'}, {'role': 'assistant', 'content': ''}]
