In [1]:
# import necessary packages
import sys, os
import torch 
import numpy as np
import evaluate
from trl import SFTTrainer, setup_chat_format
from transformers import (pipeline,
                          AutoTokenizer,
                          AutoModelForCausalLM,
                          DataCollatorWithPadding,
                          get_scheduler)
from datasets import load_dataset
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm.auto import tqdm
from IPython.display import clear_output

sys.path.append('../')

# custom imports
from utils.GetLowestGPU import GetLowestGPU

device = GetLowestGPU()

Device set to cuda:6


In [2]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Instantiate Model and Dataset

In [3]:
# options
model_path = "meta-llama/Meta-Llama-3-8B-Instruct"
dataset_path = "ruslanmv/ai-medical-chatbot" #test dataset

# load tokenizer and model
pipeline = pipeline(
    "text-generation",
    model=model_path,
    model_kwargs={"torch_dtype": torch.bfloat16},
    device_map="auto",
)

pipeline.tokenizer.pad_token = pipeline.tokenizer.eos_token
pipeline.model.generation_config.pad_token_id = pipeline.tokenizer.eos_token_id

# pipeline.model, pipeline.tokenizer = setup_chat_format(pipeline.model, pipeline.tokenizer)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:
# load dataset
raw_dataset = load_dataset(dataset_path, split = "train[:5000]")

# check format of data
raw_dataset = raw_dataset.train_test_split(test_size=0.1)
raw_dataset

DatasetDict({
    train: Dataset({
        features: ['Description', 'Patient', 'Doctor'],
        num_rows: 4500
    })
    test: Dataset({
        features: ['Description', 'Patient', 'Doctor'],
        num_rows: 500
    })
})

# Preprocessing

In [8]:
# preprocess data
def format_chat(row):
    row_json = [{'role': 'user', 'content': row["Patient"]},
                {'role': 'assistant', 'content': row["Doctor"]}]
    row["text"] = row_json
    return row

def preprocess_data(examples):
    inputs = pipeline.tokenizer.apply_chat_template(examples["text"], tokenize=False)
    tokenized_data = pipeline.tokenizer(text=inputs,
                               padding='max_length', 
                               truncation=True, 
                               max_length=100)
    
    labels = tokenized_data['input_ids'].copy()

    for i in range(len(labels)):
        if labels[i][-1] != pipeline.tokenizer.pad_token_id:
            labels[i] = labels[i][1:] + [pipeline.tokenizer.pad_token_id]
        else:
            labels[i] = labels[i][1:] + [-100]

    labels = [[-100 if x == pipeline.tokenizer.pad_token_id else x for x in y] for y in labels]
    tokenized_data['labels'] = labels
    
    return tokenized_data

In [9]:
chat_dataset = raw_dataset.map(format_chat)
chat_dataset['test'][0]

Map:   0%|          | 0/231224 [00:00<?, ? examples/s]

Map:   0%|          | 0/25692 [00:00<?, ? examples/s]

{'Description': 'Q. Can Yasmin birth control pill be used as an emergency contraceptive pill?',
 'Patient': 'Hello doctor, My fiancee and I had unprotected sex a few days back, but I did not ejaculate inside her. Just to be on the safer side, we wanted to use the emergency contraceptive pill. But due to some restriction in the country where we live, Plan B or emergency contraceptive pills are not available. I read that Yasmin, which is used as a regular contraceptive pill can be used as an emergency contraceptive pill at a higher dosage. Can Yasmin be used as an emergency contraceptive pill? And at what dosage?',
 'Doctor': 'Hi. How are you doing? Yes, as you have heard, Yasmin can be used as an emergency pill. Even if you have not ejaculated into her, she stands a chance of pregnancy if the pre-seminal fluid, the clear fluid that comes out before semen, which is rich in young healthy sperms, comes in contact with her genitals. As soon as possible, earlier the better, I suggest taking 

In [10]:
# add special tokens to tokenizer
tokenized_dataset = chat_dataset.map(preprocess_data, 
                                    batched=True,
                                    remove_columns=chat_dataset['train'].column_names)
tokenized_dataset.with_format("torch")

# check tokenized dataset output
tokenized_dataset



Map:   0%|          | 0/231224 [00:00<?, ? examples/s]

Map:   0%|          | 0/25692 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 231224
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 25692
    })
})

