In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import T5Tokenizer, T5ForConditionalGeneration, DataCollatorForSeq2Seq
from transformers.models.t5.modeling_t5 import T5LayerFF
from datasets import load_dataset
import numpy as np
import pandas as pd
from huggingface_hub import HfApi
from tqdm.notebook import tqdm
from IPython import display
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    confusion_matrix,
)
import gc

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

trainset_range = list(range(0, 15000))
testset_range = list(range(18000, 22000))
base_t5_path = "t5-small"
medmcqa_dataset_path = "openlifescienceai/medmcqa"
checkpoint_file = "T5-Finetuned-15k-20epoch.pth"
test_dataset_file_name = "MEDMCQA-test-dataset-4k.json"
repo_id = "MMK79/Medical-RAG"
# push_model_to_huggingface = False
# push_dataset_to_huggingface = False
push_model_to_huggingface = True
push_dataset_to_huggingface = True

batch_size = 8
lr = 1e-4
num_epochs = 20
bottleneck_size = 32

In [None]:
tokenizer = T5Tokenizer.from_pretrained(base_t5_path)
model = T5ForConditionalGeneration.from_pretrained(base_t5_path)

In [None]:
opt_idx2str = {
    0: "A",
    1: "B",
    2: "C",
    3: "D",
}
dataset = load_dataset(medmcqa_dataset_path)
dataset

In [None]:
train_dataset = dataset["train"].select(trainset_range)
test_dataset = dataset["train"].select(testset_range)

In [None]:
def filter_none(example):
    return (
        (example["exp"] is not None)
        and (len(example["exp"]) > 20)
        and (example["question"] is not None)
    )


train_dataset = train_dataset.filter(filter_none)
test_dataset = test_dataset.filter(filter_none)
test_dataset_raw = test_dataset.filter(filter_none)
dataset["validation"] = dataset["validation"].filter(filter_none)

In [None]:
def format_example_training(row):
    input_text = f"Question: {row['question']}\n\nOptions:\nA: {row['opa']}\nB: {row['opb']}\nC: {row['opc']}\nD: {row['opd']}\n\nExplanation: {row['exp']}\n\nAnswer:"
    target_text = f"Answer: {opt_idx2str[row['cop']]}"
    return {"input_text": input_text, "target_text": target_text}


def format_example_validation(row):
    input_text = f"Question: {row['question']}\n\nOptions:\nA: {row['opa']}\nB: {row['opb']}\nC: {row['opc']}\nD: {row['opd']}\n\nExplanation: {row['exp']}\n\nAnswer:"
    target_text = f"Answer: {opt_idx2str[row['cop']]}"
    return {"input_text": input_text, "target_text": target_text}


train_dataset = train_dataset.map(
    format_example_training, remove_columns=train_dataset.column_names
)
test_dataset = test_dataset.map(
    format_example_training, remove_columns=test_dataset.column_names
)
dataset["validation"] = dataset["validation"].map(
    format_example_validation, remove_columns=dataset["validation"].column_names
)

In [None]:
def map_function(row):
    input_info = tokenizer(row["input_text"], truncation=True, max_length=1024)
    output_info = tokenizer(row["target_text"])
    return {**input_info, "labels": output_info.input_ids}


train_dataset = train_dataset.map(map_function, batched=True)
train_dataset.set_format(
    type="torch", columns=["input_ids", "attention_mask", "labels"]
)

test_dataset = test_dataset.map(map_function, batched=True)
test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

dataset["validation"] = dataset["validation"].map(map_function, batched=True)
dataset["validation"].set_format(
    type="torch", columns=["input_ids", "attention_mask", "labels"]
)

In [None]:
col_fn = DataCollatorForSeq2Seq(tokenizer, return_tensors="pt", padding="longest")

train_loader = DataLoader(
    train_dataset, batch_size=batch_size, collate_fn=col_fn, shuffle=True
)
test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=col_fn)
val_loader = DataLoader(dataset["validation"], batch_size=batch_size, collate_fn=col_fn)

In [None]:
torch.cuda.empty_cache()
gc.collect()

In [None]:
def train_loop(model, loader, optimizer):
    model.train()

    batch_losses = []

    for row in tqdm(loader):
        row = row.to(model.device)

        optimizer.zero_grad()
        out = model(**row)
        loss = out.loss
        batch_loss_value = loss.item()
        loss.backward()
        optimizer.step()

        batch_losses.append(batch_loss_value)
    loss_value = np.mean(batch_losses)
    return {"train_loss": loss_value}


