In [8]:
! pip install datasets



In [9]:
import os
import torch
import pandas as pd
import wandb
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForLanguageModeling
from peft import LoraConfig, get_peft_model

# # Initialize Weights & Biases
# wandb.login()  # Ensure you're logged into WandB
# wandb.init(project="medical-chatbot", name="fine-tuning", config={})
os.environ["WANDB_DISABLED"] = "true"
# ---------------------------
# Step 1: Prepare Dummy Data
# ---------------------------
# Load and sample dataset
df = pd.read_csv("/kaggle/input/processed-medquad2/processed_medquad.csv")
df = df.sample(frac=1.0, random_state=42).reset_index(drop=True)
df_small = df.iloc[:100]  # ✅ Use small sample (change number if needed)

# Convert to Hugging Face Dataset
dataset = Dataset.from_pandas(df_small)

# Train/test split
split = dataset.train_test_split(test_size=0.2)
train_dataset = split["train"]
test_dataset = split["test"]


# ---------------------------
# Step 2: Model and Tokenizer Setup
# ---------------------------
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
OUTPUT_DIR = "./models/medical_chatbot_finetuned"
USE_QLORA = False  # Set to False for LoRA, True for QLoRA if you want quantized weights

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token  # Set padding token

# ---------------------------
# Step 3: Prepare Dataset
# ---------------------------
def format_instruction(example):
    """Format each example into an instruction format for medical Q&A."""
    question = example["Question_Sentences"]
    answer = example["Answer_Sentences"]

    instruction = (
        f"### Instruction:\nAnswer the following medical question in a concise and accurate manner.\n\n"
        f"### Question:\n{question}\n\n"
        f"### Answer:\n{answer}"
    )
    return {"formatted_text": instruction}

print("Formatting dataset...")
train_formatted = train_dataset.map(format_instruction)
test_formatted = test_dataset.map(format_instruction)

def tokenize_function(examples):
    """Tokenize the dataset."""
    return tokenizer(
        examples["formatted_text"],
        truncation=True,
        max_length=512,
        padding="max_length",
    )

print("Tokenizing dataset...")
train_tokenized = train_formatted.map(
    tokenize_function,
    batched=True,
    remove_columns=train_formatted.column_names
)

test_tokenized = test_formatted.map(
    tokenize_function,
    batched=True,
    remove_columns=test_formatted.column_names
)

# ---------------------------
# Step 4: Apply LoRA or QLoRA
# ---------------------------
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map="auto")

# Apply LoRA Configuration
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

# ---------------------------
# Step 5: Training Setup
# ---------------------------
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=3,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    warmup_steps=50,
    logging_dir='./logs',
    logging_steps=10,
    eval_steps=100,  # Evaluate every 100 steps
    save_steps=100,  # Save the model every 100 steps
    save_total_limit=1,
    fp16=True,  # Use if you're on a GPU that supports it
)


data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# ---------------------------
# Step 6: Trainer Setup
# ---------------------------
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_tokenized,
    eval_dataset=test_tokenized,
    tokenizer=tokenizer,
    data_collator=data_collator
)

# ---------------------------
# Step 7: Train the Model
# ---------------------------
trainer.train()

# ---------------------------
# Step 8: Save the Model
# ---------------------------
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)

# # Log final model information to WandB
# wandb.log({"model_output_dir": OUTPUT_DIR})

print("✅ Fine-tuning complete and model saved.")


tokenizer_config.json:   0%|          | 0.00/1.29k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/551 [00:00<?, ?B/s]

Formatting dataset...


Map:   0%|          | 0/80 [00:00<?, ? examples/s]

Map:   0%|          | 0/20 [00:00<?, ? examples/s]

Tokenizing dataset...


Map:   0%|          | 0/80 [00:00<?, ? examples/s]

Map:   0%|          | 0/20 [00:00<?, ? examples/s]

config.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


trainable params: 1,126,400 || all params: 1,101,174,784 || trainable%: 0.1023


  trainer = Trainer(
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


Step,Training Loss
10,1.6278
20,1.7088
30,1.6671
40,1.6518
50,1.6097
60,1.5815


✅ Fine-tuning complete and model saved.


In [10]:
# Evaluate the fine-tuned model on the test dataset
eval_results = trainer.evaluate(eval_dataset=test_tokenized)

# Print the evaluation results (loss, metrics)
print(f"Evaluation results: {eval_results}")


Evaluation results: {'eval_loss': 1.5870946645736694, 'eval_runtime': 3.4051, 'eval_samples_per_second': 5.873, 'eval_steps_per_second': 1.468, 'epoch': 3.0}


In [11]:
!pip install evaluate

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




In [12]:
!pip install sacrebleu

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




In [13]:
!pip install rouge_score

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




In [14]:
import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from evaluate import load
import sacrebleu
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer


In [15]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)


def generate_answer(question):
    prompt = (
        f"### Instruction:\nAnswer the following medical question in a concise and accurate manner.\n\n"
        f"### Question:\n{question}\n\n### Answer:"
    )
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    outputs = model.generate(inputs["input_ids"], max_length=256)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

sample_q = "How to diagnose Crimean-Congo Hemorrhagic Fever (CCHF)?"

print("🔍 Sample Question:", sample_q)
print("💬 Model Answer:", generate_answer(sample_q))


# Ensure NLTK tokenizer is ready
nltk.download("punkt")

# Expected answer (ground truth)
expected_answer = """Laboratory tests that are used to diagnose CCHF include antigen-capture enzyme-linked immunosorbent assay (ELISA), real time polymerase chain reaction (RT-PCR), virus isolation attempts, and detection of antibody by ELISA (IgG and IgM). Laboratory diagnosis of a patient with a clinical history compatible with CCHF can be made during the acute phase of the disease by using the combination of detection of the viral antigen (ELISA antigen capture), viral RNA sequence (RT-PCR) in the blood or in tissues collected from a fatal case and virus isolation. Immunohistochemical staining can also show evidence of viral antigen in formalin-fixed tissues. Later in the course of the disease, in people surviving, antibodies can be found in the blood. But antigen, viral RNA and virus are no more present and detectable."""

