In [17]:
import torch
from transformers import AutoTokenizer
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from datasets import Dataset
import pandas as pd

# Configuration
DATA_PATH = "combined_clinical_notes.csv"
MODEL_PATH = "gpt2"
MEDICAL_PROMPTS = [
    "Generate a clinical summary focusing on medications: ",
    "List key diagnostic findings from this case: ",
    "Create a treatment plan outline: "
]

# Load dataset
df = pd.read_csv(DATA_PATH)
dataset = Dataset.from_pandas(df.rename(columns={"dialogue": "review"}))

# Tokenizer setup
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, padding_side='left')
tokenizer.pad_token = tokenizer.eos_token

# PPO Model loading
model = AutoModelForCausalLMWithValueHead.from_pretrained(MODEL_PATH).to("cuda")

# Initialize PPO Trainer
ppo_config = PPOConfig(
    model_name=MODEL_PATH,
    batch_size=1,
    learning_rate=1e-5,
)

ppo_trainer = PPOTrainer(
    config=ppo_config,
    model=model,
    ref_model=None,
    tokenizer=tokenizer,
)

# Generation parameters
generation_kwargs = {
    "do_sample": True,
    "top_p": 0.9,
    "max_new_tokens": 128,
    "pad_token_id": tokenizer.eos_token_id,
    "eos_token_id": tokenizer.eos_token_id,
}

# Iterate through dataset and prompts
for example in dataset.select(range(1)):  # First 3 examples
    original_notes = example["review"]
    
    for prompt in MEDICAL_PROMPTS:
        # Combine prompt with original notes
        full_prompt = f"{prompt}\n{original_notes}"
        
        # Tokenize and generate
        inputs = tokenizer.encode(
            full_prompt,
            return_tensors="pt",
            truncation=True,
            max_length=768
        ).squeeze(0).to("cuda")  # Remove batch dimension
        
        # Generate response
        response = ppo_trainer.generate(inputs, **generation_kwargs)
        response = response[:, :generation_kwargs["max_new_tokens"]]
        
        # Decode for evaluation
        decoded_response = tokenizer.decode(response.squeeze(), skip_special_tokens=True)
        
        # Decode and print
        # decoded_response = tokenizer.decode(response, skip_special_tokens=True)
        
        print(f"\n=== Prompt: {prompt} ===")
        print(f"Original notes: {original_notes}...")
        print(f"\nGenerated response:\n{decoded_response[len(prompt):]}")  # Show only new text
        print("\n" + "="*50 + "\n")




=== Prompt: Generate a clinical summary focusing on medications:  ===
Original notes: [doctor] hi diane , how are you ?
[patient] i'm doing okay , how are you ?
[doctor] i'm doin' okay . so i know the nurse told you about dax and i'd like to tell dax a little bit about you okay ?
[patient] okay .
[doctor] diane is a 28 year old female with a past medical history , significant for , depression and hypertension who presents for emergency room follow-up .
[doctor] so diane what's going on ? i heard that your- your blood pressure was really high in the emergency room . what happened ?
[patient] yeah , so i ended up going for a walk , um , yesterday 'cause it was sunny and it was really great . and i just felt really light-headed , um , and i started to fall a bit , and , um , luckily i was with my boyfriend and he caught me , um , and then we went right to the e , to the er .
[doctor] yeah , okay . yeah , i saw that the blood pressure was pretty high , like in , like , the , almost 200 .


In [1]:
import os
os.environ["HF_HOME"] = r"D:\hf-cache"

In [3]:
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType

lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    bias="none",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], 
)

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from trl import AutoModelForCausalLMWithValueHead
from transformers import BitsAndBytesConfig  

# ---- Device Setup ----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ---- Paths ----
MODEL_PATH = r"D:\kshitij-weights-folder\qwen-aloe-9-4-base-fine-tune"
PEFT_ADAPTER_PATH = r"D:\kshitij-weights-folder\qwen-aloe-9-4-base-fine-tune-peft-adapaters"
REF_MODEL_PATH = r"D:\kshitij-weights-folder\qwen-aloe-9-4-base-fine-tune"

# ---- 1) 4-bit Quantization Configuration ----
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

# ---- 2) Load Base Model in 4-bit ----
base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH,
    quantization_config=bnb_config,
    device_map="auto",
)
# Prepare the model for k-bit training (this typically freezes most parameters except adapter ones)
base_model = prepare_model_for_kbit_training(base_model)
base_model.gradient_checkpointing_disable()  # Disable checkpointing

# ---- 3) Load Tokenizer ----
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

# ---- 4) Load the PEFT Adapter (LoRA) ----
# This reloads your fine-tuned adapter weights onto your base model.
model_with_lora = PeftModel.from_pretrained(base_model, PEFT_ADAPTER_PATH)

# ---- 5) Convert to PPO-Compatible ValueHead Model ----
# When converting, pass the peft_config from the adapter model to ensure proper initialization.
starcoder_model = AutoModelForCausalLMWithValueHead.from_pretrained(
    model_with_lora,
    peft_config=lora_config
).to(device)

