<a href="https://colab.research.google.com/github/ankraj1234/MediGuide/blob/master/mediguide_qlora.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**IMPORTS AND INSTALLATIONS**

In [None]:
%%capture
%pip install accelerate peft bitsandbytes transformers trl evaluate

In [None]:
import os
import torch
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import LoraConfig, PeftModel
from trl import SFTTrainer

Upload the train dataset

In [None]:
from google.colab import drive
drive.mount("/content/drive")
file_path = '/content/sampled_6000.json'
drive_path = "/content/drive/MyDrive/"

Mounted at /content/drive


In [None]:
import json
try:
    with open(file_path) as f:
        try:
            medical_data = json.load(f)
        except json.JSONDecodeError:

            content = f.read().split('[file content end]')[0].split('[file content begin]')[-1].strip()
            medical_data = json.loads(content)

    print(f"Successfully loaded {len(medical_data)} medical examples")

except FileNotFoundError:
    print(f"Error: File not found at {file_path}")
except Exception as e:
    print(f"An error occurred: {str(e)}")

Successfully loaded 6000 medical examples


In [None]:
print("\nSample data item:")
print(f"Instruction: {medical_data[0]['instruction']}")
print(f"Input: {medical_data[0]['input'][:50]}...")
print(f"Output: {medical_data[0]['output'][:50]}...")


Sample data item:
Instruction: If you are a doctor, please answer the medical questions based on the patient's description.
Input: I wake in the night, usually about 2-3 hours after...
Output: Dear patient Here are the possibilities of what yo...


In [None]:
# Data format for Mistral model
def format_data(sample):
        return {
            "text": f"[MED] {sample['instruction']}\nPatient: {sample['input']}\nDoctor: {sample['output']}"
        }

In [None]:
from datasets import Dataset
dataset = [format_data(d) for d in medical_data]
dataset = Dataset.from_list(dataset)

In [None]:
# define an output directory to save partially trained model so as to resume later
output_dir = "/content/drive/MyDrive/medical_qlora"

In [None]:
def find_latest_checkpoint(output_dir):
    try:
        if not os.path.exists(output_dir):
            print(f"Output directory {output_dir} does not exist")
            return None

        if not os.listdir(output_dir):
            print(f"Output directory {output_dir} is empty")
            return None

        checkpoints = [d for d in os.listdir(output_dir)
                      if d.startswith("checkpoint") and os.path.isdir(os.path.join(output_dir, d))]

        if not checkpoints:
            print("No checkpoint directories found")
            return None

        checkpoints.sort(key=lambda x: int(x.split("-")[1]))
        latest = os.path.join(output_dir, checkpoints[-1])
        print(f"Found checkpoint: {latest}")
        return latest

    except Exception as e:
        print(f"Error finding checkpoint: {e}")
        return None

In [None]:
latest_checkpoint = find_latest_checkpoint(output_dir)
print(f"Latest checkpoint: {latest_checkpoint}")

Found checkpoint: /content/drive/MyDrive/medical_qlora/checkpoint-375
Latest checkpoint: /content/drive/MyDrive/medical_qlora/checkpoint-375


Setting up model and hyperparameters

In [None]:
model_name = "mistralai/Mistral-7B-Instruct-v0.3"
new_model = "mistral-chat-finetune"

# Qlora parameters

lora_r = 64
lora_alpha = 16
lora_dropout = 0.1


use_4bit = True
bnb_4bit_compute_dtype = "float16"
bnb_4bit_quant_type = "nf4"
use_nested_quant = False
num_train_epochs = 1

fp16 = True
bf16 = False

per_device_train_batch_size = 4
gradient_accumulation_steps = 4
gradient_checkpointing = True
max_grad_norm = 0.3


learning_rate = 2e-4
weight_decay = 0.001
optim = "paged_adamw_8bit"
lr_scheduler_type = "cosine"

max_steps = -1


warmup_ratio = 0.03
group_by_length = True
save_strategy="steps",
save_steps=200,
save_total_limit=3,


logging_steps = 25
max_seq_length = 1024
packing = False


device_map = {"": 0}