# Generate answer from fine-tuned model
sample_q = "How to diagnose Crimean-Congo Hemorrhagic Fever (CCHF)?"
generated_answer = generate_answer(sample_q)

# Clean decoded output
generated_answer_clean = generated_answer.replace(sample_q, "").replace("### Instruction:", "").replace("### Question:", "").replace("### Answer:", "").strip()

print("🔍 Sample Question:", sample_q)
print("💬 Generated Answer:\n", generated_answer_clean)
print("📘 Expected Answer:\n", expected_answer)

# Tokenization
ref_tokens = nltk.word_tokenize(expected_answer.lower())
gen_tokens = nltk.word_tokenize(generated_answer_clean.lower())

# BLEU Scores
smoothing = SmoothingFunction().method1
print("\n📊 BLEU Scores:")
print("BLEU-1:", sentence_bleu([ref_tokens], gen_tokens, weights=(1, 0, 0, 0), smoothing_function=smoothing))
print("BLEU-2:", sentence_bleu([ref_tokens], gen_tokens, weights=(0.5, 0.5, 0, 0), smoothing_function=smoothing))
print("BLEU-3:", sentence_bleu([ref_tokens], gen_tokens, weights=(0.33, 0.33, 0.33, 0), smoothing_function=smoothing))
print("BLEU-4:", sentence_bleu([ref_tokens], gen_tokens, weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=smoothing))

# ROUGE
rouge = load("rouge")
rouge_score = rouge.compute(predictions=[generated_answer_clean], references=[expected_answer])
print("\n📊 ROUGE Score:", rouge_score)

# F1 Score (word-level overlap)
def f1_from_tokens(ref_tokens, gen_tokens):
    ref_set = set(ref_tokens)
    gen_set = set(gen_tokens)
    true_positives = len(ref_set & gen_set)
    precision = true_positives / len(gen_set) if gen_set else 0
    recall = true_positives / len(ref_set) if ref_set else 0
    return 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

f1_score = f1_from_tokens(ref_tokens, gen_tokens)
print("📊 F1 Score:", round(f1_score, 4))

# SacreBLEU
sacrebleu_score = sacrebleu.corpus_bleu([generated_answer_clean], [[expected_answer]])
print("📊 SacreBLEU Score:", round(sacrebleu_score.score, 2))


🔍 Sample Question: How to diagnose Crimean-Congo Hemorrhagic Fever (CCHF)?
💬 Model Answer: ### Instruction:
Answer the following medical question in a concise and accurate manner.

### Question:
How to diagnose Crimean-Congo Hemorrhagic Fever (CCHF)?

### Answer:
1. Blood tests: Blood tests can be done to check for the presence of antibodies to the virus.

2. Serology: Serology is a blood test that can detect antibodies to the virus.

3. PCR: Polymerase chain reaction (PCR) is a test that can detect the virus's RNA.

4. Viral culture: Viral culture is a test that can detect the virus's DNA.

5. Immunofluorescence: Immunofluorescence is a test that can detect the virus's antibodies.

6. Viral isolation: Viral isolation is a test that can detect the virus's DNA.

7. Viral nucleic acid amplification: Viral nucleic acid amplification is a test that can detect the virus's RNA.

8. Viral antigen detection:


[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


🔍 Sample Question: How to diagnose Crimean-Congo Hemorrhagic Fever (CCHF)?
💬 Generated Answer:
 Answer the following medical question in a concise and accurate manner.





1. Blood tests: Blood tests can be done to check for the presence of antibodies to the virus.

2. Serology: Serology is a blood test that can detect antibodies to the virus.

3. PCR: Polymerase chain reaction (PCR) is a test that can detect the virus's RNA.

4. Viral culture: Viral culture is a test that can detect the virus's DNA.

5. Immunofluorescence: Immunofluorescence is a test that can detect the virus's antibodies.

6. Viral isolation: Viral isolation is a test that can detect the virus's DNA.

7. Viral nucleic acid amplification: Viral nucleic acid amplification is a test that can detect the virus's RNA.

8. Viral antigen detection:
📘 Expected Answer:
 Laboratory tests that are used to diagnose CCHF include antigen-capture enzyme-linked immunosorbent assay (ELISA), real time polymerase chain reaction (RT-PC

Downloading builder script:   0%|          | 0.00/6.27k [00:00<?, ?B/s]


📊 ROUGE Score: {'rouge1': 0.32812500000000006, 'rouge2': 0.031496062992125984, 'rougeL': 0.16406250000000003, 'rougeLsum': 0.22656250000000003}
📊 F1 Score: 0.3504
📊 SacreBLEU Score: 1.9


In [16]:
pip install transformers datasets wandb


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Note: you may need to restart the kernel to use updated packages.


In [17]:
pip install --upgrade transformers


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Collecting transformers
  Downloading transformers-4.51.3-py3-none-any.whl.metadata (38 kB)
Downloading transformers-4.51.3-py3-none-any.whl (10.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.4/10.4 MB[0m [31m70.0 MB/s[0m eta [36m0:00:00[0m:00:01[0m:01[0mm
[?25hInstalling collected packages: transformers
  Attempting uninstall: transformers
    Found existing installation: transformers 4.51.1
    Uninstalling transformers-4.51.1:
      Successfully uninstalled transformers-4.51.1
Successfully installed transformers-4.51.3
Note: you may need to restart the kernel to use updated packages.
