In [1]:
# import necessary packages
import sys, os
import torch 
import numpy as np
from peft import get_peft_config, get_peft_model, LoraConfig, TaskType
from trl import SFTTrainer, setup_chat_format
from transformers import (pipeline,
                          AutoTokenizer,
                          AutoModelForCausalLM,
                          DataCollatorWithPadding,
                          get_scheduler)
from datasets import load_dataset
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm.auto import tqdm
from importlib import reload
from functools import partial
from IPython.display import clear_output

sys.path.append('../')

# custom imports
from utils.GetLowestGPU import GetLowestGPU
from utils.GetFileNames import get_file_names
import utils.preprocessing as pp

device = GetLowestGPU()

Device set to cuda:0


In [2]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Instantiate Model and Dataset

In [3]:
# options
model_path = "meta-llama/Meta-Llama-3-8B-Instruct"
dataset_path = "ruslanmv/ai-medical-chatbot" #test dataset

peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
)

# load tokenizer and model
pipeline = pipeline(
    "text-generation",
    model=model_path,
    model_kwargs={"torch_dtype": torch.bfloat16},
    device_map="auto",
)

pipeline.model = get_peft_model(pipeline.model, peft_config)

pipeline.tokenizer.pad_token = pipeline.tokenizer.eos_token
pipeline.model.generation_config.pad_token_id = pipeline.tokenizer.eos_token_id

# pipeline.model, pipeline.tokenizer = setup_chat_format(pipeline.model, pipeline.tokenizer)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [4]:
pipeline.model.print_trainable_parameters()

trainable params: 3,407,872 || all params: 8,033,669,120 || trainable%: 0.0424


In [5]:
# load dataset
raw_dataset = load_dataset(dataset_path, split = "train[:5000]")

# check format of data
raw_dataset = raw_dataset.train_test_split(test_size=0.1)
raw_dataset

DatasetDict({
    train: Dataset({
        features: ['Description', 'Patient', 'Doctor'],
        num_rows: 4500
    })
    test: Dataset({
        features: ['Description', 'Patient', 'Doctor'],
        num_rows: 500
    })
})

# Preprocessing

In [6]:
# format chat dataset
reload(pp)
format_chat = partial(pp.format_chat, input_col="Patient", output_col="Doctor", pipeline_name=pipeline)
chat_dataset = raw_dataset.map(format_chat)
chat_dataset['test'][0]



Map:   0%|          | 0/4500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

