# Importing Libraries

In [None]:
%%capture
%pip install accelerate peft bitsandbytes transformers trl evaluate datasets

In [None]:
import os
import torch
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)

2025-06-25 05:02:58.947910: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1750827779.081006      35 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1750827779.118026      35 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


# Preparing Dataset

In [None]:
from google.colab import drive
drive.mount("/content/drive")
file_path = '/content/drive/MyDrive/MediGuideDataset/sampled_6000.json'
drive_path = "/content/drive/MyDrive/"

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
if not os.path.exists(file_path):
    print(f"Error: File not found at {file_path}")
    print("Available files in directory:")
    print(os.listdir(drive_path))
else:
    print(f"Found file at {file_path}")

Found file at /content/drive/MyDrive/MediGuideDataset/sampled_6000.json


In [None]:
import json
try:
    with open(file_path) as f:
        try:
            medical_data = json.load(f)
        except json.JSONDecodeError:

            content = f.read().split('[file content end]')[0].split('[file content begin]')[-1].strip()
            medical_data = json.loads(content)

    print(f"Successfully loaded {len(medical_data)} medical examples")

except FileNotFoundError:
    print(f"Error: File not found at {file_path}")
except Exception as e:
    print(f"An error occurred: {str(e)}")

Successfully loaded 6000 medical examples


In [None]:
def format_data(sample):
    instruction = sample.get("instruction", "").strip()
    input_text = sample.get("input", "").strip()
    output_text = sample.get("output", "").strip()

    return {
        "text": f"[MED] {instruction}\nPatient: {input_text}\nDoctor: {output_text}"
    }

In [None]:
from datasets import Dataset
dataset = [format_data(d) for d in medical_data]
dataset = Dataset.from_list(dataset)

In [None]:
output_dir = "/content/drive/MyDrive/medical_Adapter"

In [None]:
def find_latest_checkpoint(output_dir):
    try:

        if not os.path.exists(output_dir):
            print(f"Output directory {output_dir} does not exist")
            return None


        if not os.listdir(output_dir):
            print(f"Output directory {output_dir} is empty")
            return None


        checkpoints = [d for d in os.listdir(output_dir)
                      if d.startswith("checkpoint") and os.path.isdir(os.path.join(output_dir, d))]

        if not checkpoints:
            print("No checkpoint directories found")
            return None


        checkpoints.sort(key=lambda x: int(x.split("-")[1]))
        latest = os.path.join(output_dir, checkpoints[-1])
        print(f"Found checkpoint: {latest}")
        return latest

    except Exception as e:
        print(f"Error finding checkpoint: {e}")
        return None

In [None]:
latest_checkpoint = find_latest_checkpoint(output_dir)
print(f"Latest checkpoint: {latest_checkpoint}")

Output directory /content/drive/MyDrive/medical_Adapter is empty
Latest checkpoint: None


# Model Training

In [None]:
from huggingface_hub import login
login("YOUR HF TOKEN HERE")

In [None]:
from transformers import BitsAndBytesConfig

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch

model_name = "mistralai/Mistral-7B-Instruct-v0.3"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16
)

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    quantization_config=bnb_config,
    offload_folder="./offload",
    offload_state_dict=True
)


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
from peft import get_peft_model, PrefixTuningConfig, TaskType
from datasets import load_dataset, Dataset
import torch

In [None]:
# Setup Prefix Tuning ----
peft_config = PrefixTuningConfig(
    task_type=TaskType.CAUSAL_LM,
    num_virtual_tokens=20,
    encoder_hidden_size=model.config.hidden_size,
    prefix_projection=True
)

model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

trainable params: 272,719,872 || all params: 7,520,743,424 || trainable%: 3.6262


In [None]:
# Tokenization ----
def tokenize(example):
    return tokenizer(
        example["text"],
        truncation=True,
        max_length=512,
        padding="max_length"
    )

tokenized_dataset = dataset.map(tokenize, batched=True)

Map:   0%|          | 0/6000 [00:00<?, ? examples/s]

