## Initial Setup

In [None]:
pip install trl

Collecting trl
  Downloading trl-0.16.1-py3-none-any.whl.metadata (12 kB)
Collecting datasets>=3.0.0 (from trl)
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets>=3.0.0->trl)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets>=3.0.0->trl)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets>=3.0.0->trl)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets>=3.0.0->trl)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.0.0->accelerate>=0.34.0->trl)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.0.0->accel

In [None]:
pip install datasets



In [None]:
# Install required packages
import subprocess
import sys

def install_packages():
    packages = [
        "bitsandbytes>=0.41.1",
        "transformers>=4.35.0",
        "peft>=0.6.0",
        "accelerate>=0.23.0",
        "datasets>=2.14.0",
        "trl>=0.7.2",
        "scipy>=1.11.3",
        "sentencepiece>=0.1.99",
        "protobuf>=4.23.4",
        "einops>=0.7.0"
    ]

    print("Installing required packages...")
    for package in packages:
        print(f"Installing {package}")
        subprocess.check_call([sys.executable, "-m", "pip", "install", package])
    print("All packages installed successfully!")

In [None]:
# Run package installation
install_packages()

import os
import torch
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments
)
from peft import (
    prepare_model_for_kbit_training,
    LoraConfig,
    get_peft_model,
    TaskType
)
from trl import SFTTrainer

# Configuration
MODEL_ID = "mistralai/Mistral-7B-v0.1"
DATASET_ID = "LinhDuong/chatdoctor-200k"
OUTPUT_DIR = "./lora_medical_adapter"
LORA_R = 8
LORA_ALPHA = 16
LORA_DROPOUT = 0.05
BATCH_SIZE = 4
GRADIENT_ACCUMULATION_STEPS = 4
LEARNING_RATE = 2e-4
NUM_TRAIN_EPOCHS = 1
MAX_SEQ_LENGTH = 512
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
os.makedirs(OUTPUT_DIR, exist_ok=True)

Installing required packages...
Installing bitsandbytes>=0.41.1
Installing transformers>=4.35.0
Installing peft>=0.6.0
Installing accelerate>=0.23.0
Installing datasets>=2.14.0
Installing trl>=0.7.2
Installing scipy>=1.11.3
Installing sentencepiece>=0.1.99
Installing protobuf>=4.23.4
Installing einops>=0.7.0
All packages installed successfully!


In [None]:
pip install -U bitsandbytes



In [None]:
!pip install -U trl



## Dataset Prep

In [None]:
# Load the dataset
dataset = load_dataset(DATASET_ID)
print(f"Dataset loaded: {dataset}")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/273 [00:00<?, ?B/s]

chatdoctor200k.json:   0%|          | 0.00/226M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/207408 [00:00<?, ? examples/s]

Dataset loaded: DatasetDict({
    train: Dataset({
        features: ['input', 'output', 'instruction'],
        num_rows: 207408
    })
})


In [None]:
dataset['train'][0]

{'input': 'i had what feels like a muscle cramp about an hour ago under the left bottom rib. it lasted about a minute and then went away. i had no other pains or dificulties since. could this have been a simptom of a minor heart attack. do heart attack symptoms come one at a time or are there more than one symptom when they occur?',
 'output': 'No this is not a symptom of great attack.... it is normal after some stressful activity. No need to worry. If same thing happen again let me know',
 'instruction': "If you are a doctor, please answer the medical questions based on the patient's description."}

In [None]:

def format_instruction(example):
    """Format the example into an instruction format suitable for fine-tuning."""

    user_input = example["input"]
    assistant_response = example["output"]

    formatted_text = f"""<|im_start|>user
{user_input}<|im_end|>
<|im_start|>assistant
{assistant_response}<|im_end|>"""

    return {"text": formatted_text}

first_split = "train"

dataset[first_split] = dataset[first_split].select(range(10000))

formatted_dataset = dataset[first_split].map(
    format_instruction,
    remove_columns=dataset[first_split].column_names
)

train_val_split = formatted_dataset.train_test_split(test_size=0.1)
train_dataset = train_val_split["train"]
val_dataset = train_val_split["test"]