{'Description': 'Q. How can I manage blackening around the mouth?',
 'Patient': 'Hello doctor, I have blackening around my mouth since two months and it is not fading. I went to Dermatologist, she prescribed me a deepwhite cream. But it is not working. What should I do? Please help me.',
 'Doctor': 'Hi. Keep applying that cream plus start taking tablet Folic acid twice a day for 10 days, then make it once a day for next two weeks. Cranking around lip corners occur due to deficiency of Multivitamins, Folic acid and winter season can also be a reason. As far as home remedy is concerned, put olive oil on finger and put it around your lips so that soreness is reduced. Do it before sleeping for a week. And above all start drinking water, like 1 litre per day mandatory.',
 'text': '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHello doctor, I have blackening around my mouth since two months and it is not fading. I went to Dermatologist, she prescribed me a deepwhite cream. But

In [11]:
# tokenize dataset
tokenize_function = partial(pp.tokenize_data, pipeline_name=pipeline)
tokenized_dataset = chat_dataset.map(tokenize_function, 
                                     batched=True,
                                     remove_columns=chat_dataset['train'].column_names)

Map:   0%|          | 0/4500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

In [12]:
tokenized_dataset['test'][0]

{'input_ids': [128000,
  128000,
  128006,
  882,
  128007,
  271,
  9906,
  10896,
  11,
  358,
  617,
  3776,
  6147,
  2212,
  856,
  11013,
  2533,
  1403,
  4038,
  323,
  433,
  374,
  539,
  59617,
  13,
  358,
  4024,
  311,
  76508,
  266,
  16549,
  11,
  1364,
  32031,
  757,
  264,
  5655,
  5902,
  12932,
  13,
  2030,
  433,
  374,
  539,
  3318,
  13,
  3639,
  1288,
  358,
  656,
  30,
  5321,
  1520,
  757,
  13,
  128009,
  128006,
  78191,
  128007,
  271,
  13347,
  13,
  13969,
  19486,
  430,
  12932,
  5636,
  1212,
  4737,
  21354,
  435,
  7918,
  13935,
  11157,
  264,
  1938,
  369,
  220,
  605,
  2919,
  11,
  1243,
  1304,
  433,
  3131,
  264,
  1938,
  369,
  1828,
  1403,
  5672,
  13,
  4656,
  33434,
  2212,
  19588,
  24359,
  12446,
  4245,
  311,
  48294,
  315,
  22950,
  344,
  275,
  38925,
  11,
  435,
  7918,
  13935,
  323,
  12688,
  3280,
  649,
  1101,
  387,
  264,
  2944,
  13,
  1666,
  3117,
  439,
  2162,
  40239,
  374,
  11920,
  11

# Create Dataloaders

In [13]:
# instantiate data collator
data_collator = DataCollatorWithPadding(tokenizer=pipeline.tokenizer)

# options
batch_size = 8

train_dataloader = DataLoader(tokenized_dataset['train'],
                              batch_size=batch_size, 
                              collate_fn=data_collator)

val_dataloader = DataLoader(tokenized_dataset['test'],
                            batch_size=batch_size,
                            collate_fn=data_collator)

In [14]:
# inspect sample batch
batch = next(iter(train_dataloader))
{key: val.shape for key, val in batch.items()}

{'input_ids': torch.Size([8, 1024]),
 'attention_mask': torch.Size([8, 1024]),
 'labels': torch.Size([8, 1024])}

In [15]:
outputs = pipeline.model(**batch)
print(outputs.loss)

We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


tensor(11.5615, grad_fn=<ToCopyBackward0>)


In [16]:
# test pre training
text = [{'role': 'system', 'content': 'you are a helpful medical chatbot'},
        {'role': 'user', 'content': 'I have a headache. What should I do?'}]
print(pipeline(text, max_length=100, truncation=True)[0]['generated_text'])

[{'role': 'system', 'content': 'you are a helpful medical chatbot'}, {'role': 'user', 'content': 'I have a headache. What should I do?'}, {'role': 'assistant', 'content': "Sorry to hear that you're experiencing a headache! There are many possible causes of headaches, but don't worry, I'm here to help you figure out what to do.\n\nFirst, can you tell me a bit more about your headache? For example:\n\n* How long have you had it?\n* Is it a sharp or dull pain?\n"}]


# Training

In [17]:
# options
optimizer = AdamW(pipeline.model.parameters(), lr=1e-5)
num_epochs = 10
num_steps = num_epochs * len(train_dataloader)

# test after training
text = [{'role': 'system', 'content': 'You are a helpful medical chatbot'},
        {'role': 'user', 'content': 'I have a headache. What should I do?'}]

# loop
for epoch in range(num_epochs):
    
    print("=====================")
    print(f"Epoch {epoch + 1}")
    print("=====================")

    # set model to train mode
    pipeline.model.train()

    # initialize train loss, val loss
    running_train_loss = 0.0
    running_val_loss = 0.0

    # loop through train data
    print("Training...")
    i = 0
    for batch in tqdm(train_dataloader):

        # grab batch and map to device
        batch = {k: v.to(device) for k, v in batch.items()}

        # forward pass
        outputs = pipeline.model(**batch)
        loss = outputs.loss
        print(f"batch loss: {loss:.4f}\r", end="")

        running_train_loss += loss.item()

        # backward pass
        loss.backward()

        # update optimizer
        optimizer.step()

        # zero gradients
        optimizer.zero_grad()
        
        i += 1

        # if i % 5 == 0:
        #     print(pipeline(text, max_length=100, truncation=True)[0]['generated_text'])
            
    # set model to eval mode
    pipeline.model.eval()

    for batch in val_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = pipeline.model(**batch)
            loss = outputs.loss
            running_val_loss += loss.item()
        
    val_loss = running_val_loss / len(val_dataloader)

    print("Printing example response...")
    print(pipeline(text, max_length=100, truncation=True)[0]['generated_text'])

    train_loss = running_train_loss / len(train_dataloader)
    print(f"Avg. Train Loss: {train_loss:.4f}, Avg. Val Loss: {val_loss:.4f}")
    # print("Evaluation metrics:", metric.compute())

print("Training Complete!")

Epoch 1
Training...


  0%|          | 0/563 [00:00<?, ?it/s]

batch loss: 6.05864

# Prediction

In [15]:
# test after training
text = [{'role': 'system', 'content': 'You are a helpful medical chatbot'},
        {'role': 'user', 'content': 'I have a migraine. What should I do?'}]
print(pipeline(text, max_length=512, truncation=True)[0]['generated_text'])

[{'role': 'system', 'content': 'You are a helpful medical chatbot'}, {'role': 'user', 'content': 'I have a migraine. What should I do?'}, {'role': 'assistant', 'content': 'Sorry to hear that you are having a migraine.'}]