In [None]:
#  Training Setup ----
training_args = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=1,
    num_train_epochs=1,
    learning_rate=5e-5,
    logging_steps=1000,
    save_strategy="steps",
    save_steps=500,
    # bf16=True,
    fp16=True,
    report_to="none"
)

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator
)

  trainer = Trainer(
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [None]:
model.to("cuda")

PeftModelForCausalLM(
  (base_model): MistralForCausalLM(
    (model): MistralModel(
      (embed_tokens): Embedding(32768, 4096)
      (layers): ModuleList(
        (0-31): 32 x MistralDecoderLayer(
          (self_attn): MistralAttention(
            (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
            (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
            (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
            (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          )
          (mlp): MistralMLP(
            (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
            (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
            (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
            (act_fn): SiLU()
          )
          (input_layernorm): MistralRMSNorm((4096,), eps=1e-05)
          (post_attention_lay

In [None]:
# Train ----
trainer.train()

Step,Training Loss


Step,Training Loss
1000,2.1207
2000,2.0153
3000,1.9955


TrainOutput(global_step=3000, training_loss=2.043830485026042, metrics={'train_runtime': 6958.1293, 'train_samples_per_second': 0.862, 'train_steps_per_second': 0.431, 'total_flos': 1.31121668947968e+17, 'train_loss': 2.043830485026042, 'epoch': 1.0})

In [None]:
# Save PEFT adapter ----
model.save_pretrained(output_dir)

In [None]:
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)

('/content/drive/MyDrive/medical_Adapter/tokenizer_config.json',
 '/content/drive/MyDrive/medical_Adapter/special_tokens_map.json',
 '/content/drive/MyDrive/medical_Adapter/chat_template.jinja',
 '/content/drive/MyDrive/medical_Adapter/tokenizer.model',
 '/content/drive/MyDrive/medical_Adapter/added_tokens.json',
 '/content/drive/MyDrive/medical_Adapter/tokenizer.json')

In [None]:
trainer.model.save_pretrained(output_dir)

# Uploading to HF

In [None]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
from huggingface_hub import notebook_login, create_repo
from transformers import AutoTokenizer, AutoModelForSequenceClassification


# Step 2: Create repo (only once)
repo_name = "ankraj/mediguide"
create_repo(repo_name, repo_type="model", exist_ok=True)

# Step 3: Push model and tokenizer
model.push_to_hub(repo_name)
tokenizer.push_to_hub(repo_name)


adapter_model.safetensors:   0%|          | 0.00/5.24M [00:00<?, ?B/s]

README.md:   0%|          | 0.00/5.17k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/587k [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/ankraj/mediguide/commit/6bec6a650fbb561b8fa3c42be30e52404268a445', commit_message='Upload tokenizer', commit_description='', oid='6bec6a650fbb561b8fa3c42be30e52404268a445', pr_url=None, repo_url=RepoUrl('https://huggingface.co/ankraj/mediguide', endpoint='https://huggingface.co', repo_type='model', repo_id='ankraj/mediguide'), pr_revision=None, pr_num=None)

# Importing Model from HF

In [1]:
path = "ankraj/mediguide"

In [2]:
!pip install -q transformers accelerate bitsandbytes

from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.0/67.0 MB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m27.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m22.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m33.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m13.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [3]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
import torch

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16
)

# Load base model first
base_model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-Instruct-v0.3",
    device_map="auto",
    quantization_config=bnb_config,
    offload_folder="./offload",
    offload_state_dict=True
)

# Inject adapter
model = PeftModel.from_pretrained(base_model, path)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(path)
tokenizer.pad_token = tokenizer.eos_token

config.json:   0%|          | 0.00/601 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/4.55G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

adapter_config.json:   0%|          | 0.00/391 [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/5.24M [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/587k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/437 [00:00<?, ?B/s]

In [4]:
model.eval()

PeftModelForCausalLM(
  (base_model): MistralForCausalLM(
    (model): MistralModel(
      (embed_tokens): Embedding(32768, 4096)
      (layers): ModuleList(
        (0-31): 32 x MistralDecoderLayer(
          (self_attn): MistralAttention(
            (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
            (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
            (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
            (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          )
          (mlp): MistralMLP(
            (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
            (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
            (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
            (act_fn): SiLU()
          )
          (input_layernorm): MistralRMSNorm((4096,), eps=1e-05)
          (post_attention_lay

In [5]:
model = torch.compile(model)

In [6]:
from transformers import StoppingCriteria, StoppingCriteriaList
class StopOnTokens(StoppingCriteria):
    def __init__(self, stop_phrases, tokenizer):
        self.tokenizer = tokenizer
        self.stop_ids_list = [
            tokenizer(phrase, return_tensors="pt").input_ids[0][1:]  # remove BOS
            for phrase in stop_phrases
        ]

    def __call__(self, input_ids, scores, **kwargs):
        device = input_ids.device
        for stop_ids in self.stop_ids_list:
            stop_ids = stop_ids.to(device)  # ✅ Move to same device
            if len(input_ids[0]) >= len(stop_ids):
                if torch.equal(input_ids[0][-len(stop_ids):], stop_ids):
                    return True
        return False



In [7]:
def preprocess_input(input_text):
    instruction = "If you are a doctor, please answer the medical questions based on the patient's description."
    prompt = f"[MED] {instruction}\nPatient: {input_text} \nDoctor:"
    return prompt

In [8]:
import re

def clean_output(text):
    stop_patterns = [
        r"Take care Chat Doctor\.",
        r"Regards, Chat Doctor\.",
        r"Regards. Chat Doctor\.",
        r"Wishing you good health\.",
        r"Goodbye\.",
        r"Take care\.",
        r"\.com"
    ]

    doc_match = re.search(r"Doctor:\s*(.*)", text, re.DOTALL | re.IGNORECASE)
    if not doc_match:
        return text.strip()

    after_doctor = doc_match.group(1)

    stop_pattern = r"(.*?)(" + "|".join(stop_patterns) + ")"
    stop_match = re.search(stop_pattern, after_doctor, re.DOTALL | re.IGNORECASE)

    if stop_match:
        return stop_match.group(1).strip() + " " + stop_match.group(2)

    return after_doctor.strip()


In [9]:
def run_medical_bot(input_text, max_new_tokens=500):
    prompt = preprocess_input(input_text)
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    # List of phrases that should stop generation
    stop_phrases = [
        "Take care Chat Doctor.",
        "Regards, Chat Doctor.",
        "Regards. Chat Doctor.",
        "Wishing you good health.",
        "Goodbye.",
        "Take care."
    ]

    stopping_criteria = StoppingCriteriaList([StopOnTokens(stop_phrases, tokenizer)])

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            top_p=0.9,
            temperature=0.7,
            pad_token_id=tokenizer.eos_token_id,
            stopping_criteria=stopping_criteria
        )

    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return clean_output(generated_text)

In [None]:
input_text = "I am feeling uneazy.. I have vomitted 3 times in the last 2 year. I am 26 female having no prior such health condition."
output = run_medical_bot(input_text)

In [None]:
output

"Thanks for your question on Chat Doctor. I can understand your concern. By your history, your symptoms are suggestive of recurrent episodes of vomiting. You are also having symptoms of bloating and weight gain. So my first advice to you is to consult gastroenterologist and get done clinical examination of abdomen. He will also advise you for investigations like blood test, urine test, stool test, ultrasound abdomen, CT scan abdomen etc. For vomiting, you should start treatment with proton pump inhibitor (PPI) like esomeprazole. Take it once in the morning. Don't forget to take it. You can also take antiemetic (anti-vomiting) Chat Doctor.  If you are still having vomiting, then you should take prokinetic like metoclopramide. It is also very effective in controlling vomiting. Take it in the morning. Don't forget to take it. Also take antacid like omeprazole once in the night. Don't forget to take it. If you are still having symptoms of bloating, then you should take alosetron. It is ver

In [None]:
input_text = "Hello, At the end of lacrosse practice about a week ago i recieved a nasty cross check to my deltoid. The check hit wierd, as it went under my pad. The pain came in right away, couldnt move my arm for the rest of the night. I was surprised to see that that there was a very small bruise , but my whole shoulder hurts to the point where i cannot do simple tasks such as passing the ball. I can slowly move my arm fine, but when i speed things up it stings all over. I have bern icing it every day. I have been on advil only to help with the pain, but is there anything else i can do? Do you know what could be wrong? Any methods of treating it faster? I need to get back on the field asap before try outs are over"
output = run_medical_bot(input_text)
print(output)

Hello, Welcome to Chat Doctor, I have gone through your query and understand your concern. The injury could be a fracture, dislocation, or a muscle tear. I would advise you to get an X-ray and MRI done. If there is a fracture, it can be treated with a sling. If there is a dislocation, it should be reduced under anesthesia. If there is a muscle tear, it can be treated with a sling and physiotherapy. You should rest the arm as much as possible and avoid any strenuous activities. You should also avoid lifting heavy objects. Take a painkiller to help with the pain. I would advise you to apply a cold pack to the affected area. You should also elevate the arm above the level of the heart. This will help to reduce the swelling. I would also advise you to take an antispasmodic to help with the pain. You should also take a muscle relaxant to help with the pain. You should also take a multivitamin and a calcium supplement to help with the healing process. I would advise you to consult an orthope

# Evaluation

In [10]:
!pip install rouge-score --quiet

  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for rouge-score (setup.py) ... [?25l[?25hdone


In [18]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [24]:
import json
from datasets import Dataset

# Read your JSONL file manually
data = []
with open("/content/drive/MyDrive/MediGuideDataset/medicare_110k_test.json", "r") as f:
    for line in f:
        line = line.strip()
        if line:
            data.append(json.loads(line))

print(f"Loaded {len(data)} examples.")

# Convert directly to a plain Dataset object (not DatasetDict!)
dataset = Dataset.from_list(data)

# OPTIONAL: If you only want first 1000 examples
test_split = dataset.select(range(1000))

print(test_split)
print(test_split[0])

Loaded 5609 examples.
Dataset({
    features: ['Conversation'],
    num_rows: 1000
})
{'Conversation': 'The conversation between human and AI assistant.\n[|Human|] I wake in the night, usually about 2-3 hours after going to sleep, with both feet and legs to mid calf feeling like they are on fire. slight red discolorization, minor swelling. This is very painful but after getting up, I can walk it off in about 30 minutes.\n[|AI|]  Dear patient Here are the possibilities of what you might have.1)PhlebitisPhlebitis means inflammation of the veins, and can cause redness, itching, irritation, pain, and swelling. A simple Doppler can rule this out.2Blood clot in the lifeblood clots in the leg can become very dangerous, symptoms include swelling, redness, tenderness in the leg. Coagulation profile with an angiography may be required3)Cellulitis: Initial stage. Only can be clinically ruled out Hope this helped\n'}


In [25]:
import re
def extract_prompt_response(example):
    """
    Parses the single-string conversation field into:
      - instruction: text between “[|Human|]” and “[|AI|]” (or full text if no markers).
      - response: text after the last “[|AI|]” marker (or empty if none).
    """
    convo = str(example[next(iter(example.keys()))]).strip()
    human_match = re.search(r"\[\|Human\|\]\s*(.*?)\s*(?=\[\|AI\|\])", convo, re.DOTALL)
    instruction = human_match.group(1).strip() if human_match else convo
    parts = re.split(r"\[\|AI\|\]", convo)
    response = parts[-1].strip() if len(parts) > 1 else ""
    return {"instruction": instruction, "response": response}

test_df = test_split.map(
    extract_prompt_response,
    remove_columns=test_split.column_names,
    num_proc=4
)
test_prompts = test_df["instruction"]
test_references = test_df["response"]

Map (num_proc=4):   0%|          | 0/1000 [00:00<?, ? examples/s]

  block_group = [InMemoryTable(cls._concat_blocks(list(block_group), axis=axis))]
  table = cls._concat_blocks(blocks, axis=0)


In [26]:
test_prompts[0]

'I wake in the night, usually about 2-3 hours after going to sleep, with both feet and legs to mid calf feeling like they are on fire. slight red discolorization, minor swelling. This is very painful but after getting up, I can walk it off in about 30 minutes.'

In [27]:
test_references[0]

'Dear patient Here are the possibilities of what you might have.1)PhlebitisPhlebitis means inflammation of the veins, and can cause redness, itching, irritation, pain, and swelling. A simple Doppler can rule this out.2Blood clot in the lifeblood clots in the leg can become very dangerous, symptoms include swelling, redness, tenderness in the leg. Coagulation profile with an angiography may be required3)Cellulitis: Initial stage. Only can be clinically ruled out Hope this helped'

In [30]:
# 1. PERPLEXITY

from torch.utils.data import DataLoader
from transformers import DataCollatorForLanguageModeling
import os

checkpoint_dir = "/content/drive/MyDrive/medical_prefix"
eval_ckpt_path = "/content/drive/MyDrive/medical_prefix/test_eval_state.pth"
batch_size = 2
save_every_n_batches = 50


class LMTestDataset(torch.utils.data.Dataset):
    def __init__(self, texts, tokenizer, max_length=1024):
        encodings = tokenizer(
            texts,
            return_tensors="pt",
            max_length=max_length,
            truncation=True,
            padding="max_length"
        )
        self.input_ids = encodings["input_ids"]


    def __len__(self):
        return self.input_ids.size(0)

    def __getitem__(self, idx):
        return {
            "input_ids": self.input_ids[idx],
            "labels": self.input_ids[idx].clone()
        }


test_texts = [
    f"{instr}\n\n{resp}"
    for instr, resp in zip(test_prompts, test_references)
]
lm_test_dataset = LMTestDataset(test_texts, tokenizer, max_length=1024)


data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False
)
test_loader = DataLoader(
    lm_test_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=data_collator
)

device = model.device


if os.path.isfile(eval_ckpt_path):

    state = torch.load(eval_ckpt_path)
    start_batch = state["last_batch"] + 1
    accumulated_loss = state["accumulated_loss"]
    total_tokens = state["total_tokens"]
    print(f"Resuming test-eval from batch {start_batch} (saved on disk).")
else:

    start_batch = 0
    accumulated_loss = 0.0
    total_tokens = 0
    print("Starting test-eval from batch 0.")


model.eval()
with torch.no_grad():
    for batch_idx, batch in enumerate(test_loader):
        if batch_idx < start_batch:
            continue


        input_ids = batch["input_ids"].to(device)
        labels = batch["labels"].to(device)


        outputs = model(
            input_ids=input_ids,
            labels=labels
        )
        loss = outputs.loss.detach().cpu().item()


        nonpad_tokens = (labels != tokenizer.pad_token_id).sum().item()


        accumulated_loss += loss * nonpad_tokens
        total_tokens += nonpad_tokens


        if (batch_idx + 1) % save_every_n_batches == 0:
            state = {
                "last_batch": batch_idx,
                "accumulated_loss": accumulated_loss,
                "total_tokens": total_tokens
            }
            torch.save(state, eval_ckpt_path)
            print(f" Saved eval state at batch {batch_idx} → tokens={total_tokens}")

    final_avg_loss = accumulated_loss / total_tokens
    test_perplexity = torch.exp(torch.tensor(final_avg_loss)).item()
    os.remove(eval_ckpt_path)

    print(f"\n→ Test complete. Avg. token-loss = {final_avg_loss:.4f}")
    print(f"→ Test Perplexity = {test_perplexity:.2f}")

Starting test-eval from batch 0.
 Saved eval state at batch 49 → tokens=102400
 Saved eval state at batch 99 → tokens=204800
 Saved eval state at batch 149 → tokens=307200
 Saved eval state at batch 199 → tokens=409600
 Saved eval state at batch 249 → tokens=512000
 Saved eval state at batch 299 → tokens=614400
 Saved eval state at batch 349 → tokens=716800
 Saved eval state at batch 399 → tokens=819200
 Saved eval state at batch 449 → tokens=921600
 Saved eval state at batch 499 → tokens=1024000

→ Test complete. Avg. token-loss = 5.8102
→ Test Perplexity = 333.69


In [31]:
# 2. LATENCY

import time

n_samples = min(50, len(test_prompts))
max_new_tokens = 128
batch_size = 4


tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token

latencies = []
model.eval()
device = model.device


dummy_input = tokenizer("Hello", return_tensors="pt").to(device)
_ = model.generate(
    **dummy_input,
    max_new_tokens=10,
    do_sample=False,
    use_cache=True,
    return_dict_in_generate=False
)
torch.cuda.synchronize()


for i in range(0, n_samples, batch_size):
    batch_prompts = test_prompts[i : i + batch_size]
    inputs = tokenizer(
        batch_prompts,
        return_tensors="pt",
        truncation=True,
        padding=True
    ).to(device)

    torch.cuda.synchronize()
    start = time.time()
    _ = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        use_cache=True,
        return_dict_in_generate=False
    )
    torch.cuda.synchronize()
    end = time.time()


    elapsed = end - start
    latencies.append(elapsed / len(batch_prompts))

avg_latency = sum(latencies) / len(latencies)
print(f"Average Latency (per prompt, {max_new_tokens} new tokens): {avg_latency:.4f} seconds")

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 

Average Latency (per prompt, 128 new tokens): 7.3562 seconds


In [33]:
# 3. MODEL SIZE ON DISK

def folder_size_in_mb(path: str) -> float:
    total_bytes = 0
    for root, _, files in os.walk(path):
        for fname in files:
            fp = os.path.join(root, fname)
            total_bytes += os.path.getsize(fp)
    return total_bytes / (1024 ** 2)

model_size_mb = folder_size_in_mb("/content/drive/MyDrive/medical_Adapter")
print(f"Model Size on Disk : {model_size_mb:.1f} MB")

Model Size on Disk : 9.2 MB


In [34]:
# Generation for ROUGE calculation
import json
from tqdm import tqdm


def optimized_medical_generation(
    model, tokenizer, prompts, references,
    checkpoint_path, batch_size=8, max_length=512
):


    if os.path.isdir(checkpoint_path):
        os.makedirs(checkpoint_path, exist_ok=True)
        state_file = os.path.join(checkpoint_path, "generation_resume.json")
    else:
        state_file = checkpoint_path


    if os.path.exists(state_file):
        with open(state_file, 'r') as f:
            state = json.load(f)
        completed_indices = state.get('completed_indices', [])
        predictions = state.get('predictions', [])
        print(f"Resuming from {len(completed_indices)} completed samples")
    else:
        completed_indices = []
        predictions = []
        state = {'completed_indices': completed_indices, 'predictions': predictions}

    completed_set = set(completed_indices)
    remaining_indices = [i for i in range(len(prompts)) if i not in completed_set]

    if not remaining_indices:
        print("All samples already processed!")
        return predictions, [references[i] for i in completed_indices]

    model.eval()
    device = model.device


    tokenizer.padding_side = "left"
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token


    with torch.no_grad():
        emb = model.get_input_embeddings().weight
        emb[tokenizer.pad_token_id].zero_()


    for batch_start in tqdm(range(0, len(remaining_indices), batch_size),
                            desc="Medical Generation"):
        batch_end     = min(batch_start + batch_size, len(remaining_indices))
        batch_indices = remaining_indices[batch_start:batch_end]
        batch_prompts = [prompts[i] for i in batch_indices]


        inputs = tokenizer(
            [f"MEDICAL PROMPT: {p}" for p in batch_prompts],
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_length,
            return_attention_mask=True,
        ).to(device)

        with torch.no_grad():
            outputs = model.generate(
                input_ids      = inputs["input_ids"],
                attention_mask = inputs["attention_mask"],
                pad_token_id   = tokenizer.pad_token_id,
                max_new_tokens = 256,
                do_sample      = False,
                num_beams      = 1,
                use_cache      = True,
            )

        decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)


        batch_preds = []
        for prompt, text in zip(batch_prompts, decoded):
            prefix = f"MEDICAL PROMPT: {prompt}"
            if text.startswith(prefix):
                gen_text = text[len(prefix):].strip()
            else:
                gen_text = text.strip()
            batch_preds.append(gen_text)

        predictions.extend(batch_preds)
        completed_indices.extend(batch_indices)


        with open(state_file, 'w') as f:
            json.dump({
                'completed_indices': completed_indices,
                'predictions': predictions
            }, f)

    return predictions, [references[i] for i in completed_indices]


In [36]:
latest_checkpoint = "/content/drive/MyDrive/medical_Adapter"

In [37]:
predictions, processed_refs = optimized_medical_generation(
    model,
    tokenizer,
    test_prompts,
    test_references,
    latest_checkpoint,
    batch_size=12,
    max_length=120
)

Medical Generation: 100%|██████████| 84/84 [1:29:12<00:00, 63.72s/it]


In [38]:
!pip install evaluate --q

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/84.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[?25h

In [39]:
from evaluate import load
rouge = load("rouge")
results = rouge.compute(
        predictions=predictions,
        references=processed_refs,
        use_stemmer=True,
        use_aggregator=True
    )

Downloading builder script: 0.00B [00:00, ?B/s]

In [40]:
print("\nMedical ROUGE Scores:")
print(f"ROUGE-1: {results['rouge1']:.4f}")
print(f"ROUGE-2: {results['rouge2']:.4f}")
print(f"ROUGE-L: {results['rougeL']:.4f}")
print(f"ROUGE-Lsum: {results['rougeLsum']:.4f}")


Medical ROUGE Scores:
ROUGE-1: 0.1556
ROUGE-2: 0.0226
ROUGE-L: 0.0956
ROUGE-Lsum: 0.0988