In [None]:
# Load tokenizer and model with QLoRA configuration
compute_dtype = getattr(torch, bnb_4bit_compute_dtype)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=use_4bit,
    bnb_4bit_quant_type=bnb_4bit_quant_type,
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=use_nested_quant,
)

# Load base model (if training for the first time)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map=device_map
)
model.config.use_cache = False
model.config.pretraining_tp = 1

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# Load LoRA configuration
peft_config = LoraConfig(
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    r=lora_r,
    bias="none",
    task_type="CAUSAL_LM",
)


# Convert any tuple values to single integers/floats
if isinstance(save_steps, tuple):
    save_steps = save_steps[0]
if isinstance(logging_steps, tuple):
    logging_steps = logging_steps[0]
if isinstance(max_steps, tuple):
    max_steps = max_steps[0]
if isinstance(save_total_limit, tuple):
    save_total_limit = save_total_limit[0]
if isinstance(num_train_epochs, tuple):
    num_train_epochs = num_train_epochs[0]
if isinstance(per_device_train_batch_size, tuple):
    per_device_train_batch_size = per_device_train_batch_size[0]
if isinstance(gradient_accumulation_steps, tuple):
    gradient_accumulation_steps = gradient_accumulation_steps[0]
if isinstance(learning_rate, tuple):
    learning_rate = learning_rate[0]
if isinstance(weight_decay, tuple):
    weight_decay = weight_decay[0]
if isinstance(max_grad_norm, tuple):
    max_grad_norm = max_grad_norm[0]
if isinstance(warmup_ratio, tuple):
    warmup_ratio = warmup_ratio[0]



config.json:   0%|          | 0.00/601 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/4.55G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/141k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/587k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.96M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

In [None]:
# If continuing training from some checkpoint
from peft import get_peft_model, PeftModel
model = get_peft_model(model, peft_config)
model = PeftModel.from_pretrained(
    model,
    output_dir,
    revision=latest_checkpoint.split("/")[-1],
    is_trainable=True
)




In [None]:
# Set training parameters
training_arguments = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=per_device_train_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    optim=optim,
    save_strategy="steps",
    save_steps=save_steps,
    save_total_limit=save_total_limit,
    logging_steps=logging_steps,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    fp16=fp16,
    bf16=bf16,
    max_grad_norm=max_grad_norm,
    max_steps=max_steps,
    warmup_ratio=warmup_ratio,
    group_by_length=group_by_length,
    lr_scheduler_type=lr_scheduler_type,
    report_to="tensorboard",
    resume_from_checkpoint = latest_checkpoint,
)

In [None]:
def tokenize_function(examples):
        return tokenizer(
            examples["text"],
            padding="max_length",
            truncation=True,
            max_length=256,
            return_tensors=None
        )

In [None]:
tokenized_dataset = dataset.map(tokenize_function, batched=True)
tokenized_dataset = tokenized_dataset.map(
        lambda examples: {"labels": examples["input_ids"].copy()},
        batched=True
    )

Map:   0%|          | 0/6000 [00:00<?, ? examples/s]

Map:   0%|          | 0/6000 [00:00<?, ? examples/s]

In [None]:
from transformers import DataCollatorForLanguageModeling
data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False
    )

In [None]:
trainer = SFTTrainer(
    model=model,
    peft_config=peft_config,
    args=training_arguments,
    train_dataset=tokenized_dataset,
    data_collator=data_collator
)

Truncating train dataset:   0%|          | 0/6000 [00:00<?, ? examples/s]

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [None]:
trainer.train(resume_from_checkpoint=latest_checkpoint)

Step,Training Loss


TrainOutput(global_step=375, training_loss=0.0, metrics={'train_runtime': 0.0352, 'train_samples_per_second': 170621.54, 'train_steps_per_second': 10663.846, 'total_flos': 6.58120900608e+16, 'train_loss': 0.0})

In [None]:
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)
trainer.model.save_pretrained(output_dir)

Pushing the fully trained model to Huggingface

In [None]:
# Enter hf access token here
from huggingface_hub import login
login("")

