In [1]:
# import necessary packages
import sys, os
import torch 
import numpy as np
import evaluate
from peft import get_peft_config, get_peft_model, LoraConfig, TaskType
from trl import SFTTrainer, setup_chat_format
from transformers import (pipeline,
                          AutoTokenizer,
                          AutoModelForCausalLM,
                          DataCollatorWithPadding,
                          get_scheduler)
from datasets import load_dataset
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm.auto import tqdm
from IPython.display import clear_output

sys.path.append('../')

# custom imports
from utils.GetLowestGPU import GetLowestGPU

device = GetLowestGPU()

Device set to cuda:2


In [2]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Instantiate Model and Dataset

In [3]:
# options
model_path = "meta-llama/Meta-Llama-3-8B-Instruct"
dataset_path = "ruslanmv/ai-medical-chatbot" #test dataset

peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
)

# load tokenizer and model
pipeline = pipeline(
    "text-generation",
    model=model_path,
    model_kwargs={"torch_dtype": torch.bfloat16},
    device_map="auto",
)

pipeline.model = get_peft_model(pipeline.model, peft_config)

pipeline.tokenizer.pad_token = pipeline.tokenizer.eos_token
pipeline.model.generation_config.pad_token_id = pipeline.tokenizer.eos_token_id

# pipeline.model, pipeline.tokenizer = setup_chat_format(pipeline.model, pipeline.tokenizer)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:
pipeline.model.print_trainable_parameters()

trainable params: 3,407,872 || all params: 8,033,669,120 || trainable%: 0.0424


In [5]:
# load dataset
raw_dataset = load_dataset(dataset_path, split = "train[:5000]")

# check format of data
raw_dataset = raw_dataset.train_test_split(test_size=0.1)
raw_dataset

DatasetDict({
    train: Dataset({
        features: ['Description', 'Patient', 'Doctor'],
        num_rows: 4500
    })
    test: Dataset({
        features: ['Description', 'Patient', 'Doctor'],
        num_rows: 500
    })
})

# Preprocessing

In [6]:
# preprocess data
def format_chat(row):
    row_json = [{'role': 'user', 'content': row["Patient"]},
                {'role': 'assistant', 'content': row["Doctor"]}]
    row["text"] = pipeline.tokenizer.apply_chat_template(row_json, tokenize=False)
    return row

def preprocess_data(examples):
    tokenized_data = pipeline.tokenizer(text=examples['text'],
                               padding='max_length', 
                               truncation=True, 
                               max_length=512)
    
    labels = tokenized_data['input_ids'].copy()

    for i in range(len(labels)):
        if labels[i][-1] != pipeline.tokenizer.pad_token_id:
            labels[i] = labels[i][1:] + [pipeline.tokenizer.pad_token_id]
        else:
            labels[i] = labels[i][1:] + [-100]

    labels = [[-100 if x == pipeline.tokenizer.pad_token_id else x for x in y] for y in labels]
    tokenized_data['labels'] = labels
    
    return tokenized_data

In [7]:
chat_dataset = raw_dataset.map(format_chat)
chat_dataset['test'][0]



Map:   0%|          | 0/4500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

{'Description': 'Q. I am feeling uncomfortable with my sports goggles. Please advise.',
 'Patient': "Hello doctor, Why does everything look different or trippy? I recently got sports goggles for soccer with the same prescription as with my regular glasses. I feel so much more comfortable with my regular glasses that I almost want to keep playing sports with those glasses. I got the sports goggles so they won't break (I have broken pairs of glasses in the past with sports) and so they will not fall off. Now I am scared to use them because it is almost as if I am in an adjustment period again. I know that happens when you have a different prescription but I know this is the same exact one. Will this continue to bother me when I am practicing or playing? Any tips?",
 'Doctor': 'Hi. It is very common for everyone to get these kinds of experience when they change glasses even with the same power glasses. This is because our eyes get adjusted to the glasses, even minor changes like the shape

In [8]:
# add special tokens to tokenizer
tokenized_dataset = chat_dataset.map(preprocess_data, 
                                    batched=True,
                                    remove_columns=chat_dataset['train'].column_names)
tokenized_dataset.with_format("torch")

# check tokenized dataset output
tokenized_dataset

Map:   0%|          | 0/4500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 4500
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 500
    })
})

# Create Dataloaders

In [9]:
# instantiate data collator
data_collator = DataCollatorWithPadding(tokenizer=pipeline.tokenizer)

# options
batch_size = 8

train_dataloader = DataLoader(tokenized_dataset['train'],
                              batch_size=batch_size, 
                              collate_fn=data_collator)

val_dataloader = DataLoader(tokenized_dataset['test'],
                            batch_size=batch_size,
                            collate_fn=data_collator)

In [10]:
# inspect sample batch
batch = next(iter(train_dataloader))
{key: val.shape for key, val in batch.items()}

{'input_ids': torch.Size([8, 512]),
 'attention_mask': torch.Size([8, 512]),
 'labels': torch.Size([8, 512])}

