In [1]:
# import necessary packages
import sys
import torch 
import numpy as np
import evaluate
from transformers import (pipeline,
                          AutoTokenizer,
                          AutoModelForCausalLM,
                          DataCollatorWithPadding,
                          get_scheduler)
from datasets import load_dataset
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm.auto import tqdm
from IPython.display import clear_output

sys.path.append('../')

# custom imports
from utils.GetLowestGPU import GetLowestGPU

device = GetLowestGPU()

Device set to cuda:0


# Instantiate Model and Dataset

In [2]:
# options
model_path = "meta-llama/Meta-Llama-3-8B"
dataset_path = "ruslanmv/ai-medical-chatbot"

# load tokenizer and model
pipeline = pipeline('text-generation', 
                    model=model_path,
                    model_kwargs={'torch_dtype': torch.bfloat16},
                    device_map = 'auto'
                    )

tokenizer = pipeline.tokenizer
model = pipeline.model

raw_dataset = load_dataset(dataset_path, split='train[:10%]')

# check format of data
raw_dataset



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Dataset({
    features: ['Description', 'Patient', 'Doctor'],
    num_rows: 25692
})

# Preprocessing

In [3]:
# add special tokens to tokenizer
tokenizer.add_special_tokens({'pad_token': '<pad>'})
model.resize_token_embeddings(len(tokenizer))

# define preprocessing function
def preprocess_data(examples):
    instructions = examples['Patient']
    responses = examples['Doctor']
    return tokenizer(text=instructions,
                     text_target=responses, 
                     max_length=1024, 
                     return_tensors='pt',
                     truncation=True, 
                     padding='max_length')

tokenized_dataset = raw_dataset.map(preprocess_data, batched=True)

# check tokenized dataset output
tokenized_dataset

Dataset({
    features: ['Description', 'Patient', 'Doctor', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 25692
})

In [4]:
# instantiate data collator
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# remove cols that are not needed and set format to torch
tokenized_dataset = tokenized_dataset.remove_columns(['Description', 'Patient', 'Doctor'])
tokenized_dataset.set_format(type='torch')
tokenized_dataset

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 25692
})

# Create Dataloaders

In [5]:
# options
train_split = .80
batch_size = 4


# split data between train and val
tokenized_dataset = tokenized_dataset.train_test_split(test_size=(1-train_split))
tokenized_dataset

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 20553
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 5139
    })
})

In [6]:
train_dataloader = DataLoader(tokenized_dataset['train'], 
                              shuffle=True, batch_size=8, 
                              collate_fn=data_collator)

val_dataloader = DataLoader(tokenized_dataset['test'],
                            shuffle=False, 
                            batch_size=8,
                            collate_fn=data_collator)

len(train_dataloader), len(val_dataloader)

(2570, 643)

In [7]:
# inspect sample batch
batch = next(iter(train_dataloader))

{k: v.shape for k, v in batch.items()}

{'input_ids': torch.Size([8, 1024]),
 'attention_mask': torch.Size([8, 1024]),
 'labels': torch.Size([8, 1024])}

In [8]:
outputs = model(**batch)
print(outputs.loss, outputs.logits.shape)

tensor(17.4647, grad_fn=<ToCopyBackward0>) torch.Size([8, 1024, 128257])


# Training

In [9]:
# initialize optimizer
optimizer = AdamW(model.parameters(), lr=5e-5)

# and scheduler
num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)

print(num_training_steps)

7710


In [10]:
# eval loop

# define metrics
# metric = evaluate.load("glue", "mrpc")

# loop through epochs
for epoch in range(num_epochs):
    
    clear_output(wait=True)

    print(f"Epoch {epoch + 1}\n=====================")

    # set model to train mode
    model.train()

    # initialize train loss, val loss
    train_loss = 0.0
    val_loss = 0.0

    # loop through train data
    print("Training...")
    for batch in tqdm(train_dataloader):

        # grab batch and map to device
        batch = {k: v.to(device) for k, v in batch.items()}

        # forward pass
        outputs = model(**batch)
        loss = outputs.loss

        train_loss += loss.item()

        # backward pass
        loss.backward()

        # update optimizer
        optimizer.step()

        # update scheduler
        lr_scheduler.step()

        # zero gradients
        optimizer.zero_grad()

    train_loss = train_loss / (len(train_dataloader) / batch_size)

    # set to eval mode
    model.eval()
    print("Validating...")
    for batch in val_dataloader:

        # get batch
        batch = {k: v.to(device) for k, v in batch.items()}

        # forward pass
        with torch.no_grad():
            outputs = model(**batch)

        # get loss
        loss = outputs.loss
        val_loss += loss.item()

        # get logits, predictions
        logits = outputs.logits
        predictions = torch.argmax(logits, dim=-1)
        # metric.add_batch(predictions=predictions, references=batch["labels"])


    val_loss = val_loss / (len(train_dataloader) / batch_size)

    print(f"Avg. Train Loss: {train_loss}, Avg. Val Loss: {val_loss}")
    # print("Evaluation metrics:", metric.compute())


Epoch 1
Training...


  0%|          | 0/2570 [00:00<?, ?it/s]

# Prediction

In [None]:
# run a test prediction
test_input = "I have a headache and my stomach hurts. What should I do?"
model_output = pipeline(test_input, max_length=8, return_full_text=False, truncation=True)
print("Response:", model_output[0]['generated_text'])

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Response: ? mostly mostly mostlylinlinlin