In [None]:
merged_model = trainer.model.merge_and_unload()

In [None]:
merged_dir = "merged_mistral_mediguide"
merged_model.save_pretrained(merged_dir)
tokenizer.save_pretrained(merged_dir)

In [None]:
from huggingface_hub import HfApi, HfFolder, create_repo, notebook_login
repo_name = "Greyitis/mediguide_new"
create_repo(repo_name, repo_type="model", exist_ok=True)

In [None]:
merged_model.push_to_hub(repo_name)
tokenizer.push_to_hub(repo_name)

In [None]:
# To push complete fine tuned model instead of lora apdapters on top of model

base = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map=device_map
)
peft_model = PeftModel.from_pretrained(base, output_dir, is_trainable=False)

# 2) Merge LoRA adapters into the base weights
merged_model = peft_model.merge_and_unload()

# 3) Save merged model & tokenizer
merged_dir = "merged_mediguide"
os.makedirs(merged_dir, exist_ok=True)
merged_model.save_pretrained(merged_dir)
tokenizer.save_pretrained(merged_dir)



token = HfFolder.get_token()
repo_id = "Greyitis/mediguide_new"
create_repo(repo_id, repo_type="model", exist_ok=True)

# Upload the entire local directory
api = HfApi()
api.upload_folder(
    folder_path="merged_mistral_mediguide",
    repo_id=repo_id,
    repo_type="model",
    path_in_repo="",
    token=token,
)

print("Everything from merged_mistral_mediguide/ is now in", repo_id)

**INFERENCING**

In [None]:
# Custom stopper to prevent end answer generated at end of useful content
from transformers import StoppingCriteria, StoppingCriteriaList

class StopOnThanks(StoppingCriteria):
    def __init__(self, tokenizer, stop_str="Regards"):
        super().__init__()
        self.tokenizer = tokenizer
        self.stop_str = stop_str

    def __call__(self, input_ids, scores, **kwargs):
        text = self.tokenizer.decode(input_ids[0], skip_special_tokens=True)
        return text.rstrip().endswith(self.stop_str)

def medical_response(patient_input):
    prompt = f"""<|user|>
As a doctor, answer medical questions
Patient: {patient_input}<|end|>
<|assistant|>
"""
    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        return_attention_mask=False
    ).to("cuda")

    stop_criteria = StoppingCriteriaList([StopOnThanks(tokenizer)])

    outputs = model.generate(
        **inputs,
        max_new_tokens=200,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
        repetition_penalty=1.1,
        no_repeat_ngram_size=3,
        pad_token_id=tokenizer.eos_token_id,
        eos_token_id=tokenizer.eos_token_id,
        stopping_criteria=stop_criteria,
    )


    gen = outputs[0][ inputs["input_ids"].shape[-1] : ]
    answer = tokenizer.decode(gen, skip_special_tokens=True).strip()
    if "Thanks" in answer:
        answer = answer[: answer.rfind("Thanks") + len("Thanks") ]
    return answer


In [None]:
print(medical_response("I have fever and body aches"))

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


I'm not a doctor but I can try to help. Your symptoms could be indicative of a viral infection such as the flu or a cold. However, it's important to get professional medical advice for accurate diagnosis and treatment. If you are concerned about your health, please contact a healthcare provider immediately.

Remember to rest, stay hydrated, and take over-the-counter pain relievers to manage your symptoms until you can see a doctor. Avoid spreading germs by washing your hands regularly and covering your mouth when coughing or sneezing.


In [None]:
print(medical_response("A 45-year-old male presents with fatigue, weight gain, and cold intolerance. What are the possible differential diagnoses, and which lab tests would you recommend?"))

Given the symptoms of fatigue (tiredness), weight gain without apparent cause, and Cold Intolerance (sensitivity to cold temperatures), there are several potential conditions that could be causing these symptoms. Here are some possible differential diagnosis options:

1. Hypothyroidism - Underactive thyroid gland leading to decreased metabolism, fatigue and weight gain.
2. Type 2 Diabetes Mellitus - Impaired insulin production or resistance leading to high blood sugar levels, fatique, weight change and cold sensitivity.
3. Cushing's Syndrome - Overproduction of cortisol hormone, resulting in fatigue weight gain and cold tolerance.
4. Anemia - Low red blood cell count can cause fatigue.
5. Chronic Fatigue Syndromes - A complex disorder characterized by prolonged fatigue not relieved by rest.
6. Depression - Can present with


**Evaluation**

In [None]:
from datasets import load_dataset
raw_datasets = load_dataset("Mohammed-Altaf/medical-instruction-120k")
test_split   = raw_datasets["test"].select(range(1000))

README.md:   0%|          | 0.00/965 [00:00<?, ?B/s]

medicare_110k_train.json:   0%|          | 0.00/126M [00:00<?, ?B/s]

medicare_110k_test.json:   0%|          | 0.00/6.60M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/106556 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/5609 [00:00<?, ? examples/s]

In [None]:
print(test_split)

Dataset({
    features: ['Conversation'],
    num_rows: 1000
})


In [None]:
from datasets import load_dataset
raw_datasets = load_dataset("Mohammed-Altaf/medical-instruction-120k")
test_split   = raw_datasets["test"].select(range(1000))

import re
def extract_prompt_response(example):
    """
    Parses the single-string conversation field into:
      - instruction: text between “[|Human|]” and “[|AI|]” (or full text if no markers).
      - response: text after the last “[|AI|]” marker (or empty if none).
    """
    convo = str(example[next(iter(example.keys()))]).strip()
    human_match = re.search(r"\[\|Human\|\]\s*(.*?)\s*(?=\[\|AI\|\])", convo, re.DOTALL)
    instruction = human_match.group(1).strip() if human_match else convo
    parts = re.split(r"\[\|AI\|\]", convo)
    response = parts[-1].strip() if len(parts) > 1 else ""
    return {"instruction": instruction, "response": response}

test_df = test_split.map(
    extract_prompt_response,
    remove_columns=test_split.column_names,
    num_proc=4
)
test_prompts = test_df["instruction"]
test_references = test_df["response"]

Map (num_proc=4):   0%|          | 0/1000 [00:00<?, ? examples/s]

In [None]:
test_prompts[0]

'I wake in the night, usually about 2-3 hours after going to sleep, with both feet and legs to mid calf feeling like they are on fire. slight red discolorization, minor swelling. This is very painful but after getting up, I can walk it off in about 30 minutes.'

In [None]:
test_references[0]

'Dear patient Here are the possibilities of what you might have.1)PhlebitisPhlebitis means inflammation of the veins, and can cause redness, itching, irritation, pain, and swelling. A simple Doppler can rule this out.2Blood clot in the lifeblood clots in the leg can become very dangerous, symptoms include swelling, redness, tenderness in the leg. Coagulation profile with an angiography may be required3)Cellulitis: Initial stage. Only can be clinically ruled out Hope this helped'

In [None]:
# 1. PERPLEXITY

from torch.utils.data import DataLoader
from transformers import DataCollatorForLanguageModeling


checkpoint_dir = "/content/drive/MyDrive/medical_qlora"
eval_ckpt_path = "/content/drive/MyDrive/medical_qlora/test_eval_state.pth"
batch_size = 2
save_every_n_batches = 50


class LMTestDataset(torch.utils.data.Dataset):
    def __init__(self, texts, tokenizer, max_length=1024):
        encodings = tokenizer(
            texts,
            return_tensors="pt",
            max_length=max_length,
            truncation=True,
            padding="max_length"
        )
        self.input_ids = encodings["input_ids"]


    def __len__(self):
        return self.input_ids.size(0)

    def __getitem__(self, idx):
        return {
            "input_ids": self.input_ids[idx],
            "labels": self.input_ids[idx].clone()
        }


test_texts = [
    f"{instr}\n\n{resp}"
    for instr, resp in zip(test_prompts, test_references)
]
lm_test_dataset = LMTestDataset(test_texts, tokenizer, max_length=1024)


data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False
)
test_loader = DataLoader(
    lm_test_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=data_collator
)

device = model.device