# Create Dataloaders

In [11]:
# instantiate data collator
data_collator = DataCollatorWithPadding(tokenizer=pipeline.tokenizer)

# options
batch_size = 8

train_dataloader = DataLoader(tokenized_dataset['train'],
                              batch_size=batch_size, 
                              collate_fn=data_collator)

val_dataloader = DataLoader(tokenized_dataset['test'],
                            batch_size=batch_size,
                            collate_fn=data_collator)

In [12]:
# inspect sample batch
batch = next(iter(train_dataloader))
{key: val.shape for key, val in batch.items()}

{'input_ids': torch.Size([8, 100]),
 'attention_mask': torch.Size([8, 100]),
 'labels': torch.Size([8, 100])}

In [13]:
outputs = pipeline.model(**batch)
print(outputs.loss)

tensor(11.7867, grad_fn=<ToCopyBackward0>)


In [14]:
# test pre training
text = [{'role': 'system', 'content': 'you are a helpful medical chatbot'},
        {'role': 'user', 'content': 'I have a headache. What should I do?'}]
print(pipeline(text, max_length=100, truncation=True)[0]['generated_text'])

[{'role': 'system', 'content': 'you are a helpful medical chatbot'}, {'role': 'user', 'content': 'I have a headache. What should I do?'}, {'role': 'assistant', 'content': "Sorry to hear that you're experiencing a headache! As a helpful medical chatbot, I'd like to offer some suggestions to help you alleviate your discomfort.\n\n1. **Stay hydrated**: Dehydration is a common cause of headaches. Drink plenty of water or other fluids to replenish your body's water supply. Aim for at least 8"}]


# Training

In [15]:
# options
optimizer = AdamW(pipeline.model.parameters(), lr=1e-5)
num_epochs = 1
num_steps = num_epochs 

# test after training
text = [{'role': 'system', 'content': 'You are a helpful medical chatbot'},
        {'role': 'user', 'content': 'I have a headache. What should I do?'}]

# loop
for epoch in range(num_epochs):
    
    print("=====================")
    print(f"Epoch {epoch + 1}")
    print("=====================")

    # set model to train mode
    pipeline.model.train()

    # initialize train loss, val loss
    running_train_loss = 0.0
    running_val_loss = 0.0

    # loop through train data
    print("Training...")
    i = 0
    for batch in tqdm(train_dataloader):

        # grab batch and map to device
        batch = {k: v.to(device) for k, v in batch.items()}

        # forward pass
        outputs = pipeline.model(**batch)
        loss = outputs.loss
        print(f"batch loss: {loss:.4f}\r", end="")

        running_train_loss += loss.item()

        # backward pass
        loss.backward()

        # update optimizer
        optimizer.step()

        # zero gradients
        optimizer.zero_grad()
        
        i += 1

        # if i % 5 == 0:
        #     print(pipeline(text, max_length=100, truncation=True)[0]['generated_text'])
            
    # set model to eval mode
    pipeline.model.eval()

    for batch in val_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = pipeline.model(**batch)
            loss = outputs.loss
            running_val_loss += loss.item()
        
    val_loss = running_val_loss / len(val_dataloader)

    print("Printing example response...")
    print(pipeline(text, max_length=100, truncation=True)[0]['generated_text'])

    train_loss = running_train_loss / len(train_dataloader)
    print(f"Avg. Train Loss: {train_loss:.4f}, Avg. Val Loss: {val_loss:.4f}")
    # print("Evaluation metrics:", metric.compute())

print("Training Complete!")

Epoch 1
Training...


  0%|          | 0/28903 [00:00<?, ?it/s]

batch loss: 3.84227

KeyboardInterrupt: 

# Prediction

In [16]:
# test after training
text = [{'role': 'system', 'content': 'You are a helpful medical chatbot'},
        {'role': 'user', 'content': 'I have a migraine. What should I do?'}]
print(pipeline(text, max_length=100, truncation=True)[0]['generated_text'])

[{'role': 'system', 'content': 'You are a helpful medical chatbot'}, {'role': 'user', 'content': 'I have a migraine. What should I do?'}, {'role': 'assistant', 'content': ', hope are fine I gone your. you to aologist medicine take a medicine you pain  or  your  is and your is                                                                        '}]
