## Supervised fine-tuning of Mistral on Medical Report generation datasets

In [None]:
!pip install -q transformers[torch]
!pip install -q datasets
!pip install -q huggingface_hub
!pip install -q accelerate -U
!pip install -q bitsandbytes
!pip install -q peft
!pip install -q trl
!pip install --q evaluate[evaluator]
!pip install --q rouge_score
!pip install --q bert_score

In [1]:
import torch
import glob
import pandas as pd
import numpy as np
import re
from peft import get_peft_model, PeftConfig, PeftModel, LoraConfig, prepare_model_for_kbit_training
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments, GenerationConfig, pipeline
import huggingface_hub
from trl import SFTTrainer
from datasets import Dataset, load_dataset
import evaluate

In [2]:
# model_name = "mistralai/Mistral-7B-v0.1"
model_name = "mistralai/Mistral-7B-Instruct-v0.2"

In [3]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    load_in_4bit=True,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
)

tokenizer = AutoTokenizer.from_pretrained(model_name)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [4]:
huggingface_hub.login(token="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")

Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


## Fine-tuning Mistral 7B on MTS-Dialog

#### Sanity check

In [None]:
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, torch_dtype=torch.bfloat16, device_map="auto")

In [None]:
prompt = "As a data scientist, can you explain the concept of regularization in machine learning?"

sequences = pipe(
    prompt,
    do_sample=True,
    max_new_tokens=100,
    temperature=0.7,
    top_k=50,
    top_p=0.95,
    num_return_sequences=1,
)
print(sequences[0]['generated_text'])

#### Dataset preparation

In [5]:
import locale
locale.getpreferredencoding = lambda: "UTF-8"

!git clone https://github.com/abachaa/MTS-Dialog.git

fatal: destination path 'MTS-Dialog' already exists and is not an empty directory.


In [6]:
mts_dialog_files = {"train": "MTS-Dialog/Main-Dataset/MTS-Dialog-TrainingSet.csv", "valid": "MTS-Dialog/Main-Dataset/MTS-Dialog-ValidationSet.csv", "test":  "MTS-Dialog/Main-Dataset/MTS-Dialog-TestSet-1-MEDIQA-Chat-2023.csv"}
mts_dialog_dataset = load_dataset("csv", data_files=mts_dialog_files)

print(mts_dialog_dataset)

DatasetDict({
    train: Dataset({
        features: ['ID', 'section_header', 'section_text', 'dialogue'],
        num_rows: 1201
    })
    valid: Dataset({
        features: ['ID', 'section_header', 'section_text', 'dialogue'],
        num_rows: 100
    })
    test: Dataset({
        features: ['ID', 'section_header', 'section_text', 'dialogue'],
        num_rows: 200
    })
})


In [7]:
def create_prompt(conversation, summary):
  prompt = f"<s>[INST] Write a resume of the following conversation between a doctor and a patient:{conversation}[/INST]{summary}"

  return prompt

In [8]:
train_data = []
for line in mts_dialog_dataset["train"]:
  prompt = create_prompt(line["dialogue"], line["section_text"])
  train_data.append({"text": prompt})

mts_dialog_processed_dataset = Dataset.from_list(train_data)

#### Training Prep

In [9]:
model.config.use_cache = False
model.config_pretraining_tp = 1
model.gradient_checkpointing_enable()

In [10]:
tokenizer.padding_side = "right"
tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_eos_token = True
tokenizer.add_bos_token = True

In [11]:
model = prepare_model_for_kbit_training(model)
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=64,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj","gate_proj"]
)
model = get_peft_model(model, peft_config)

#### Training

In [41]:
epochs = 1

training_arguments = TrainingArguments(
    output_dir="./mistral-mtsdialog-finetune",
    num_train_epochs=epochs,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=1,
    optim="paged_adamw_32bit",
    save_steps=25,
    logging_steps=25,
    learning_rate=2e-4,
    weight_decay=0.001,
    fp16=False,
    bf16=False,
    max_grad_norm=0.3,
    max_steps=-1,
    warmup_ratio=0.03,
    group_by_length=True,
    lr_scheduler_type="constant",
    push_to_hub=True,
    predict_with_generate=True,
)

