In [2]:
# import necessary packages
import sys, os
import torch 
import numpy as np
from peft import get_peft_config, get_peft_model, LoraConfig, TaskType
from trl import SFTTrainer, setup_chat_format
from transformers import (pipeline,
                          AutoTokenizer,
                          AutoModelForCausalLM,
                          DataCollatorWithPadding,
                          get_scheduler)
from datasets import load_dataset
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm.auto import tqdm
from importlib import reload
from functools import partial
from IPython.display import clear_output

sys.path.append('../')

# custom imports
from utils.GetLowestGPU import GetLowestGPU
from utils.GetFileNames import get_file_names
import utils.preprocessing as pp

device = GetLowestGPU()

Device set to cuda:1


In [3]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Instantiate Model and Dataset

In [4]:
# options
model_path = "meta-llama/Meta-Llama-3-8B-Instruct"
dataset_path = "ruslanmv/ai-medical-chatbot" #test dataset

peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
)

# load tokenizer and model
pipeline = pipeline(
    "text-generation",
    model=model_path,
    model_kwargs={"torch_dtype": torch.bfloat16},
    device_map="auto",
)

pipeline.model = get_peft_model(pipeline.model, peft_config)

pipeline.tokenizer.pad_token = pipeline.tokenizer.eos_token
pipeline.tokenizer.pad_token_id = pipeline.tokenizer.eos_token_id
pipeline.model.generation_config.pad_token_id = pipeline.tokenizer.eos_token_id

# pipeline.model, pipeline.tokenizer = setup_chat_format(pipeline.model, pipeline.tokenizer)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [9]:
pipeline.model.print_trainable_parameters()

trainable params: 3,407,872 || all params: 8,033,669,120 || trainable%: 0.0424


In [10]:
# load dataset
raw_dataset = load_dataset(dataset_path, split = "train[:5000]")

# check format of data
raw_dataset = raw_dataset.train_test_split(test_size=0.1)
raw_dataset

DatasetDict({
    train: Dataset({
        features: ['Description', 'Patient', 'Doctor'],
        num_rows: 4500
    })
    test: Dataset({
        features: ['Description', 'Patient', 'Doctor'],
        num_rows: 500
    })
})

# Preprocessing

In [11]:
# format chat dataset
reload(pp)
format_chat = partial(pp.format_chat, input_col="Patient", output_col="Doctor", pipeline_name=pipeline)
chat_dataset = raw_dataset.map(format_chat)
chat_dataset['test'][0]



Map:   0%|          | 0/4500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