print("done")



Using device: cuda


Loading checkpoint shards: 100%|███████████████████████████████████████████████| 4/4 [00:07<00:00,  1.92s/it]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


done


In [69]:
MEDICAL_PROMPTS = {
    "chief_complaint": {
        "instruction": "Extract the patient's primary concern or reason for seeking care (e.g., 'chest pain', 'fatigue'). Include duration if noted.",
        "example": "Chief Complaint: 'Sharp headache for 3 days'"
    },
    "history_of_illness": {
        "instruction": "Detail onset, timing, severity, and progression of symptoms. Include aggravating/alleviating factors and prior treatments.",
        "example": "HPI: Headache started 3 days ago, throbbing, worsens with light. Took ibuprofen with partial relief."
    },
    "medications": {
        "instruction": "List all current medications with dosages and purposes (e.g., 'Lisinopril 10mg daily for HTN').",
        "example": "Meds: Aspirin 81mg daily, Metformin 500mg BID"
    },
    "vitals": {
        "instruction": "Extract vital signs with units and timing (e.g., 'BP 120/80 mmHg'). Flag abnormalities.",
        "example": "Vitals: Temp 98.6°F, HR 72, BP 118/76"
    },
    "physical_exam": {
        "instruction": "Summarize key exam findings (abnormal/normal). Include systems examined (e.g., 'Lungs: clear bilaterally').",
        "example": "PE: Heart RRR, no murmurs. Abdomen soft, non-tender."
    },
    "assessment": {
        "instruction": "State the clinician's diagnosis/differentials (e.g., 'Migraine vs tension headache'). Include supporting evidence.",
        "example": "Assessment: Likely tension headache. No red flags."
    },
    "treatment_plan": {
        "instruction": "Outline next steps: medications, referrals, follow-up (e.g., 'Start amitriptyline, follow up in 4 weeks').",
        "example": "Plan: Hydration, ibuprofen PRN, return if worsening."
    },
    "patient_instructions": {
        "instruction": "Extract discharge/follow-up instructions (e.g., 'Avoid NSAIDs, call if fever develops').",
        "example": "Instructions: Rest, monitor for confusion, return if headache persists >48h."
    }
}

In [107]:


import pandas as pd
from datasets import Dataset

# ---- Load Dataset ----
DATA_PATH = "combined_clinical_notes.csv"  # Update with your actual path
df = pd.read_csv(DATA_PATH)
dataset = Dataset.from_pandas(df.rename(columns={"dialogue": "review"}))

# ---- Medical Prompts ----
generation_kwargs = {
    "do_sample": False,
    "num_beams": 1,          # Greedy search
    "temperature": 0.0,      # Remove randomness
    "pad_token_id": tokenizer.eos_token_id,
    "eos_token_id": tokenizer.eos_token_id,
    "max_new_tokens": 32
}

# ---- Generation Function ----
def generate_responses(model, tokenizer, dataset, prompts, device, max_new_tokens=256):
    responses = {}
    model.eval()
    
    for example in dataset.select(range(1)):  # Test with 3 examples
        original_notes = example["review"]
        print(f"\nOriginal Notes:\n{original_notes}...")
        full_prompt = None
        
        
        for prompt_key, prompt in MEDICAL_PROMPTS.items():  # <-- Changed to .items()
    # Combine prompt with original notes
            full_prompt = (
                f"{prompt['instruction']} Extract the information from below conversation. "
                "Only include the relevant information in your response please. Keep responses "
                "as short as you can while including all important information.\n"
                f"Example: {prompt['example']} Follow the example's format for generating your "
                "response in that format and don't include any facts from the example in your "
                f"response. Don't include any dialogues from the conversation\n\n"
                f"{original_notes}\n\n"
            )
            
            # Tokenize with proper formatting
            inputs = tokenizer(
                full_prompt,
                return_tensors="pt",
                truncation=True,
                padding="max_length",
                max_length=2048  # Adjust based on model's max context
            ).to(device)

            
            # Generate response
            with torch.no_grad():
                response = model.generate(
                    input_ids=inputs.input_ids,
                    attention_mask=inputs.attention_mask,
                    **generation_kwargs
                )
            
            # Decode and clean output
            full_text = tokenizer.decode(response[0], skip_special_tokens=True)
            generated_part = full_text

            responses[prompt_key] = generated_part[len(full_prompt):]
        
            # print(f"\n=== Prompt ===\n{prompt}")
            # # print(f"\nOriginal Notes:\n{original_notes[:500]}...")
            # print(f"\nGenerated Response:\n{generated_part[len(full_prompt):]}")
            # print("\n" + "="*80 + "\n")

    print(responses)

# ---- Execute Generation ----
generate_responses(
    model=starcoder_model,
    tokenizer=tokenizer,
    dataset=dataset,
    prompts=MEDICAL_PROMPTS,
    device=device,
    max_new_tokens=3000
)