print(f"Training examples: {len(train_dataset)}")
print(f"Validation examples: {len(val_dataset)}")

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Training examples: 9000
Validation examples: 1000


In [None]:
train_dataset[0]

{'text': '<|im_start|>user\nI recently had a rusty nail puncture my hand between my fingers quite deeply, very little blood when it happened but I could see down inside the puncture wound the bits inside. I went to the hospital where they washed my hand in iodine and glued me back together also had my first tetanus jab. now the swelling has mostly subsided I am experiencing shooting pains up along my fingers and a crunching sound in my jaw. the pain in my fingers is quite substantial. do I need an xray to see if any damage has been done<|im_end|>\n<|im_start|>assistant\nHi, I value your concern regarding the symptoms. I have gone through your symptoms, and in my opinion you should first take Tab Tramadol 100\xa0mg to relieve you of your pain, then an X-Ray is a must to see if there is any phalanges fracture. Also I would recommend a short course of antibiotics to prevent any infection from the injury. Hope this answers your question. If you have additional questions or follow-up questi

## Loading Base model

In [None]:
!huggingface-cli login

In [None]:

compute_dtype = torch.float16

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

try:
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        quantization_config=bnb_config,
        device_map="auto"
    )
except Exception as e:
    print(f"Error loading model with BitsAndBytes: {e}")
    print("Trying to load without quantization...")
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.float16,
        device_map="auto"
    )

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

try:
    model = prepare_model_for_kbit_training(model)
except Exception as e:
    print(f"Error preparing model for kbit training: {e}")
    print("Continuing without prepare_model_for_kbit_training...")

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/25.1k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.94G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/996 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

In [None]:
# Applying LoRA to the model

config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, config)

In [None]:
import torch

def format_prompt(example):
    return f"<|im_start|>user\n{example['input']}<|im_end|>\n<|im_start|>assistant\n"

num_samples = 5
model.eval()

print("\nTesting model on few samples before fine-tuning...\n")

for i in range(num_samples):
    example = dataset['train'][i]
    prompt = format_prompt(example)

    input_ids = tokenizer(prompt, return_tensors='pt', padding=True).input_ids.cuda()

    with torch.no_grad():
        output_ids = model.generate(
            input_ids=input_ids,
            max_new_tokens=100,
            do_sample=True,
            temperature=0.7,
            top_p=0.9
        )

    response = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    print(f"\n--- Example {i+1} ---")
    print(f"User Input:\n{prompt}\n")
    print(f"Human Reference:\n{example['output']}\n")
    print(f"Model Response:\n{response}\n")
    print("-" * 60)


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.



Testing model on few samples before fine-tuning...



The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



--- Example 1 ---
User Input:
<|im_start|>user
i had what feels like a muscle cramp about an hour ago under the left bottom rib. it lasted about a minute and then went away. i had no other pains or dificulties since. could this have been a simptom of a minor heart attack. do heart attack symptoms come one at a time or are there more than one symptom when they occur?<|im_end|>
<|im_start|>assistant


Human Reference:
No this is not a symptom of great attack.... it is normal after some stressful activity. No need to worry. If same thing happen again let me know

Model Response:
<|im_start|>user
i had what feels like a muscle cramp about an hour ago under the left bottom rib. it lasted about a minute and then went away. i had no other pains or dificulties since. could this have been a simptom of a minor heart attack. do heart attack symptoms come one at a time or are there more than one symptom when they occur?<|im_end|>
<|im_start|>assistant
Hello user,

It is not common for a heart att

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



--- Example 2 ---
User Input:
<|im_start|>user
I woke up this morning feeling the whole room is spinning when i was sitting down. I went to the bathroom walking unsteadily, as i tried to focus i feel nauseous. I try to vomit but it wont come out.. After taking panadol and sleep for few hours, i still feel the same.. By the way, if i lay down or sit down, my head do not spin, only when i want to move around then i feel the whole world is spinning.. And it is normal stomach discomfort at the same time? Earlier after i relieved myself, the spinning lessen so i am not sure whether its connected or coincidences.. Thank you doc!<|im_end|>
<|im_start|>assistant


