In [1]:
# import necessary packages
import sys, os
import torch 
import numpy as np
import evaluate
from trl import SFTTrainer, setup_chat_format
from transformers import (pipeline,
                          AutoTokenizer,
                          AutoModelForCausalLM,
                          DataCollatorWithPadding,
                          get_scheduler)
from datasets import load_dataset
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm.auto import tqdm
from IPython.display import clear_output

sys.path.append('../')

# custom imports
from utils.GetLowestGPU import GetLowestGPU

device = GetLowestGPU()

Device set to cuda:1


In [2]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Instantiate Model and Dataset

In [3]:
# options
model_path = "meta-llama/Meta-Llama-3-8B-Instruct"
dataset_path = "ruslanmv/ai-medical-chatbot" #test dataset

# load tokenizer and model
pipeline = pipeline('text-generation', 
                    model=model_path,
                    model_kwargs={'torch_dtype': torch.bfloat16},
                    device_map = 'auto'
                    )

model, tokenizer = pipeline.model, pipeline.tokenizer
model, tokenizer = setup_chat_format(model, tokenizer)

tokenizer.pad_token = tokenizer.eos_token
model.resize_token_embeddings(len(tokenizer))

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Embedding(128258, 4096)

UndefinedError: 'str object' has no attribute 'role'

In [4]:
message = [{'role': 'system', 'content': 'You are a helpful medical chatbot'},
           {'role': 'user', 'content': 'I have a headache. What should I do?'}]
print(pipeline(message, max_length=100)[0]['generated_text'])

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


[{'role': 'system', 'content': 'You are a helpful medical chatbot'}, {'role': 'user', 'content': 'I have a headache. What should I do?'}, {'role': 'assistant', 'content': "Sorry to hear that you're experiencing a headache! There are several things you can try to help alleviate the discomfort. Here are some suggestions:\n\n1. **Stay hydrated**: Dehydration is a common cause of headaches. Drink plenty of water or other fluids to help your body replenish its water levels. Aim for at least 8-10"}]


In [7]:
# load dataset
raw_dataset = load_dataset(dataset_path, split = 'train[:1%]')

# check format of data
raw_dataset = raw_dataset.train_test_split(test_size=0.1)
raw_dataset

DatasetDict({
    train: Dataset({
        features: ['Description', 'Patient', 'Doctor'],
        num_rows: 2312
    })
    test: Dataset({
        features: ['Description', 'Patient', 'Doctor'],
        num_rows: 257
    })
})

# Preprocessing

In [14]:
# preprocess data
def format_chat(row):
    row_json_inp = [{"role": "user", "content": row["Patient"]}]
    row_json_out = [{"role": "doctor", "content": row["Doctor"]}]
    row["input"] = tokenizer.apply_chat_template(row_json_inp, tokenize=False)
    row["target"] = tokenizer.apply_chat_template(row_json_out, tokenize=False)
    return row

def preprocess_data(examples):
    inp = examples["input"]
    out = examples["target"]
    tokenized_data = tokenizer(text=inp, 
                               text_target=out,
                               padding='max_length', 
                               truncation=True, 
                               max_length=512)
    return tokenized_data

In [19]:
# add special tokens to tokenizer
chat_dataset = raw_dataset.map(format_chat)
tokenized_dataset = chat_dataset.map(preprocess_data, 
                                    batched=True,
                                    remove_columns=chat_dataset['train'].column_names)
tokenized_dataset.with_format("torch")

# check tokenized dataset output
tokenized_dataset

Map:   0%|          | 0/2312 [00:00<?, ? examples/s]

Map:   0%|          | 0/257 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 2312
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 257
    })
})

# Create Dataloaders

In [28]:
# instantiate data collator
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# options
batch_size = 8

train_dataloader = DataLoader(tokenized_dataset['train'],
                              batch_size=batch_size, 
                              collate_fn=data_collator)

val_dataloader = DataLoader(tokenized_dataset['test'],
                            batch_size=batch_size,
                            collate_fn=data_collator)

In [29]:
# inspect sample batch
batch = next(iter(train_dataloader))
{key: val.shape for key, val in batch.items()}

{'input_ids': torch.Size([8, 512]),
 'attention_mask': torch.Size([8, 512]),
 'labels': torch.Size([8, 512])}

In [30]:
outputs = model(**batch)
print(outputs.loss, outputs.logits.shape)

tensor(15.4872, grad_fn=<ToCopyBackward0>) torch.Size([8, 512, 128258])


# Training

In [31]:
# initialize optimizer
optimizer = AdamW(model.parameters(), lr=5e-5)

# and scheduler
num_epochs = 1
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)

print(num_training_steps)

289


In [32]:
# eval loop
# loop through epochs
for epoch in range(num_epochs):
    clear_output(wait=True)
    
    print("=====================")
    print(f"Epoch {epoch + 1}")
    print("=====================")

    # set model to train mode
    model.train()

    # initialize train loss, val loss
    train_loss = 0.0
    val_loss = 0.0

    # loop through train data
    print("Training...")
    for batch in train_dataloader:

        # grab batch and map to device
        batch = {k: v.to(device) for k, v in batch.items()}

        # forward pass
        outputs = model(**batch)
        loss = outputs.loss

        train_loss += loss.item()

        # backward pass
        loss.backward()

        # update optimizer
        optimizer.step()

        # update scheduler
        lr_scheduler.step()

        # zero gradients
        optimizer.zero_grad()

    train_loss = train_loss / (len(train_dataloader) / batch_size)

    # set to eval mode
    model.eval()
    print("Validating...")
    for batch in val_dataloader:

        # get batch
        batch = {k: v.to(device) for k, v in batch.items()}

        # forward pass
        with torch.no_grad():
            outputs = model(**batch)

        # get loss
        loss = outputs.loss
        val_loss += loss.item()


    val_loss = val_loss / (len(val_dataloader) / batch_size)

    print(f"Avg. Train Loss: {train_loss}, Avg. Val Loss: {val_loss}")
    # print("Evaluation metrics:", metric.compute())


Epoch 1
Training...


  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


KeyboardInterrupt: 

# Prediction

In [15]:
message = 'I have a headache. What should I do?'
# test after training
text = [{'role': 'system', 'content': 'You are a helpful medical chatbot'},
        {'role': 'user', 'content': message}]
print(pipeline(text, max_length=100)[0]['generated_text'])

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


The answer to the ultimate question of life, the universe and everything is... * The universe is huge but not infinite * There are many many many many many many many many many many many many many many many many many many many many many many many many many many many many many many many many many many many many many manythe answer to the ultimate question of life, the universe and everything is... * The universe is huge but not infinite * There are many many many many many many many many many
