In [12]:
!pip install -q datasets
!pip install -q bitsandbytes
!pip install -q peft
!pip install -q accelerate
!pip install -q trl
!pip install -q wandb

In [None]:
# !pip install --upgrade pyarrow



In [22]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [13]:
import torch
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
)
from peft import LoraConfig, PeftModel
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from ast import literal_eval

In [14]:
# General parameters
model_name = "mistralai/Mistral-7B-Instruct-v0.2"  # The model that you want to train from the Hugging Face hub
new_model = "mistral-presc-json-generator"  # The name for fine-tuned LoRA Adaptor

In [15]:
# LoRA parameters
lora_r = 64
lora_alpha = lora_r * 2
lora_dropout = 0.1
target_modules = ["q_proj", "v_proj", "k_proj"]

In [16]:
# QLoRA parameters
load_in_4bit = True
bnb_4bit_compute_dtype = "float16"
bnb_4bit_quant_type = "nf4"
bnb_4bit_use_double_quant = False

In [17]:
# TrainingArguments parameters
num_train_epochs = 1
fp16 = False
bf16 = False
per_device_train_batch_size = 4
gradient_accumulation_steps = 1
gradient_checkpointing = True
learning_rate = 0.00015
weight_decay = 0.01
optim = "paged_adamw_32bit"
lr_scheduler_type = "cosine"
max_steps = -1
warmup_ratio = 0.03
group_by_length = True
save_steps = 25
logging_steps = 10

# SFT parameters
max_seq_length = None
packing = False
device_map = {"": 0}

# Dataset parameters
use_special_template = True
response_template = " ### Answer:"
instruction_prompt_template = '"### Human:"'
use_llama_like_model = True

In [18]:
def load_data(file_path):
    percent_of_train_dataset = 0.90
    data = load_dataset("json", data_files=file_path, split="train")

    split_dataset = data.train_test_split(
        train_size=int(data.num_rows * percent_of_train_dataset), seed=19, shuffle=False
    )
    train_dataset = split_dataset["train"]
    eval_dataset = split_dataset["test"]
    print(f"Size of the train set: {len(train_dataset)}. Size of the validation set: {len(eval_dataset)}")
    return train_dataset, eval_dataset

dataset_dict = {"seizure": {}, "prescription": {}}
dataset_dict["prescription"]["train_dataset"], dataset_dict["prescription"]["eval_dataset"] = load_data(
    file_path="/content/prescription.jsonl")

# dataset_dict["seizure"]["train_dataset"], dataset_dict["seizure"]["eval_dataset"] = load_data(
#     file_path="/content/seizure_frequency.jsonl")

Size of the train set: 264. Size of the validation set: 30


In [19]:
# Load LoRA configuration
peft_config = LoraConfig(
    r=lora_r,
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=target_modules,
)

In [20]:
# Load QLoRA configuration
compute_dtype = getattr(torch, bnb_4bit_compute_dtype)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=load_in_4bit,
    bnb_4bit_quant_type=bnb_4bit_quant_type,
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=bnb_4bit_use_double_quant,
)

In [23]:
# Load base model
model = AutoModelForCausalLM.from_pretrained(model_name,
                                             quantization_config=bnb_config, device_map=device_map)
model.config.use_cache = False

config.json:   0%|          | 0.00/596 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/25.1k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.94G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

In [24]:
# Set training parameters
training_arguments = TrainingArguments(
    output_dir=new_model,
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=per_device_train_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    optim=optim,
    save_steps=save_steps,
    logging_steps=logging_steps,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    fp16=fp16,
    bf16=bf16,
    max_steps=max_steps,
    warmup_ratio=warmup_ratio,
    gradient_checkpointing=gradient_checkpointing,
    group_by_length=group_by_length,
    lr_scheduler_type=lr_scheduler_type,
    do_eval=True,
    evaluation_strategy="steps",
    max_steps=250
)



In [25]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"  # Fix weird overflow issue with fp16 training
if not tokenizer.chat_template:
    tokenizer.chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}"

tokenizer_config.json:   0%|          | 0.00/2.10k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

In [26]:
def special_formatting_prompts(example):
    output_texts = []
    for i in range(len(example["instruction"])):
        text = f"{instruction_prompt_template}{example['instruction'][i]}\n{response_template} {example['output'][i]}"
        output_texts.append(text)
    return output_texts


def normal_formatting_prompts(example):
    output_texts = []
    for i in range(len(example["instruction"])):
        chat_temp = [
            {"role": "user", "content": example["instruction"][i]},
            {"role": "assistant", "content": example["output"][i]},
        ]
        text = tokenizer.apply_chat_template(chat_temp, tokenize=False)
        output_texts.append(text)
    return output_texts

In [27]:
if use_special_template:
    formatting_func = special_formatting_prompts
    if use_llama_like_model:
        response_template_ids = tokenizer.encode(response_template, add_special_tokens=False)[2:]
        collator = DataCollatorForCompletionOnlyLM(response_template=response_template_ids, tokenizer=tokenizer)
    else:
        collator = DataCollatorForCompletionOnlyLM(response_template=response_template, tokenizer=tokenizer)