if os.path.isfile(eval_ckpt_path):

    state = torch.load(eval_ckpt_path)
    start_batch = state["last_batch"] + 1
    accumulated_loss = state["accumulated_loss"]
    total_tokens = state["total_tokens"]
    print(f"Resuming test-eval from batch {start_batch} (saved on disk).")
else:

    start_batch = 0
    accumulated_loss = 0.0
    total_tokens = 0
    print("Starting test-eval from batch 0.")


model.eval()
with torch.no_grad():
    for batch_idx, batch in enumerate(test_loader):
        if batch_idx < start_batch:
            continue


        input_ids = batch["input_ids"].to(device)
        labels = batch["labels"].to(device)


        outputs = model(
            input_ids=input_ids,
            labels=labels
        )
        loss = outputs.loss.detach().cpu().item()


        nonpad_tokens = (labels != tokenizer.pad_token_id).sum().item()


        accumulated_loss += loss * nonpad_tokens
        total_tokens += nonpad_tokens


        if (batch_idx + 1) % save_every_n_batches == 0:
            state = {
                "last_batch": batch_idx,
                "accumulated_loss": accumulated_loss,
                "total_tokens": total_tokens
            }
            torch.save(state, eval_ckpt_path)
            print(f" Saved eval state at batch {batch_idx} → tokens={total_tokens}")

    final_avg_loss = accumulated_loss / total_tokens
    test_perplexity = torch.exp(torch.tensor(final_avg_loss)).item()
    os.remove(eval_ckpt_path)

    print(f"\n→ Test complete. Avg. token-loss = {final_avg_loss:.4f}")
    print(f"→ Test Perplexity = {test_perplexity:.2f}")

→ Test complete. Avg. token-loss = 2.7282
→ Test Perplexity = 15.30

In [None]:
# 2. LATENCY

import time

n_samples = min(50, len(test_prompts))
max_new_tokens = 128
batch_size = 4


tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token

latencies = []
model.eval()
device = model.device


dummy_input = tokenizer("Hello", return_tensors="pt").to(device)
_ = model.generate(
    **dummy_input,
    max_new_tokens=10,
    do_sample=False,
    use_cache=True,
    return_dict_in_generate=False
)
torch.cuda.synchronize()


for i in range(0, n_samples, batch_size):
    batch_prompts = test_prompts[i : i + batch_size]
    inputs = tokenizer(
        batch_prompts,
        return_tensors="pt",
        truncation=True,
        padding=True
    ).to(device)

    torch.cuda.synchronize()
    start = time.time()
    _ = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        use_cache=True,
        return_dict_in_generate=False
    )
    torch.cuda.synchronize()
    end = time.time()


    elapsed = end - start
    latencies.append(elapsed / len(batch_prompts))

avg_latency = sum(latencies) / len(latencies)
print(f"Average Latency (per prompt, {max_new_tokens} new tokens): {avg_latency:.4f} seconds")

Average Latency (per prompt, 128 new tokens): 7.2467 seconds

In [None]:
# 3. MODEL SIZE ON DISK

def folder_size_in_mb(path: str) -> float:
    total_bytes = 0
    for root, _, files in os.walk(path):
        for fname in files:
            fp = os.path.join(root, fname)
            total_bytes += os.path.getsize(fp)
    return total_bytes / (1024 ** 2)

model_size_mb = folder_size_in_mb(latest_checkpoint)
print(f"Model Size on Disk (under “{latest_checkpoint}”): {model_size_mb:.1f} MB")

Model Size on Disk (under “/content/drive/MyDrive/medical_qlora/checkpoint-375”): 161.2 MB

In [None]:
# Generation for ROUGE calculation
import json
from tqdm import tqdm