def predict(model, row):
    return model.generate(
        input_ids=row.input_ids, attention_mask=row.attention_mask, max_length=5
    )


def tokenizer_ids_to_label(all_input_ids):
    vocab_size = tokenizer.vocab_size

    filtered_input_ids = [
        [token_id for token_id in seq if 0 <= token_id < vocab_size]
        for seq in all_input_ids
    ]

    return tokenizer.batch_decode(filtered_input_ids, skip_special_tokens=True)


def valid_loop(model, loader, compute_metric):
    model.eval()

    all_true = []
    all_pred = []

    with torch.no_grad():
        for row in tqdm(loader):
            row = row.to(model.device)
            pred = predict(model, row)

            all_true += row.labels.detach().cpu().tolist()
            all_pred += pred.detach().cpu().tolist()

    all_true = tokenizer_ids_to_label(all_true)
    all_pred = tokenizer_ids_to_label(all_pred)

    return {"valid_acc": compute_metric(y_true=all_true, y_pred=all_pred)}

In [None]:
# Adapter layer
class AdapterLayer(nn.Module):
    def __init__(self, emb_dim: int, bottleneck_size: int):

        super().__init__()

        self.sharif_llm_adapter = nn.Sequential(
            nn.Linear(emb_dim, bottleneck_size),
            nn.ReLU(),
            nn.Linear(bottleneck_size, emb_dim),
        )

    def forward(self, x: torch.Tensor):
        adapter_output = self.sharif_llm_adapter(x)
        output = x + adapter_output
        return output


class FeedForwardAdapterWrapper(nn.Module):
    def __init__(self, original_module: T5LayerFF, bottleneck_size: int):

        super().__init__()
        assert isinstance(original_module, T5LayerFF)

        self.original_module = original_module
        emb_dim = original_module.DenseReluDense.wi.in_features
        self.adapter = AdapterLayer(emb_dim, bottleneck_size)

    def forward(self, x: torch.Tensor):
        output = self.original_module(x)
        output = self.adapter(output)
        return output

In [None]:
# Add adapter to the model
def mutate_model_recursive(model: nn.Module, bottleneck_size: int):
    for name, module in model.named_children():
        if isinstance(module, T5LayerFF):
            feed_forward_with_adapter = FeedForwardAdapterWrapper(
                module, bottleneck_size
            )
            setattr(model, name, feed_forward_with_adapter)
            print(f"Replaced {name} with FeedForwardAdapterWrapper layer.")
        else:
            mutate_model_recursive(module, bottleneck_size)


def mutate_model(model: nn.Module, bottleneck_size: int):
    if hasattr(model, "_mutated"):
        print("Model already contains adapter layers! \n Try reloading the model.")
        return

    mutate_model_recursive(model, bottleneck_size)

    model._mutated = True


mutate_model(model, bottleneck_size=bottleneck_size)

In [None]:
# Freeze non-adapter parameters
def freeze_non_adapter(model, peft_key):
    print("Non freezed weights:")
    total_params = 0
    for param_name, weights in model.named_parameters():
        weights.requires_grad = peft_key in param_name
        if weights.requires_grad:
            print(param_name)
            total_params += weights.numel()
    print(f"Total number of parameters should be update: {total_params}")


freeze_non_adapter(model, peft_key="sharif_llm")

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
model.to(device)

all_results = []
for epoch in range(num_epochs):
    epoch_results = {"epoch": epoch}

    epoch_results.update(
        train_loop(model=model, loader=train_loader, optimizer=optimizer)
    )
    epoch_results.update(
        valid_loop(model=model, loader=val_loader, compute_metric=accuracy_score)
    )
    all_results.append(epoch_results)

    display.clear_output()
    display.display(pd.DataFrame(all_results).set_index("epoch"))

In [None]:
opt_idx2str = {
    0: "A",
    1: "B",
    2: "C",
    3: "D",
}
opt_str2idx = {s: i for i, s in opt_idx2str.items()}


def convert_answer(answer):
    prefix_str = "Answer: "
    if answer.startswith(prefix_str):
        try:
            option = answer[len(prefix_str) :]
            return opt_str2idx[option]
        except:
            return 100
    return 100


