# Gemma NER for Diagnosis Extraction (Interactive Notebook)

This notebook demonstrates how to use a Gemma instruction-tuned model (via Hugging Face `transformers`) to extract diagnosis names from clinical text based on a specific prompt.

**Steps:**
1.  Install necessary libraries.
2.  Configure model and file paths.
3.  Load datasets
4.  Define the helper function for parsing model output.
5.  Load the tokenizer and model (this may take time and resources).
6.  Process each note: construct prompt, generate text, parse diagnosis.
7.  Display and save the results.

## 1. Install Libraries

In [1]:
!pip install --upgrade transformers torch pandas accelerate sentencepiece bitsandbytes -q

## 2. Imports and Configuration

In [1]:
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import time
import re
import logging
import os

# --- Configuration ---
INPUT_FILE = '/Users/adrish/Desktop/ehrcon-dataset/workgroup-notes.csv'
SELECTED_COLUMNS = ['SUBJECT_ID','TEXT'] 
OUTPUT_FILE = 'extracted_diagnoses_v2.csv'
# Choose model: 'google/gemma-3-1b-it' (faster, less VRAM) or 'google/gemma-7b-it' (potentially better, more VRAM)
MODEL_NAME = 'google/gemma-3-1b-it' 

# Optional: Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

## 4. Define Helper Function for Parsing

In [2]:
# --- Function to clean model output ---
def parse_diagnoses(model_output, prompt_text):
    """
    Parses the model's raw output to extract the comma-separated diagnoses.
    Removes the input prompt part and cleans the result.
    """
    # Find the start of the actual answer (after the prompt)
    # Look for the marker we put at the end of the prompt
    answer_marker = "Extracted Diagnoses (comma-separated):"
    try:
        # Find the position *after* the marker
        start_index = model_output.index(answer_marker) + len(answer_marker)
        # Extract the text after the marker
        extracted_part = model_output[start_index:].strip()

        # Sometimes models add extra text or explanations after the list.
        # Try to find the first newline character after the start, assuming the list is on one line.
        first_newline = extracted_part.find('\n')
        if first_newline != -1:
            extracted_part = extracted_part[:first_newline].strip()

        # Remove potential trailing tags or symbols often added by models
        extracted_part = re.sub(r'<eos>$|</s>$', '', extracted_part).strip() # Remove end-of-sequence tokens if present

        # Check if the model explicitly said "None"
        if extracted_part.lower() == 'none':
            return []

        # Split by comma and clean up each item
        diagnoses = [diag.strip() for diag in extracted_part.split(',') if diag.strip()]
        
        # Final sanity check: remove any empty strings that might remain
        diagnoses = [d for d in diagnoses if d]

        return diagnoses

    except ValueError:
        # If the marker isn't found, the model output format was unexpected.
        logging.warning(f"Could not find answer marker '{answer_marker}' in model output. Trying fallback. Raw output: {model_output[:500]}...") # Log truncated output
        # Attempt a simpler extraction based on the last line (less reliable)
        lines = model_output.strip().split('\n')
        if lines:
            last_line = lines[-1].strip()
             # Avoid taking the prompt itself as the answer if it appears last
            if answer_marker not in last_line and last_line.lower() != 'none':
                 # Basic split and clean, might capture unwanted text
                 diagnoses = [diag.strip() for diag in last_line.split(',') if diag.strip()] 
                 return [d for d in diagnoses if d] # Ensure no empty strings
        return [] # Return empty list if parsing fails

## 5. Load Tokenizer and Model

In [3]:
logging.info(f"Loading model: {MODEL_NAME}")
device = "cuda" if torch.cuda.is_available() else "cpu"
logging.info(f"Using device: {device}")

try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    # Ensure pad_token is set if missing (common issue with some models)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        logging.info("Set tokenizer pad_token to eos_token")

    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        device_map="auto", # Automatically use GPU if available
        torch_dtype=torch.bfloat16 # Use bfloat16 for faster inference if supported, requires Ampere GPU or newer
        # torch_dtype=torch.float16 # Alternative if bfloat16 not supported
    )
    logging.info("Model and tokenizer loaded successfully.")
except Exception as e:
    logging.error(f"Error loading model: {e}")
    # Optionally, raise the error to stop execution if model loading fails
    # raise e 