def optimized_medical_generation(
    model, tokenizer, prompts, references,
    checkpoint_path, batch_size=8, max_length=512
):


    if os.path.isdir(checkpoint_path):
        os.makedirs(checkpoint_path, exist_ok=True)
        state_file = os.path.join(checkpoint_path, "generation_resume.json")
    else:
        state_file = checkpoint_path


    if os.path.exists(state_file):
        with open(state_file, 'r') as f:
            state = json.load(f)
        completed_indices = state.get('completed_indices', [])
        predictions = state.get('predictions', [])
        print(f"Resuming from {len(completed_indices)} completed samples")
    else:
        completed_indices = []
        predictions = []
        state = {'completed_indices': completed_indices, 'predictions': predictions}

    completed_set = set(completed_indices)
    remaining_indices = [i for i in range(len(prompts)) if i not in completed_set]

    if not remaining_indices:
        print("All samples already processed!")
        return predictions, [references[i] for i in completed_indices]

    model.eval()
    device = model.device


    tokenizer.padding_side = "left"
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token


    with torch.no_grad():
        emb = model.get_input_embeddings().weight
        emb[tokenizer.pad_token_id].zero_()


    for batch_start in tqdm(range(0, len(remaining_indices), batch_size),
                            desc="Medical Generation"):
        batch_end     = min(batch_start + batch_size, len(remaining_indices))
        batch_indices = remaining_indices[batch_start:batch_end]
        batch_prompts = [prompts[i] for i in batch_indices]


        inputs = tokenizer(
            [f"MEDICAL PROMPT: {p}" for p in batch_prompts],
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_length,
            return_attention_mask=True,
        ).to(device)

        with torch.no_grad():
            outputs = model.generate(
                input_ids      = inputs["input_ids"],
                attention_mask = inputs["attention_mask"],
                pad_token_id   = tokenizer.pad_token_id,
                max_new_tokens = 256,
                do_sample      = False,
                num_beams      = 1,
                use_cache      = True,
            )

        decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)


        batch_preds = []
        for prompt, text in zip(batch_prompts, decoded):
            prefix = f"MEDICAL PROMPT: {prompt}"
            if text.startswith(prefix):
                gen_text = text[len(prefix):].strip()
            else:
                gen_text = text.strip()
            batch_preds.append(gen_text)

        predictions.extend(batch_preds)
        completed_indices.extend(batch_indices)


        with open(state_file, 'w') as f:
            json.dump({
                'completed_indices': completed_indices,
                'predictions': predictions
            }, f)

    return predictions, [references[i] for i in completed_indices]


In [None]:
predictions, processed_refs = optimized_medical_generation(
    model,
    tokenizer,
    test_prompts,
    test_references,
    latest_checkpoint,
    batch_size=12,
    max_length=120
)

Resuming from 1000 completed samples
All samples already processed!


In [None]:
# 4. ROUGE score
!pip install rouge_score
from evaluate import load
rouge = load("rouge")
results = rouge.compute(
        predictions=predictions,
        references=processed_refs,
        use_stemmer=True,
        use_aggregator=True
    )

Collecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: rouge_score
  Building wheel for rouge_score (setup.py) ... [?25l[?25hdone
  Created wheel for rouge_score: filename=rouge_score-0.1.2-py3-none-any.whl size=24934 sha256=8680567035a0fe0aebd7875e4bd6a4ae636c080056b2461c4032b88e3139778b
  Stored in directory: /root/.cache/pip/wheels/1e/19/43/8a442dc83660ca25e163e1bd1f89919284ab0d0c1475475148
Successfully built rouge_score
Installing collected packages: rouge_score
Successfully installed rouge_score-0.1.2


Downloading builder script: 0.00B [00:00, ?B/s]

In [None]:
print("\nMedical ROUGE Scores:")
print(f"ROUGE-1: {results['rouge1']:.4f}")
print(f"ROUGE-2: {results['rouge2']:.4f}")
print(f"ROUGE-L: {results['rougeL']:.4f}")
print(f"ROUGE-Lsum: {results['rougeLsum']:.4f}")


Medical ROUGE Scores:
ROUGE-1: 0.2398
ROUGE-2: 0.0307
ROUGE-L: 0.1212
ROUGE-Lsum: 0.1469


In [None]:
with open("/content/drive/MyDrive/medical_rouge_results.json", "w") as f:
        json.dump({
            "scores": results,
            "predictions": predictions,
            "references": processed_refs
        }, f)