def show_classification_metrics(trues, preds):
    preds = [convert_answer(answer) for answer in preds]
    trues = [convert_answer(answer) for answer in trues]

    accuracy = accuracy_score(trues, preds)
    macro_f1 = f1_score(trues, preds, average="macro")
    micro_f1 = f1_score(trues, preds, average="micro")
    macro_precision = precision_score(trues, preds, average="macro", zero_division=0)
    macro_recall = recall_score(trues, preds, average="macro", zero_division=0)
    conf_matrix = confusion_matrix(trues, preds)

    print(f"Accuracy        =  {accuracy * 100:.2f}%")
    print(f"Macro F1-score  =  {macro_f1 * 100:.2f}%")
    print(f"Micro F1-score  =  {micro_f1 * 100:.2f}%")
    print(f"Macro Precision =  {macro_precision * 100:.2f}%")
    print(f"Macro Recall    =  {macro_recall * 100:.2f}%")

    if (
        100 in preds
    ):  # Model's answer for some questions is not among the question options
        class_names = ["Option A", "Option B", "Option C", "Option D", "None"]
    else:  # Model's answer for every single question is among the provided options
        class_names = ["Option A", "Option B", "Option C", "Option D"]

    plt.figure(figsize=(8, 6))
    sns.heatmap(
        conf_matrix,
        annot=True,
        fmt="d",
        cmap="Blues",
        xticklabels=class_names,
        yticklabels=class_names,
    )
    plt.xlabel("Predicted labels")
    plt.ylabel("True labels")
    plt.title("Confusion Matrix")
    plt.show()

In [None]:
model.eval()

all_true = []
all_pred = []

with torch.no_grad():
    for row in tqdm(test_loader):
        row = row.to(model.device)
        pred = predict(model, row)

        all_true += row.labels.detach().cpu().tolist()
        all_pred += pred.detach().cpu().tolist()

all_true = tokenizer_ids_to_label(all_true)
all_pred = tokenizer_ids_to_label(all_pred)

show_classification_metrics(all_true, all_pred)

In [None]:
print(all_true[:20])
print()
print(all_pred[:20])

In [None]:
torch.save(
    {
        "model_state_dict": model.state_dict(),
    },
    checkpoint_file,
)

In [None]:
if push_model_to_huggingface:
    # generate a token from Profile > Setting > Access Tokens with write access
    api = HfApi(
        token="",
    )
    api.upload_file(
        path_or_fileobj=f"./{checkpoint_file}",
        path_in_repo=checkpoint_file,
        repo_id=repo_id,
        repo_type="model",
    )

In [None]:
# code to generate answer based on model
def generate_answer(row):
    input_text = f"Question: {row['question']}\n\nOptions:\nA: {row['opa']}\nB: {row['opb']}\nC: {row['opc']}\nD: {row['opd']}\n\nExplanation: {row['exp']}\n\nAnswer:"
    input_ids = tokenizer(input_text, truncation=True, max_length=1024)
    input_ids = tokenizer.encode(input_text, return_tensors="pt")
    outputs = model.generate(input_ids.to(device), max_length=5)
    answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return answer

In [None]:
sample_query = {
    "question": "Which of the following is most characteristic of diabetic neuropathy?",
    "opa": "it is usually bilateral",
    "opb": "pain is not a feature",
    "opc": "it most commonly affects the brain",
    "opd": "it spares the autonomic system",
    "cop": 0,
    "exp": "Diabetic neuropathy usually presents as peripheral polyneuropathy, usually bilateral, including symptoms of numbness, paresthesia, severe hyperesthesia, and pain. Impairment of proprioceptive fibers can lead to gait abnormalities and Charcot's joints. Mononeuropathy is less common and is often spontaneously reversible. Common syndromes include wrist or foot drop and third, fourth, or sixth cranial nerve palsies. Autonomic neuropathy may cause gastroesophageal dysfunction, bladder dysfunction, and orthostatic hypotension.",
}

model_ans = generate_answer(sample_query)
print(f'Model\'s output =  "{model_ans}"')

correct_ans = f"Answer: {opt_idx2str[sample_query['cop']]}"
print(f'Correct output =  "{correct_ans}"')

In [None]:
test_dataset_raw_df = pd.DataFrame(test_dataset_raw.to_dict())
test_dataset_raw_df.to_json(test_dataset_file_name)

if push_dataset_to_huggingface:
    # generate a token from Profile > Setting > Access Tokens with write access
    api = HfApi(
        token="",
    )
    api.upload_file(
        path_or_fileobj=f"./{test_dataset_file_name}",
        path_in_repo=test_dataset_file_name,
        repo_id=repo_id,
        repo_type="model",
    )