Human Reference:
Hi, Thank you for posting your query. The most likely cause for your symptoms is benign paroxysmal positional vertigo (BPPV), a type of peripheral vertigo. In this condition, the most common symptom is dizziness or giddiness, which is made worse with movements. Accompanying nausea and vomiting are common. The condit

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



--- Example 3 ---
User Input:
<|im_start|>user
My baby has been pooing 5-6 times a day for a week. In the last few days it has increased to 7 and they are very watery with green stringy bits in them. He does not seem unwell i.e no temperature and still eating. He now has a very bad nappy rash from the pooing ...help!<|im_end|>
<|im_start|>assistant


Human Reference:
Hi... Thank you for consulting in Chat Doctor. It seems your kid is having viral diarrhea. Once it starts it will take 5-7 days to completely get better. Unless the kids having low urine output or very dull or excessively sleepy or blood in motion or green bilious vomiting...you need not worry. There is no need to use antibiotics unless there is blood in the motion. Antibiotics might worsen if unnecessarily used causing antibiotic associated diarrhea. I suggest you use zinc supplements (Z&D Chat Doctor. 

Model Response:
<|im_start|>user
My baby has been pooing 5-6 times a day for a week. In the last few days it has incre

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



--- Example 4 ---
User Input:
<|im_start|>user
my sone has left sided abd pain..lt pelvic and rt pelvic pain in the groin area.. can only stand for a short time and sitting impossible for the pain..ct showed inlarged lymph nodes..slight elevation in wt count. he has been 10 weeks and unable to get any relief other than rest and pain meds.. he is 32 6ft 3..approx 220 lb<|im_end|>
<|im_start|>assistant


Human Reference:
Hi. If there is no relief, it is mandatory to get the biopsy of the node done ASAP to get a correct diagnosis. It may be a serious problem. As a rule - anything not getting OK within 1 to 3 weeks maximum should be removed and tested. Noe-a-days we can do this by laparoscopy too, one day care surgery...

Model Response:
<|im_start|>user
my sone has left sided abd pain..lt pelvic and rt pelvic pain in the groin area.. can only stand for a short time and sitting impossible for the pain..ct showed inlarged lymph nodes..slight elevation in wt count. he has been 10 weeks and 

## LoRA Fine tuning

In [None]:
print(f"Trainable parameters: {model.print_trainable_parameters()}")

trainable params: 6,815,744 || all params: 7,248,547,840 || trainable%: 0.0940
Trainable parameters: None


In [None]:
from peft import PeftModel
model = PeftModel.from_pretrained(model, "vaibhav1/lora-mistral-medical",inference_mode=False,is_trainable=True)

In [None]:

training_args = TrainingArguments(
    output_dir="vaibhav1/lora-mistral-medical",
    max_steps=1000,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    optim="adamw_torch",
    save_strategy="steps",
    save_steps=20,
    logging_steps=20,
    learning_rate=LEARNING_RATE,
    weight_decay=0.001,
    fp16=True,
    bf16=False,
    eval_strategy="steps",
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    group_by_length=True,
    lr_scheduler_type="cosine",
    report_to="tensorboard",
    eval_steps=20,
    push_to_hub=True,
    hub_model_id="vaibhav1/lora-mistral-medical",
    hub_strategy="every_save",
    label_names=["labels"]
)

trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    args=training_args,
    peft_config=config,
)


In [None]:
 # Training..
print("Starting training...")
trainer.train(resume_from_checkpoint=True)

print(f"Saving model to {OUTPUT_DIR}")
trainer.save_model(OUTPUT_DIR)
print("Training completed!")

Starting training...


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
  return fn(*args, **kwargs)


Step,Training Loss,Validation Loss
20,2.3165,2.187199
40,2.1199,2.094706
60,2.1072,2.088509
80,2.1055,2.064076
100,1.9905,2.079358
120,2.0935,2.046397
140,2.0186,2.038318
160,2.0396,2.041651
180,2.0305,2.01986
200,1.9469,2.026335


  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)


KeyboardInterrupt: 