else:
    formatting_func = normal_formatting_prompts

In [28]:
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset_dict["prescription"]["train_dataset"],
    eval_dataset=dataset_dict["prescription"]["eval_dataset"],
    peft_config=peft_config,
    formatting_func=formatting_func,
    data_collator=collator,
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    args=training_arguments,
    packing=packing,
)



Map:   0%|          | 0/264 [00:00<?, ? examples/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Map:   0%|          | 0/30 [00:00<?, ? examples/s]

In [29]:
if torch.cuda.device_count() > 1:
  model.is_parallelizable = True
  model.model_parallel = True

# Train model
trainer.train()

# Save fine tuned Lora Adaptor
trainer.model.save_pretrained(new_model)



<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


Step,Training Loss,Validation Loss
10,0.5766,0.214707
20,0.2216,0.205422
30,0.1848,0.182913
40,0.1785,0.175195
50,0.1531,0.169945
60,0.1597,0.168654


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


In [None]:
# import torch
# import gc


# def clear_hardwares():
#     torch.clear_autocast_cache()
#     torch.cuda.ipc_collect()
#     torch.cuda.empty_cache()
#     gc.collect()


# clear_hardwares()
# clear_hardwares()

In [30]:
def generate(model, prompt: str, kwargs):
    tokenized_prompt = tokenizer(prompt, return_tensors="pt").to(model.device)

    prompt_length = len(tokenized_prompt.get("input_ids")[0])

    with torch.cuda.amp.autocast():
        output_tokens = model.generate(**tokenized_prompt, **kwargs) if kwargs else model.generate(**tokenized_prompt)
        output = tokenizer.decode(output_tokens[0][prompt_length:], skip_special_tokens=True)

    return output

In [32]:
new_model

'mistral-presc-json-generator'

In [31]:
new_model_path = "/content/mistral-presc-json-generator"

In [33]:
ft_model = AutoModelForCausalLM.from_pretrained(new_model_path)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [34]:
peft_model = PeftModel.from_pretrained(ft_model, new_model)
# del base_model

In [35]:
sample = dataset_dict["prescription"]["eval_dataset"][0]
if use_special_template:
    prompt = f"{instruction_prompt_template}{sample['instruction']}\n{response_template}"
else:
    chat_temp = [{"role": "system", "content": sample["instruction"]}]
    prompt = tokenizer.apply_chat_template(chat_temp, tokenize=False, add_generation_prompt=True)

prompt

'"### Human:"\nHere is the clinical text from the doctor in the delimiter.\nclinical text: <<<Dear Dr Pooled  Re: Ms Haana Habley D.O.B 30/01/1972  I reviewed this 41 year old lady with symptomatic epilepsy due to previous neurocysticercosis. She gets frequent focal dyscognitive   seizures   in clusters. Last week she had around 10-15 of these   seizures   over 2 days. There was no obvious provoking factor.  As you know she had an abnormal CT scan in 2000 which showed calcifications consistent with her neurocysticercosis. A recent MRI in 2011 was normal. She has previously tried levetiracetam but it caused mood disturbance and is now taking   lamotrigine   150mg bd. We discussed various treatment options. Mrs Habley does not want any further children and this will leave more options in terms of drug treatment.  We decided to try and increase in the   lamotrigine   in the first instance. Please increase by 25mg every fortnight until she is taking 200mg bd. If this increase is not succes

In [37]:
gen_kwargs = {"max_new_tokens": 256}
generated_texts = generate(model=peft_model, prompt=prompt, kwargs=gen_kwargs)
print(generated_texts)

  with torch.cuda.amp.autocast():
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


{'entity': 'Prescription', 'start_index': '193', 'end_index': '212', 'text': 'lamotrigine-150mg-bd', 'attributes': {'DrugName': 'lamotrigine', 'DrugDose': '150', 'DoseUnit': 'mg', 'Frequency': '2'}}
 {'entity': 'Prescription', 'start_index': '233', 'end_index': '258', 'text': 'zonisamide-25mg-od', 'attributes': {'DrugName': 'zonisamide', 'Dose': '25', 'DoseUnit': 'mg', 'Frequency': '1'}}
 {'entity': 'Prescription', 'start_index': '261', 'end_index': '286', 'text': 'levetiracetam', 'attributes': {'DrugName': 'levetiracetam', 'Dose': 'N/A', 'DoseUnit': 'N/A', 'Frequency': 'N/A'}}
 {'entity': 'Prescription


In [38]:
sample["output"]

{'entity': 'Prescription',
 'start_index': '547',
 'end_index': '558',
 'text': 'lamotrigine',
 'attributes': {'DrugName': 'Lamotrigine',
  'DrugDose': '150',
  'DoseUnit': 'mg',
  'Frequency': '2'}}

In [43]:
from ast import literal_eval

In [44]:
gen_text = literal_eval(generated_texts.split("\n")[0])

In [45]:
gen_text

{'entity': 'Prescription',
 'start_index': '193',
 'end_index': '212',
 'text': 'lamotrigine-150mg-bd',
 'attributes': {'DrugName': 'lamotrigine',
  'DrugDose': '150',
  'DoseUnit': 'mg',
  'Frequency': '2'}}