2025-04-21 22:23:17,404 - INFO - Loading model: google/gemma-3-1b-it
2025-04-21 22:23:17,406 - INFO - Using device: cpu
  Referenced from: <2BD1B165-EC09-3F68-BCE4-8FE4E70CA7E2> /Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/torchvision/image.so
  warn(
2025-04-21 22:23:27,782 - INFO - Model and tokenizer loaded successfully.


## 6. Load Input Data

Load the clinical notes from the CSV file

In [4]:
logging.info(f"Loading input data from: {INPUT_FILE}")
try:
    df_input = pd.read_csv(INPUT_FILE)
    df_input = df_input[SELECTED_COLUMNS]
    # Verify required columns exist
    if "SUBJECT_ID" not in df_input.columns or "TEXT" not in df_input.columns:
        raise ValueError("Input CSV must contain 'subject id' and 'text' columns.")
    # Handle potential missing text data
    df_input['TEXT'] = df_input['TEXT'].fillna('')
    logging.info(f"Loaded {len(df_input)} records.")
    print("Input Data Sample:")
    display(df_input.head()) # Display first 5 rows in Jupyter
except FileNotFoundError:
    logging.error(f"Input file not found: {INPUT_FILE}")
except Exception as e:
    logging.error(f"Error reading input CSV: {e}")

2025-04-21 22:23:32,276 - INFO - Loading input data from: /Users/adrish/Desktop/ehrcon-dataset/workgroup-notes.csv
2025-04-21 22:23:32,322 - INFO - Loaded 9 records.


Input Data Sample:


Unnamed: 0,SUBJECT_ID,TEXT
0,100,"""Patient presents with symptoms indicative of ..."
1,200,89 yo M with a history of prostate CA and Alzh...
2,300,52 yo male with Down's syndrome and NAFLD who ...
3,400,Patient is a 50yo woman with adenoid cystic ca...
4,500,The patient is a 78-year-old woman with a hist...


## 7. Process Records

Iterate through each record, generate the diagnosis using the model, parse the output, and store the results.

In [None]:
results = []
total_records = len(df_input)
start_time = time.time()

for index, row in df_input.iterrows():
    patient_id = row['SUBJECT_ID']
    clinical_text = row['TEXT']

    print(f"\nProcessing record {index + 1}/{total_records} for patient ID: {patient_id}")

    if not clinical_text or pd.isna(clinical_text) or not clinical_text.strip():
         logging.warning(f"  Skipping record {index + 1} due to empty clinical text.")
         print("  Skipping due to empty text.")
         continue # Skip records with no text

    # --- Construct the Prompt ---
    prompt = f"""Objective: Identify all diagnosis names mentioned in the following clinical text.
Guidelines:
1. Extract only diagnosis names.
2. Extract the entity exactly as written in the note without modification.
3. Only extract diagnoses explicitly listed in the text. Do not infer or add conditions not present.
4. Ignore numeric values unless they are part of a specific diagnosis name (e.g., 'Type 2 Diabetes', 'stage 3').
5. Focus on conditions, diseases, syndromes, and specific medical problems mentioned.
6. Output Format: List the extracted diagnosis names separated by commas. If no diagnoses are found, output "None".

Clinical Text:
---
{clinical_text}
---

Extracted Diagnoses (comma-separated):"""

    # --- Generate Text with the Model ---
    try:
        inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=10000).to(model.device) # Adjust max_length if needed
        # Adjust generation parameters as needed
        outputs = model.generate(
            **inputs,
            max_new_tokens=10000,  # Max tokens for the *answer* part
            do_sample=False,     # Use greedy decoding for consistency
            temperature=0,   # Optional: for slight randomness if needed
            # top_k=50,          # Optional: sampling parameters
            pad_token_id=tokenizer.pad_token_id # Ensure padding token is set
            )
        # Decode the full output (including the prompt part)
        full_output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

        # --- Parse the Output ---
        # print(f"\nDEBUG: Raw model output:\n{full_output_text}\n---\n") # Uncomment for debugging raw output
        extracted_diagnoses = parse_diagnoses(full_output_text, prompt) # Pass the original prompt text

        # --- Store Results ---
        if extracted_diagnoses:
            print(f"  Found diagnoses: {extracted_diagnoses}")
            for diagnosis in extracted_diagnoses:
                results.append({'patient_id': patient_id, 'entity_name': diagnosis})
        else:
            print(f"  No diagnoses found.")
            # If you want to explicitly record patients with no diagnoses found:
            # results.append({'patient id': patient_id, 'entity_name': 'None'}) # Optional based on requirements

    except Exception as e:
        logging.error(f"Error processing record {index + 1} for patient ID {patient_id}: {e}")
        print(f"  Error processing record: {e}")
        # Optionally add a placeholder for failed records
        # results.append({'patient id': patient_id, 'entity_name': 'Processing_Error'})
    
    # Small delay to potentially help with GPU cooling or API rate limits if applicable
    # time.sleep(0.1)

end_time = time.time()
processing_time = end_time - start_time
print(f"\n-------------------------------------------------")
logging.info(f"Finished processing all records. Total time: {processing_time:.2f} seconds.")
print(f"Finished processing all records. Total time: {processing_time:.2f} seconds.")


Processing record 1/9 for patient ID: 100




  Found diagnoses: ['Pneumonia', 'Hypertension', 'Type 2 Diabetes']

Processing record 2/9 for patient ID: 200
  Found diagnoses: ['[Hospital1 18]', "Alzheimer's dementia", 'Prostate CA', 'Maroon', 'guaiac positive stool', 'Gerontology']

Processing record 3/9 for patient ID: 300
  Found diagnoses: ["Down's syndrome", 'NAFLD', 'facial rash', 'petechial rash', 'fevers']

Processing record 4/9 for patient ID: 400
  Found diagnoses: ['Adenoid cystic carcinoma', 'pneumonectomy', 'liver', 'kidney', 'PE', 'fevers', 'lethargy', 'pleuritic CP']

Processing record 5/9 for patient ID: 500
  Found diagnoses: ['Doctor Last Name 688', 'Diabetes', 'Carotid stenosis', 'Chronic back pain', 'Lethargy', 'Facial droop', 'Voice weakness', 'Slurred speech', 'Cough', 'Abdominal pain', 'Diarrhea', 'Constipation', 'Fevers', 'Chills', 'Chest pain', 'Shortness of breath']

Processing record 6/9 for patient ID: 600
  Found diagnoses: ['Last Name', 'Location', 'Acute abdominal back pain', 'systolic blood pressure

## 8. Display and Save Results

Convert the collected results into a DataFrame, display the first few rows, and save the complete list to a CSV file.

In [28]:
if results:
    df_output = pd.DataFrame(results)
    print("\nExtracted Diagnoses Sample:")
    display(df_output.head()) # Display first 5 rows
    
    logging.info(f"Saving {len(df_output)} extracted diagnosis entries to: {OUTPUT_FILE}")
    try:
        df_output.to_csv(OUTPUT_FILE, index=False)
        print(f"\nOutput saved successfully to {OUTPUT_FILE}")
        logging.info("Output file saved successfully.")
    except Exception as e:
        logging.error(f"Error saving output file: {e}")
        print(f"\nError saving output file: {e}")
else:
    logging.info("No diagnoses were extracted from any record.")
    print("\nNo diagnoses were extracted from any record.")
    # Optionally create an empty file or a file indicating no results
    try:
        with open(OUTPUT_FILE, 'w') as f:
             f.write("patient id,entity_name\n") # Create header for empty file
        logging.info(f"Empty output file created: {OUTPUT_FILE}")
        print(f"Empty output file created: {OUTPUT_FILE}")
    except Exception as e:
        logging.error(f"Error creating empty output file: {e}")
        print(f"Error creating empty output file: {e}")



Extracted Diagnoses Sample:


Unnamed: 0,patient id,entity_name
0,4954,on [**2152-5-22**].
1,13342,intermittent mild grunting respirations. The
2,66296,[**2141
3,60,Prematurity
4,60,Triplet #3


2025-04-21 21:20:41,419 - INFO - Saving 10 extracted diagnosis entries to: extracted_diagnoses_v1.csv
2025-04-21 21:20:41,427 - INFO - Output file saved successfully.



Output saved successfully to extracted_diagnoses_v1.csv