In [11]:
outputs = pipeline.model(**batch)
print(outputs.loss)

tensor(11.1616, grad_fn=<ToCopyBackward0>)


In [12]:
# test pre training
text = [{'role': 'system', 'content': 'you are a helpful medical chatbot'},
        {'role': 'user', 'content': 'I have a headache. What should I do?'}]
print(pipeline(text, max_length=100, truncation=True)[0]['generated_text'])

[{'role': 'system', 'content': 'you are a helpful medical chatbot'}, {'role': 'user', 'content': 'I have a headache. What should I do?'}, {'role': 'assistant', 'content': "Sorry to hear that you're experiencing a headache! As a helpful medical chatbot, I'd be happy to guide you through some steps to help alleviate your discomfort.\n\nHere are some suggestions:\n\n1. **Stay hydrated**: Sometimes, dehydration can cause or worsen headaches. Drink plenty of water or other fluids to help your body replenish its"}]


# Training

In [13]:
# options
optimizer = AdamW(pipeline.model.parameters(), lr=1e-5)
num_epochs = 10
num_steps = num_epochs * len(train_dataloader)

# test after training
text = [{'role': 'system', 'content': 'You are a helpful medical chatbot'},
        {'role': 'user', 'content': 'I have a headache. What should I do?'}]

# loop
for epoch in range(num_epochs):
    
    print("=====================")
    print(f"Epoch {epoch + 1}")
    print("=====================")

    # set model to train mode
    pipeline.model.train()

    # initialize train loss, val loss
    running_train_loss = 0.0
    running_val_loss = 0.0

    # loop through train data
    print("Training...")
    i = 0
    for batch in tqdm(train_dataloader):

        # grab batch and map to device
        batch = {k: v.to(device) for k, v in batch.items()}

        # forward pass
        outputs = pipeline.model(**batch)
        loss = outputs.loss
        print(f"batch loss: {loss:.4f}\r", end="")

        running_train_loss += loss.item()

        # backward pass
        loss.backward()

        # update optimizer
        optimizer.step()

        # zero gradients
        optimizer.zero_grad()
        
        i += 1

        # if i % 5 == 0:
        #     print(pipeline(text, max_length=100, truncation=True)[0]['generated_text'])
            
    # set model to eval mode
    pipeline.model.eval()

    for batch in val_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = pipeline.model(**batch)
            loss = outputs.loss
            running_val_loss += loss.item()
        
    val_loss = running_val_loss / len(val_dataloader)

    print("Printing example response...")
    print(pipeline(text, max_length=100, truncation=True)[0]['generated_text'])

    train_loss = running_train_loss / len(train_dataloader)
    print(f"Avg. Train Loss: {train_loss:.4f}, Avg. Val Loss: {val_loss:.4f}")
    # print("Evaluation metrics:", metric.compute())

print("Training Complete!")

Epoch 1
Training...


  0%|          | 0/563 [00:00<?, ?it/s]

Printing example response...
[{'role': 'system', 'content': 'You are a helpful medical chatbot'}, {'role': 'user', 'content': 'I have a headache. What should I do?'}, {'role': 'assistant', 'content': "Sorry to hear that you have a headache! Don't worry, there are several steps you can follow to help alleviate your discomfort. Here are a few suggestions:\n\n 1.  **Drink water**: Sometimes, a headache can be caused by a dehydration. If you have a headache, try drinking water. Water is helpful to your body in"}]
Avg. Train Loss: 6.5678, Avg. Val Loss: 4.9125
Epoch 2
Training...


  0%|          | 0/563 [00:00<?, ?it/s]

Printing example response...
[{'role': 'system', 'content': 'You are a helpful medical chatbot'}, {'role': 'user', 'content': 'I have a headache. What should I do?'}, {'role': 'assistant', 'content': "Sorry to hear that you're having a headache! There are several things you can try to help alleviate the discomfort. Here's a list of remedies that have been known to help: 1. Drink plenty of water. Sometimes, a headache may be caused by dehydration, so."}]
Avg. Train Loss: 4.7440, Avg. Val Loss: 4.6215
Epoch 3
Training...


  0%|          | 0/563 [00:00<?, ?it/s]

Printing example response...
[{'role': 'system', 'content': 'You are a helpful medical chatbot'}, {'role': 'user', 'content': 'I have a headache. What should I do?'}, {'role': 'assistant', 'content': "Sorry to hear that you're having a headache. As a medical chatbot, I can help you with the best course of action."}]
Avg. Train Loss: 4.5642, Avg. Val Loss: 4.5058
Epoch 4
Training...


  0%|          | 0/563 [00:00<?, ?it/s]