{'Description': 'Q. Does excess body heat cause swelling in the groin area?',
 'Patient': 'Hello doctor, I have got a swelling in the groin area. I clean the groin area daily. I suppose this is because of excess body heat. It appeared the day after I did a long drive of 12 hours from my home town. And the place was too hot.',
 'Doctor': 'Hello. Do you have itching and a similar skin lesion on another side of the thigh? Is this the first time you are having this problem? In the picture (attachment removed to protect patient identity), it looks like tinea cruris which is a fungal infection of the groin. The picture is not very clear. I do not see any fluid-filled blisters. Is there oozing of fluid from the bumps?',
 'text': '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHello doctor, I have got a swelling in the groin area. I clean the groin area daily. I suppose this is because of excess body heat. It appeared the day after I did a long drive of 12 hours from my home town

In [12]:
# tokenize dataset
tokenize_function = partial(pp.tokenize_data, pipeline_name=pipeline)
tokenized_dataset = chat_dataset.map(tokenize_function, 
                                     batched=True,
                                     remove_columns=chat_dataset['train'].column_names)

Map:   0%|          | 0/4500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

# Create Dataloaders

In [15]:
# instantiate data collator
data_collator = DataCollatorWithPadding(tokenizer=pipeline.tokenizer)

# options
batch_size = 8

train_dataloader = DataLoader(tokenized_dataset['train'],
                              batch_size=batch_size, 
                              collate_fn=data_collator)

val_dataloader = DataLoader(tokenized_dataset['test'],
                            batch_size=batch_size,
                            collate_fn=data_collator)

In [16]:
# inspect sample batch
batch = next(iter(train_dataloader))
{key: val.shape for key, val in batch.items()}

{'input_ids': torch.Size([8, 1024]),
 'attention_mask': torch.Size([8, 1024]),
 'labels': torch.Size([8, 1024])}

In [17]:
outputs = pipeline.model(**batch)
print(outputs.loss)

We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


tensor(14.3553, grad_fn=<ToCopyBackward0>)


In [18]:
# test pre training
text = [{'role': 'system', 'content': 'you are a helpful medical chatbot'},
        {'role': 'user', 'content': 'I have a headache. What should I do?'}]
print(pipeline(text, max_length=100, truncation=True)[0]['generated_text'])

[{'role': 'system', 'content': 'you are a helpful medical chatbot'}, {'role': 'user', 'content': 'I have a headache. What should I do?'}, {'role': 'assistant', 'content': "Sorry to hear that you're experiencing a headache! As a helpful medical chatbot, I'd like to offer some suggestions to help alleviate your discomfort.\n\nBefore we dive into remedies, can you tell me a bit more about your headache? For example:\n\n1. How long have you had the headache?\n2. Is it a sudden onset or"}]


# Training

In [19]:
# options
optimizer = AdamW(pipeline.model.parameters(), lr=1e-5)
num_epochs = 10
num_steps = num_epochs * len(train_dataloader)

# test after training
text = [{'role': 'system', 'content': 'You are a helpful medical chatbot'},
        {'role': 'user', 'content': 'I have a headache. What should I do?'}]

# loop
for epoch in range(num_epochs):
    
    print("=====================")
    print(f"Epoch {epoch + 1}")
    print("=====================")

    # set model to train mode
    pipeline.model.train()

    # initialize train loss, val loss
    running_train_loss = 0.0
    running_val_loss = 0.0

    # loop through train data
    print("Training...")
    i = 0
    for batch in tqdm(train_dataloader):

        # grab batch and map to device
        batch = {k: v.to(device) for k, v in batch.items()}

        # forward pass
        outputs = pipeline.model(**batch)
        loss = outputs.loss
        print(f"batch loss: {loss:.4f}\r", end="")

        running_train_loss += loss.item()

        # backward pass
        loss.backward()

        # update optimizer
        optimizer.step()

        # zero gradients
        optimizer.zero_grad()
        
        i += 1

        # if i % 5 == 0:
        #     print(pipeline(text, max_length=100, truncation=True)[0]['generated_text'])
            
    # set model to eval mode
    pipeline.model.eval()

    for batch in val_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = pipeline.model(**batch)
            loss = outputs.loss
            running_val_loss += loss.item()
        
    val_loss = running_val_loss / len(val_dataloader)

    print("Printing example response...")
    print(pipeline(text, max_length=100, truncation=True)[0]['generated_text'])

    train_loss = running_train_loss / len(train_dataloader)
    print(f"Avg. Train Loss: {train_loss:.4f}, Avg. Val Loss: {val_loss:.4f}")
    # print("Evaluation metrics:", metric.compute())

print("Training Complete!")

Epoch 1
Training...


  0%|          | 0/563 [00:00<?, ?it/s]

batch loss: 14.3553

  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


batch loss: 3.18096

KeyboardInterrupt: 

# Prediction

In [20]:
# test after training
text = [{'role': 'system', 'content': 'You are a helpful medical chatbot'},
        {'role': 'user', 'content': 'I have a migraine. What should I do?'}]
print(pipeline(text, max_length=512, truncation=True)[0]['generated_text'])

[{'role': 'system', 'content': 'You are a helpful medical chatbot'}, {'role': 'user', 'content': 'I have a migraine. What should I do?'}, {'role': 'assistant', 'content': "Sorry to hear that you're experiencing a migraine! As a helpful medical chatbot, I'd be happy to guide you through some steps to help alleviate your symptoms.\n\n**Immediate Relief:**\n\n1. **Stay hydrated**: Drink plenty of water to help your body replenish lost fluids and electrolytes. Aim for at least 8-10 glasses of water throughout the day.\n2. **Rest**: Find a quiet, dark, and cool place to rest and relax. Avoid stimulating activities and bright lights.\n3. **Over-the-counter pain relievers**: You can try taking an over-the-counter pain reliever such as ibuprofen (Advil, Motrin) or acetaminophen (Tylenol). However, always follow the recommended dosage and consult with your doctor if you're unsure.\n\n**Additional Tips:**\n\n1. **Cold or warm compress**: Apply a cold or warm compress to the forehead, neck, or sh