Original Notes:
[doctor] hi diane , how are you ?
[patient] i'm doing okay , how are you ?
[doctor] i'm doin' okay . so i know the nurse told you about dax and i'd like to tell dax a little bit about you okay ?
[patient] okay .
[doctor] diane is a 28 year old female with a past medical history , significant for , depression and hypertension who presents for emergency room follow-up .
[doctor] so diane what's going on ? i heard that your- your blood pressure was really high in the emergency room . what happened ?
[patient] yeah , so i ended up going for a walk , um , yesterday 'cause it was sunny and it was really great . and i just felt really light-headed , um , and i started to fall a bit , and , um , luckily i was with my boyfriend and he caught me , um , and then we went right to the e , to the er .
[doctor] yeah , okay . yeah , i saw that the blood pressure was pretty high , like in , like , the , almost 200 .
[patient] yeah .
[doctor] did you have a headache ?
[patient] yeah i d



{'chief_complaint': "The patient's primary concern is high blood pressure, which led to lightheadedness and a fall during a walk. She reports experiencing high blood pressure episodes", 'history_of_illness': 'HPI: The patient, Diane, a 28-year-old female with a history of depression and hypertension, presented for emergency room follow-up after experiencing high', 'medications': 'Meds: Lisinopril 40mg daily for HTN\nMeds: Lisinopril 40mg daily for HTN.\n\n', 'vitals': 'Vitals: BP 198/100 mmHg, HR not mentioned, Temperature not mentioned, Respiratory rate not mentioned, O2', 'physical_exam': 'PE: Heart - S2/6 systolic ejection murmur, clear lungs bilaterally, trace pitting edema in lower extremities. AB', 'assessment': "Assessment: Hypertension, likely uncontrolled despite medication adherence. Differential: Migraine or other headache disorder given the patient's history of episodic high", 'treatment_plan': 'Plan: Increase lisinopril to 40 mg daily, monitor blood pressures through patie

In [None]:

=== Prompt ===
{'instruction': "Extract the patient's primary concern or reason for seeking care (e.g., 'chest pain', 'fatigue'). Include duration if noted.", 'example': "Chief Complaint: 'Sharp headache for 3 days'"}

Generated Response:
Primary Concern: High blood pressure, lightheadedness, and headache. The patient reports that her blood pressure often "skyrockets" when she travels

================================================================================


=== Prompt ===
{'instruction': 'Detail onset, timing, severity, and progression of symptoms. Include aggravating/alleviating factors and prior treatments.', 'example': 'HPI: Headache started 3 days ago, throbbing, worsens with light. Took ibuprofen with partial relief.'}

Generated Response:
HPI: Headache started yesterday, worsening with light, improved with ibuprofen. Patient experienced lightheadedness and nearly fell during a walk,

================================================================================


=== Prompt ===
{'instruction': "List all current medications with dosages and purposes (e.g., 'Lisinopril 10mg daily for HTN').", 'example': 'Meds: Aspirin 81mg daily, Metformin 500mg BID'}

Generated Response:
Meds: Lisinopril 40mg daily for HTN. The patient also has a history of depression, which she manages with therapy. She

================================================================================


=== Prompt ===
{'instruction': "Extract vital signs with units and timing (e.g., 'BP 120/80 mmHg'). Flag abnormalities.", 'example': 'Vitals: Temp 98.6°F, HR 72, BP 118/76'}

Generated Response:
Vitals: BP 195/110 mmHg (high). 
Abnormalities: Slight 2/6 systolic e

================================================================================


=== Prompt ===
{'instruction': "Summarize key exam findings (abnormal/normal). Include systems examined (e.g., 'Lungs: clear bilaterally').", 'example': 'PE: Heart RRR, no murmurs. Abdomen soft, non-tender.'}

Generated Response:
Lungs: Clear bilaterally. Heart: Slight two out of six systolic ejection murmur. Pitting edema in lower extremities bil

================================================================================


=== Prompt ===
{'instruction': "State the clinician's diagnosis/differentials (e.g., 'Migraine vs tension headache'). Include supporting evidence.", 'example': 'Assessment: Likely tension headache. No red flags.'}

Generated Response:
Diagnosis: Hypertension, uncontrolled
Differential diagnoses: Migraine, Tension headache (considered less likely given the absence of other migraine

================================================================================


=== Prompt ===
{'instruction': "Outline next steps: medications, referrals, follow-up (e.g., 'Start amitriptyline, follow up in 4 weeks').", 'example': 'Plan: Hydration, ibuprofen PRN, return if worsening.'}

Generated Response:
Plan: Increase lisinopril to 40 mg daily, monitor blood pressure and adjust if necessary, possibly add a second agent; continue current depression management

================================================================================


=== Prompt ===
{'instruction': "Extract discharge/follow-up instructions (e.g., 'Avoid NSAIDs, call if fever develops').", 'example': 'Instructions: Rest, monitor for confusion, return if headache persists >48h.'}

Generated Response:
Instructions: Take lisinopril 40 mg daily. Monitor blood pressure and report through the portal. Increase lisinopril dose if BP is not

================================================================================