TypeError: TrainingArguments.__init__() got an unexpected keyword argument 'predict_with_generate'

In [38]:
trainer = SFTTrainer(
    model=model,
    train_dataset=mts_dialog_processed_dataset,
    peft_config=peft_config,
    max_seq_length= None,
    dataset_text_field="text",
    tokenizer=tokenizer,
    args=training_arguments,
    packing= False,
)

Map:   0%|          | 0/1201 [00:00<?, ? examples/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [39]:
trainer.train()

Step,Training Loss,Validation Loss


TypeError: argument 'ids': 'list' object cannot be interpreted as an integer

## ACI-Bench Supervised finetune

In [None]:
!git clone https://github.com/wyim/aci-bench

In [None]:
aci_bench_files = {"train": "aci-bench/data/challenge_data/train.csv", "valid": "aci-bench/data/challenge_data/valid.csv", "test":  "aci-bench/data/challenge_data/clinicalnlp_taskB_test1.csv"}
aci_bench_dataset = load_dataset("csv", data_files=aci_bench_files)

print(aci_bench_dataset)

## Combining the adapters

In [None]:
print("Hey")

## Evaluation

#### MTS-Dialog

In [None]:
print("MTS-Dialog")

#### ACI-Bench

In [None]:
print("ACI-Bench")

#### Psychiatry

In [None]:
!git clone https://github.com/nazmulkazi/dataset_automated_medical_transcription.git

transcripts_path = "./dataset_automated_medical_transcription/transcripts/transcribed/"
casenotes_path = "./dataset_automated_medical_transcription/casenotes/annotator_1/"

def prepare_transcripts(path):
    files = os.listdir(path)
    result = []

    for file in files:
        with open(path + file, "r") as raw_data:
            data = json.load(raw_data)

            conversation = ""
            for turn in data:
                speaker = "Doctor" if turn["speaker"] == 1 else "Patient"
                dialogue = " ".join(turn["dialogue"])
                conversation += f"{speaker}: {dialogue}\n"

            result.append(conversation)

    return result

def prepare_casenotes(path):
    files = os.listdir(path)
    annotators = os.listdir("./dataset_automated_medical_transcription/casenotes/")
    result = []

    categories = ["Client Details", "Chief Complaint", "History of Present Illness", "Past Psychiatric History", "History of Substance Use", "Social History", "Family History", "Review of Systems"]

    # Iterating over annotator #1
    for file in files:
        with open(path + file, "r") as raw_data:
            data = json.load(raw_data)

            casenotes = {}
            for category in categories:
                casenotes[category] = []

            for element in data:
                category = categories[int(element["categoryId"])]
                casenotes[category].append(element["formalText"])


            # Iterating over other annotators to find casenotes for the same case
            for annotator in annotators:
                if annotator != "annotator_1":
                    annotations = os.listdir("./dataset_automated_medical_transcription/casenotes/" + annotator)
                    if file in annotations:
                        with open("./dataset_automated_medical_transcription/casenotes/" + annotator + "/" + file, "r") as raw_data:
                            data = json.load(raw_data)

                            for element in data:
                                category = categories[int(element["categoryId"])]
                                if len(casenotes[category]) == 0:
                                    casenotes[category].append(element["formalText"])

            casenote = ""
            for category in casenotes.keys():
                casenote += f"{category}:\n"

                content = casenotes[category] if len(casenotes[category]) > 0 else ["None"]

                casenote += " ".join(content)
                casenote += "\n\n"
            result.append(casenote)



    return result

transcripts = prepare_transcripts(transcripts_path)
case_notes = prepare_casenotes(casenotes_path)

dataset = {"transcripts": transcripts, "case_notes": case_notes}
dataset = Dataset.from_dict(dataset)