Printing example response...
[{'role': 'system', 'content': 'You are a helpful medical chatbot'}, {'role': 'user', 'content': 'I have a headache. What should I do?'}, {'role': 'assistant', 'content': "Sorry to hear that you're experiencing a headache! Don't worry, I'm here to help. Here are some steps you can follow to help alleviate your headache:\n\n1. **Stay hydrated**: Sometimes, dehydration can cause headaches. Drink a glass of water or other non-caffeinated liquid to help your body replenish fluids.\n\n2."}]
Avg. Train Loss: 4.4695, Avg. Val Loss: 4.4259
Epoch 5
Training...


  0%|          | 0/563 [00:00<?, ?it/s]

Printing example response...
[{'role': 'system', 'content': 'You are a helpful medical chatbot'}, {'role': 'user', 'content': 'I have a headache. What should I do?'}, {'role': 'assistant', 'content': "Sorry to hear that you're experiencing a headache. Here are some steps you can follow to help relieve your headache:\n\n1. Stay hydrated: Sometimes, headaches can be caused by dehydration. Drinking a glass of water or other hydrating drink can help relieve your headache.\n2. Stretch your neck and shoulders: Tight muscles in your neck and shoulders"}]
Avg. Train Loss: 4.3986, Avg. Val Loss: 4.3587
Epoch 6
Training...


  0%|          | 0/563 [00:00<?, ?it/s]

Printing example response...
[{'role': 'system', 'content': 'You are a helpful medical chatbot'}, {'role': 'user', 'content': 'I have a headache. What should I do?'}, {'role': 'assistant', 'content': "Sorry to hear that you're experiencing a headache! There are several steps you can take to help alleviate the discomfort. Here are some suggestions:\r\n\r\n1. Drink plenty of water: Sometimes, dehydration can lead to headaches. So, make sure you're drinking enough water throughout the day."}]
Avg. Train Loss: 4.3367, Avg. Val Loss: 4.3001
Epoch 7
Training...


  0%|          | 0/563 [00:00<?, ?it/s]

Printing example response...
[{'role': 'system', 'content': 'You are a helpful medical chatbot'}, {'role': 'user', 'content': 'I have a headache. What should I do?'}, {'role': 'assistant', 'content': "Sorry to hear that you're experiencing a headache! There are many possible causes, but don't worry, I'm here to guide you through some steps to help alleviate your discomfort. Here are some suggestions:\n\n1. **Hydrate**: Sometimes, a headache can be caused by dehydration. Drink plenty of water (at least 8-"}]
Avg. Train Loss: 4.2820, Avg. Val Loss: 4.2459
Epoch 8
Training...


  0%|          | 0/563 [00:00<?, ?it/s]

Printing example response...
[{'role': 'system', 'content': 'You are a helpful medical chatbot'}, {'role': 'user', 'content': 'I have a headache. What should I do?'}, {'role': 'assistant', 'content': "Sorry to hear that you're experiencing a headache! I'm here to help. Here are some steps you can follow to alleviate your headache:\n\n1. **Stay hydrated**: Dehydration is a common cause of headaches. Drink at least 8-10 glasses of water throughout the day to replenish your body.\n\n2. **Rest**: If"}]
Avg. Train Loss: 4.2330, Avg. Val Loss: 4.1975
Epoch 9
Training...


  0%|          | 0/563 [00:00<?, ?it/s]

Printing example response...
[{'role': 'system', 'content': 'You are a helpful medical chatbot'}, {'role': 'user', 'content': 'I have a headache. What should I do?'}, {'role': 'assistant', 'content': "Sorry to hear that you're experiencing a headache! There are several things you can try to help relieve the discomfort. Here are a few suggestions:\r\n\r\n1. Drink plenty of water: Sometimes, dehydration can cause headaches. Make sure you are getting enough fluids throughout the day.\r\n\r\n2. Over-the-counter pain medication: You can try taking an over"}]
Avg. Train Loss: 4.1894, Avg. Val Loss: 4.1574
Epoch 10
Training...


  0%|          | 0/563 [00:00<?, ?it/s]

batch loss: 3.5609

You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset


Printing example response...
[{'role': 'system', 'content': 'You are a helpful medical chatbot'}, {'role': 'user', 'content': 'I have a headache. What should I do?'}, {'role': 'assistant', 'content': "Sorry to hear that you're experiencing a headache! Don't worry, I'm here to help. Here are some steps you can take to alleviate your headache:\n\n1. **Stay hydrated**: Sometimes, headaches can be caused by dehydration. Drink a glass of water or an electrolyte-rich beverage, such as coconut water or a sports drink.\n"}]
Avg. Train Loss: 4.1500, Avg. Val Loss: 4.1176
Training Complete!


# Prediction

In [15]:
# test after training
text = [{'role': 'system', 'content': 'You are a helpful medical chatbot'},
        {'role': 'user', 'content': 'I have a migraine. What should I do?'}]
print(pipeline(text, max_length=512, truncation=True)[0]['generated_text'])

[{'role': 'system', 'content': 'You are a helpful medical chatbot'}, {'role': 'user', 'content': 'I have a migraine. What should I do?'}, {'role': 'assistant', 'content': 'Sorry to hear that you are having a migraine.'}]
