In [1]:
from tqdm.auto import tqdm
tqdm.__init__ = lambda *a, **k: __import__('tqdm').tqdm(*a, **{**k, "disable": False})


In [2]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

In [4]:
import torch
import time
import json
import re
from faster_whisper import WhisperModel
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

# ============================================
# GPU Memory Optimization
# ============================================
def print_gpu_usage(label=""):
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1024**3
        reserved = torch.cuda.memory_reserved() / 1024**3
        print(f"[{label}] GPU Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB")

# ============================================
# Load Whisper (GPU-optimized)
# ============================================
print("Loading Whisper model (large-v3 on GPU)...")
start = time.time()
model_dir = "./models/large-v3" 
whisper_model = WhisperModel(
    model_dir,
    device="cuda",
    compute_type="float16",  # FP16 for speed
    download_root="./models/whisper"
)

load_time = time.time() - start
print(f"Whisper loaded in {load_time:.2f} sec")
print_gpu_usage("After Whisper Load")

# ============================================
# Load Gemma (4-bit quantized on GPU)
# ============================================
print("\nLoading Gemma 1B (4-bit quantized on GPU)...")
start = time.time()

model_dir = "./models/gemma-3-1b-it"

# 4-bit quantization config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True
)

tokenizer = AutoTokenizer.from_pretrained(model_dir)
gemma_model = AutoModelForCausalLM.from_pretrained(
    model_dir,
    quantization_config=bnb_config,
    device_map="cuda:0",  # FORCE GPU!
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True
)

gemma_load_time = time.time() - start
print(f"Gemma loaded in {gemma_load_time:.2f} sec")
print_gpu_usage("After Gemma Load")

# ============================================
# Transcription Function
# ============================================
def transcribe_audio(audio_path, language="en"):
    """Transcribe audio with timing"""
    start = time.time()
    segments, info = whisper_model.transcribe(
        audio_path,
        language=language,
        beam_size=5,
        vad_filter=True  # Voice activity detection for better accuracy
    )
    
    transcript = " ".join([segment.text for segment in segments])
    elapsed = time.time() - start
    
    return transcript, elapsed

# ============================================
# Entity Extraction Function
# ============================================
def extract_entities(transcript):
    """Extract medical entities using Gemma"""
    
    system_prompt = """You are a medical prescription parser. Extract ONLY explicitly stated information.

Rules:
1. Extract medicines with EXACT dosages mentioned
2. If dosage/frequency unclear, mark as "unspecified"
3. Do NOT infer or assume any information
4. Output ONLY valid JSON, no markdown formatting

Output format:
{"medicines": [{"name": str, "dosage": str, "frequency": str, "duration": str}], "diseases": [str], "tests": [{"name": str, "timing": str}]}"""

    user_prompt = f"{system_prompt}\n\nExtract from:\n{transcript}\n\nJSON output:"
    
    start = time.time()
    
    inputs = tokenizer(user_prompt, return_tensors="pt").to(gemma_model.device)
    
    with torch.no_grad():
        outputs = gemma_model.generate(
            **inputs,
            max_new_tokens=512,
            do_sample=False,
            temperature=0.1,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id
        )
    
    result_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    elapsed = time.time() - start
    
    # Clean output (remove markdown code fences)
    result_text = result_text.split("JSON output:")[-1].strip()
    result_text = re.sub(r'```json\s*|\s*```', '', result_text).strip()
    
    # Parse JSON
    try:
        entities = json.loads(result_text)
        return entities, elapsed, None
    except json.JSONDecodeError as e:
        return None, elapsed, result_text

# ============================================
# Full Pipeline
# ============================================
def process_prescription(audio_path):
    """Complete pipeline with timing"""
    
    print("\n" + "="*80)
    print("PROCESSING PRESCRIPTION")
    print("="*80)
    
    pipeline_start = time.time()
    
    # English transcription
    print("\n1. Transcribing English pass...")
    english_transcript, english_time = transcribe_audio(audio_path, language="en")
    print(f"   ✓ English done in {english_time:.2f} sec")
    
    # Marathi transcription
    print("\n2. Transcribing Marathi pass...")
    marathi_transcript, marathi_time = transcribe_audio(audio_path, language="hi")  # 'hi' for Hindi/Marathi
    print(f"   ✓ Marathi done in {marathi_time:.2f} sec")
    
    # Entity extraction (use English transcript)
    print("\n3. Extracting medical entities...")
    entities, extraction_time, error = extract_entities(english_transcript)
    print(f"   ✓ Extraction done in {extraction_time:.2f} sec")
    
    total_time = time.time() - pipeline_start
    
    # ============================================
    # SUMMARY
    # ============================================
    print("\n" + "="*80)
    print("FINAL SUMMARY")
    print("="*80)
    print(f"Model Load Time    : {load_time + gemma_load_time:.2f} sec")
    print(f"English Pass Time  : {english_time:.2f} sec")
    print(f"Marathi Pass Time  : {marathi_time:.2f} sec")
    print(f"Entity Extraction  : {extraction_time:.2f} sec")
    print(f"TOTAL PIPELINE     : {total_time:.2f} sec")
    print("="*80)
    
    print_gpu_usage("Final")
    
    # Results
    print("\nENGLISH TRANSCRIPT:")
    print(english_transcript)
    print("\nMARATHI TRANSCRIPT:")
    print(marathi_transcript)
    print("\nEXTRACTED ENTITIES:")
    if entities:
        print(json.dumps(entities, indent=2))
    else:
        print(f"ERROR: {error}")
    
    return {
        "timings": {
            "model_load": load_time + gemma_load_time,
            "english": english_time,
            "marathi": marathi_time,
            "extraction": extraction_time,
            "total": total_time
        },
        "transcripts": {
            "english": english_transcript,
            "marathi": marathi_transcript
        },
        "entities": entities
    }

# ============================================
# Example Usage
# ============================================
if __name__ == "__main__":
    # Test with your audio file
    audio_file = r"E:\Projects\Med_Scribe\Testing\Mr_Patil_Medical_converstaino.m4a"
    
    result = process_prescription(audio_file)
    
    # Save results
    with open("prescription_output.json", "w") as f:
        json.dump(result, f, indent=2)
    
    print("\n✓ Results saved to prescription_output.json")


# ============================================
# EXPECTED PERFORMANCE (GPU)
# ============================================
"""
NVIDIA RTX 3050 (4GB VRAM) or better:
- Model Load: ~5-8 sec (one-time)
- English Pass: ~2-3 sec
- Marathi Pass: ~3-4 sec
- Entity Extraction: ~3-5 sec
- TOTAL: ~10-15 sec ✅ ACCEPTABLE

VRAM Usage:
- Whisper large-v3 (FP16): 3.8 GB
- Gemma 1B (4-bit): 0.4 GB
- Total: ~4.2 GB

If VRAM insufficient, switch to Whisper Medium:
whisper_model = WhisperModel("medium", device="cuda", compute_type="float16")
- VRAM: 1.5 GB (saves 2.3 GB!)
- Accuracy: ~5-10% worse
"""

Loading Whisper model (large-v3 on GPU)...
Whisper loaded in 6.06 sec
[After Whisper Load] GPU Allocated: 0.91 GB, Reserved: 1.66 GB

Loading Gemma 1B (4-bit quantized on GPU)...
Gemma loaded in 4.10 sec
[After Gemma Load] GPU Allocated: 1.24 GB, Reserved: 2.25 GB

PROCESSING PRESCRIPTION

1. Transcribing English pass...
   ✓ English done in 12.71 sec

2. Transcribing Marathi pass...
   ✓ Marathi done in 27.81 sec

3. Extracting medical entities...
   ✓ Extraction done in 60.85 sec

FINAL SUMMARY
Model Load Time    : 10.16 sec
English Pass Time  : 12.71 sec
Marathi Pass Time  : 27.81 sec
Entity Extraction  : 60.85 sec
TOTAL PIPELINE     : 101.37 sec
[Final] GPU Allocated: 0.91 GB, Reserved: 2.25 GB

ENGLISH TRANSCRIPT:
 Mr. Patil, after reading your reports, I can see that your fatty liver and sugar levels are high.  But don't worry. It's the early stage.  You take Metformin 500mg in the morning and evening after eating one tablet.  And take 2-3 spoons of Live 1252 syrup every day.  St

'\nNVIDIA RTX 3050 (4GB VRAM) or better:\n- Model Load: ~5-8 sec (one-time)\n- English Pass: ~2-3 sec\n- Marathi Pass: ~3-4 sec\n- Entity Extraction: ~3-5 sec\n- TOTAL: ~10-15 sec ✅ ACCEPTABLE\n\nVRAM Usage:\n- Whisper large-v3 (FP16): 3.8 GB\n- Gemma 1B (4-bit): 0.4 GB\n- Total: ~4.2 GB\n\nIf VRAM insufficient, switch to Whisper Medium:\nwhisper_model = WhisperModel("medium", device="cuda", compute_type="float16")\n- VRAM: 1.5 GB (saves 2.3 GB!)\n- Accuracy: ~5-10% worse\n'

In [9]:
import torch
import time
import json
import re
from faster_whisper import WhisperModel
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, StoppingCriteria, StoppingCriteriaList

# ============================================
# GPU UTILITIES
# ============================================
def print_gpu_usage(label=""):
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1024**3
        reserved = torch.cuda.memory_reserved() / 1024**3
        print(f"[{label}] GPU Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB")

# ============================================
# LOAD WHISPER
# ============================================
print("Loading Whisper model (large-v3)...")
start = time.time()
whisper_model = WhisperModel(
    "./models/large-v3",
    device="cuda",
    compute_type="float16",
    download_root="./models/whisper"
)
load_time_whisper = time.time() - start
print(f"Whisper loaded in {load_time_whisper:.2f} sec")
print_gpu_usage("After Whisper Load")

# ============================================
# LOAD GEMMA (4-BIT QUANTIZED)
# ============================================
print("\nLoading Gemma 1B (4-bit quantized)...")
start = time.time()
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True
)

model_dir = "./models/gemma-3-1b-it"
tokenizer = AutoTokenizer.from_pretrained(model_dir)
gemma_model = AutoModelForCausalLM.from_pretrained(
    model_dir,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True
)
load_time_gemma = time.time() - start
print(f"Gemma loaded in {load_time_gemma:.2f} sec")
print_gpu_usage("After Gemma Load")

# ============================================
# TRANSCRIPTION (CHUNKED)
# ============================================
def transcribe_audio(audio_path, language="en"):
    """Chunked transcription with Whisper"""
    start = time.time()
    segments, _ = whisper_model.transcribe(
        audio_path,
        language=language,
        beam_size=5,
        vad_filter=True
    )
    # Combine segments for chunked processing
    transcript_chunks = [segment.text.strip() for segment in segments if segment.text.strip()]
    elapsed = time.time() - start
    return transcript_chunks, elapsed

# ============================================
# STREAMING GEMMA ENTITY EXTRACTION
# ============================================
def extract_entities(transcript_chunks):
    """Chunk-wise entity extraction with strict JSON output"""
    system_prompt = """You are a medical prescription parser. Extract ONLY explicitly stated information.

Rules:
1. Extract medicines with EXACT dosages mentioned
2. If dosage/frequency unclear, mark as "unspecified"
3. Do NOT infer or assume any information
4. Output ONLY valid JSON with the following format:
{
  "medicines": [{"name": str, "dosage": str, "frequency": str, "duration": str}],
  "diseases": [str],
  "tests": [{"name": str, "timing": str}]
}"""

    final_entities = {
        "medicines": [],
        "diseases": [],
        "tests": []
    }

    total_time = 0
    for chunk in transcript_chunks:
        user_prompt = f"{system_prompt}\n\nExtract from:\n{chunk}\n\nJSON output:"

        inputs = tokenizer(user_prompt, return_tensors="pt").to(gemma_model.device)

        start = time.time()
        with torch.no_grad():
            outputs = gemma_model.generate(
                **inputs,
                max_new_tokens=256,  # limit per chunk
                do_sample=False,
                temperature=0.1,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id
            )
        elapsed = time.time() - start
        total_time += elapsed

        result_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        result_text = result_text.split("JSON output:")[-1].strip()
        result_text = re.sub(r'```json\s*|\s*```', '', result_text).strip()

        try:
            entities = json.loads(result_text)
            # Merge results
            final_entities["medicines"].extend(entities.get("medicines", []))
            final_entities["diseases"].extend([d for d in entities.get("diseases", []) if d not in final_entities["diseases"]])
            final_entities["tests"].extend(entities.get("tests", []))
        except json.JSONDecodeError:
            # Skip invalid chunk, optionally log
            print("⚠️ JSON decode failed for chunk, skipping")

    return final_entities, total_time

# ============================================
# FULL PIPELINE
# ============================================
def process_prescription(audio_path):
    print("\n" + "="*80)
    print("PROCESSING PRESCRIPTION")
    print("="*80)
    pipeline_start = time.time()

    # English transcription
    print("\n1. Transcribing English pass...")
    english_chunks, english_time = transcribe_audio(audio_path, language="en")
    english_transcript = " ".join(english_chunks)
    print(f"   ✓ English done in {english_time:.2f} sec")

    # Marathi transcription (optional, skip if not required)
    print("\n2. Transcribing Marathi pass...")
    # marathi_chunks, marathi_time = transcribe_audio(audio_path, language="hi")
    # marathi_transcript = " ".join(marathi_chunks)
    # print(f"   ✓ Marathi done in {marathi_time:.2f} sec")

    # Entity extraction (chunked)
    print("\n3. Extracting medical entities...")
    # entities, extraction_time = extract_entities(english_chunks)
    # print(f"   ✓ Extraction done in {extraction_time:.2f} sec")

    total_time = time.time() - pipeline_start

    # ============================================
    # SUMMARY
    # ============================================
    print("\n" + "="*80)
    print("FINAL SUMMARY")
    print("="*80)
    print(f"Model Load Time    : {load_time_whisper + load_time_gemma:.2f} sec")
    print(f"English Pass Time  : {english_time:.2f} sec")
    # print(f"Marathi Pass Time  : {marathi_time:.2f} sec")
    # print(f"Entity Extraction  : {extraction_time:.2f} sec")
    print(f"TOTAL PIPELINE     : {total_time:.2f} sec")
    print("="*80)
    print_gpu_usage("Final")

    # Results
    print("\nENGLISH TRANSCRIPT:")
    print(english_transcript)
    print("\nMARATHI TRANSCRIPT:")
    print(marathi_transcript)
    print("\nEXTRACTED ENTITIES:")
    print(json.dumps(entities, indent=2))

    return {
        "timings": {
            "model_load": load_time_whisper + load_time_gemma,
            "english": english_time,
            # "marathi": marathi_time,
            # "extraction": extraction_time,
            "total": total_time
        },
        "transcripts": {
            "english": english_transcript,
            "marathi": marathi_transcript
        },
        "entities": entities
    }

# ============================================
# MAIN
# ============================================
if __name__ == "__main__":
    audio_file = r"E:\Projects\Med_Scribe\Testing\Mr_Patil_Medical_converstaino.m4a"
    result = process_prescription(audio_file)

    # Save output
    with open("prescription_output.json", "w") as f:
        json.dump(result, f, indent=2)

    print("\n✓ Results saved to prescription_output.json")


Loading Whisper model (large-v3)...
Whisper loaded in 5.61 sec
[After Whisper Load] GPU Allocated: 0.91 GB, Reserved: 3.43 GB

Loading Gemma 1B (4-bit quantized)...
Gemma loaded in 7.81 sec
[After Gemma Load] GPU Allocated: 1.24 GB, Reserved: 3.43 GB

PROCESSING PRESCRIPTION

1. Transcribing English pass...
   ✓ English done in 30.28 sec

2. Transcribing Marathi pass...

3. Extracting medical entities...

FINAL SUMMARY
Model Load Time    : 13.42 sec
English Pass Time  : 30.28 sec
TOTAL PIPELINE     : 30.28 sec
[Final] GPU Allocated: 0.91 GB, Reserved: 3.43 GB

ENGLISH TRANSCRIPT:
Mr. Patil, after reading your reports, I can see that your fatty liver and sugar levels are high. But don't worry. It's the early stage. You take Metformin 500mg in the morning and evening after eating one tablet. And take 2-3 spoons of Live 1252 syrup every day. Stop eating oily and sugary foods. Take a walk every day for 30 minutes. One more thing, do an ultrasound of your abdomen for the next visit. I want 

NameError: name 'marathi_transcript' is not defined

In [10]:
import torch
import time
import json
import re
from faster_whisper import WhisperModel
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

# ============================================
# GPU UTILITIES
# ============================================
def print_gpu_usage(label=""):
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1024**3
        reserved = torch.cuda.memory_reserved() / 1024**3
        print(f"[{label}] GPU Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB")

# ============================================
# LOAD WHISPER
# ============================================
print("Loading Whisper model (large-v3)...")
start = time.time()
whisper_model = WhisperModel(
    "./models/large-v3",
    device="cuda",
    compute_type="float16",
    download_root="./models/whisper"
)
load_time_whisper = time.time() - start
print(f"Whisper loaded in {load_time_whisper:.2f} sec")
print_gpu_usage("After Whisper Load")

# ============================================
# LOAD GEMMA (4-BIT QUANTIZED)
# ============================================
print("\nLoading Gemma 1B (4-bit quantized)...")
start = time.time()
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True
)

model_dir = "./models/gemma-3-1b-it"
tokenizer = AutoTokenizer.from_pretrained(model_dir)
gemma_model = AutoModelForCausalLM.from_pretrained(
    model_dir,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True
)
load_time_gemma = time.time() - start
print(f"Gemma loaded in {load_time_gemma:.2f} sec")
print_gpu_usage("After Gemma Load")

# ============================================
# TRANSCRIPTION (CHUNKED)
# ============================================
def transcribe_audio(audio_path, language="en"):
    """Chunked transcription with Whisper"""
    start = time.time()
    segments, _ = whisper_model.transcribe(
        audio_path,
        language=language,
        beam_size=5,
        vad_filter=True
    )
    # Combine segments for chunked processing
    transcript_chunks = [segment.text.strip() for segment in segments if segment.text.strip()]
    elapsed = time.time() - start
    return transcript_chunks, elapsed

# ============================================
# JSON EXTRACTION UTILITIES
# ============================================
def extract_json_from_text(text):
    """Extract JSON from text with multiple fallback strategies"""
    # Strategy 1: Find JSON between braces
    json_match = re.search(r'\{[\s\S]*\}', text)
    if json_match:
        try:
            return json.loads(json_match.group())
        except json.JSONDecodeError:
            pass
    
    # Strategy 2: Try to find and fix common issues
    try:
        # Remove markdown code blocks
        cleaned = re.sub(r'```(?:json)?\s*|\s*```', '', text)
        # Remove text before first { and after last }
        cleaned = re.search(r'\{[\s\S]*\}', cleaned)
        if cleaned:
            return json.loads(cleaned.group())
    except (json.JSONDecodeError, AttributeError):
        pass
    
    # Strategy 3: Return empty structure
    return None

def merge_entities(target, source):
    """Merge entity dictionaries without duplicates"""
    # Merge medicines (check for duplicate names)
    existing_med_names = {m.get("name", "").lower() for m in target["medicines"]}
    for med in source.get("medicines", []):
        if med.get("name", "").lower() not in existing_med_names:
            target["medicines"].append(med)
            existing_med_names.add(med.get("name", "").lower())
    
    # Merge diseases (unique only)
    for disease in source.get("diseases", []):
        if disease and disease not in target["diseases"]:
            target["diseases"].append(disease)
    
    # Merge tests (check for duplicate names)
    existing_test_names = {t.get("name", "").lower() for t in target["tests"]}
    for test in source.get("tests", []):
        if test.get("name", "").lower() not in existing_test_names:
            target["tests"].append(test)
            existing_test_names.add(test.get("name", "").lower())

# ============================================
# STREAMING GEMMA ENTITY EXTRACTION
# ============================================
def extract_entities(transcript_chunks, max_retries=2):
    """Chunk-wise entity extraction with robust JSON parsing"""
    
    # Optimized: Create prompt template once
    system_prompt = """Extract medical information in JSON format:
{"medicines":[{"name":"","dosage":"","frequency":"","duration":""}],"diseases":[],"tests":[{"name":"","timing":""}]}

Rules: Only extract explicitly stated info. Mark unclear info as "unspecified"."""

    final_entities = {
        "medicines": [],
        "diseases": [],
        "tests": []
    }

    total_time = 0
    failed_chunks = 0
    
    for idx, chunk in enumerate(transcript_chunks):
        if not chunk or len(chunk.strip()) < 10:  # Skip very short chunks
            continue
            
        # Optimized: Shorter, clearer prompt
        user_prompt = f"{system_prompt}\n\nText: {chunk}\n\nJSON:"

        inputs = tokenizer(user_prompt, return_tensors="pt").to(gemma_model.device)

        retry_count = 0
        chunk_success = False
        
        while retry_count < max_retries and not chunk_success:
            start = time.time()
            
            with torch.no_grad():
                outputs = gemma_model.generate(
                    **inputs,
                    max_new_tokens=256,
                    do_sample=(retry_count > 0),  # First try deterministic, then sample
                    temperature=0.3 if retry_count > 0 else 0.1,
                    pad_token_id=tokenizer.pad_token_id,
                    eos_token_id=tokenizer.eos_token_id
                )
            
            elapsed = time.time() - start
            total_time += elapsed

            result_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
            
            # Extract only the JSON part (after "JSON:")
            if "JSON:" in result_text:
                result_text = result_text.split("JSON:")[-1].strip()
            
            # Try to extract and parse JSON
            entities = extract_json_from_text(result_text)
            
            if entities:
                # Validate structure
                if isinstance(entities, dict):
                    merge_entities(final_entities, entities)
                    chunk_success = True
                    print(f"   ✓ Chunk {idx+1}/{len(transcript_chunks)} processed")
                else:
                    retry_count += 1
            else:
                retry_count += 1
            
            if not chunk_success and retry_count >= max_retries:
                failed_chunks += 1
                print(f"   ⚠ Chunk {idx+1} failed after {max_retries} retries, skipping")
                break
        
        # Clear GPU cache periodically
        if idx % 5 == 0:
            torch.cuda.empty_cache()

    if failed_chunks > 0:
        print(f"\n⚠️ Warning: {failed_chunks}/{len(transcript_chunks)} chunks failed to parse")

    return final_entities, total_time

# ============================================
# FULL PIPELINE
# ============================================
def process_prescription(audio_path):
    print("\n" + "="*80)
    print("PROCESSING PRESCRIPTION")
    print("="*80)
    pipeline_start = time.time()

    # English transcription
    print("\n1. Transcribing English pass...")
    english_chunks, english_time = transcribe_audio(audio_path, language="en")
    english_transcript = " ".join(english_chunks)
    print(f"   ✓ English done in {english_time:.2f} sec ({len(english_chunks)} chunks)")

    # Marathi transcription (optional, skip if not required)
    print("\n2. Transcribing Marathi pass...")
    marathi_chunks, marathi_time = transcribe_audio(audio_path, language="hi")
    marathi_transcript = " ".join(marathi_chunks)
    print(f"   ✓ Marathi done in {marathi_time:.2f} sec ({len(marathi_chunks)} chunks)")

    # Entity extraction (chunked)
    print("\n3. Extracting medical entities...")
    entities, extraction_time = extract_entities(english_chunks, max_retries=2)
    print(f"   ✓ Extraction done in {extraction_time:.2f} sec")

    total_time = time.time() - pipeline_start

    # ============================================
    # SUMMARY
    # ============================================
    print("\n" + "="*80)
    print("FINAL SUMMARY")
    print("="*80)
    print(f"Model Load Time    : {load_time_whisper + load_time_gemma:.2f} sec")
    print(f"English Pass Time  : {english_time:.2f} sec")
    print(f"Marathi Pass Time  : {marathi_time:.2f} sec")
    print(f"Entity Extraction  : {extraction_time:.2f} sec")
    print(f"TOTAL PIPELINE     : {total_time:.2f} sec")
    print("="*80)
    print_gpu_usage("Final")

    # Results
    print("\nENGLISH TRANSCRIPT:")
    print(english_transcript)
    print("\nMARATHI TRANSCRIPT:")
    print(marathi_transcript)
    print("\nEXTRACTED ENTITIES:")
    print(json.dumps(entities, indent=2, ensure_ascii=False))

    return {
        "timings": {
            "model_load": load_time_whisper + load_time_gemma,
            "english": english_time,
            "marathi": marathi_time,
            "extraction": extraction_time,
            "total": total_time
        },
        "transcripts": {
            "english": english_transcript,
            "marathi": marathi_transcript
        },
        "entities": entities
    }

# ============================================
# MAIN
# ============================================
if __name__ == "__main__":
    audio_file = r"E:\Projects\Med_Scribe\Testing\Mr_Patil_Medical_converstaino.m4a"
    
    try:
        result = process_prescription(audio_file)

        # Save output
        with open("prescription_output.json", "w", encoding="utf-8") as f:
            json.dump(result, f, indent=2, ensure_ascii=False)

        print("\n✓ Results saved to prescription_output.json")
        
    except Exception as e:
        print(f"\n❌ Error: {e}")
        import traceback
        traceback.print_exc()
    finally:
        # Cleanup
        torch.cuda.empty_cache()

Loading Whisper model (large-v3)...
Whisper loaded in 7.06 sec
[After Whisper Load] GPU Allocated: 0.91 GB, Reserved: 3.43 GB

Loading Gemma 1B (4-bit quantized)...
Gemma loaded in 4.43 sec
[After Gemma Load] GPU Allocated: 1.24 GB, Reserved: 3.43 GB

PROCESSING PRESCRIPTION

1. Transcribing English pass...
   ✓ English done in 18.05 sec (12 chunks)

2. Transcribing Marathi pass...
   ✓ Marathi done in 54.94 sec (2 chunks)

3. Extracting medical entities...
   ✓ Chunk 1/12 processed
   ⚠ Chunk 2 failed after 2 retries, skipping
   ⚠ Chunk 3 failed after 2 retries, skipping
   ⚠ Chunk 4 failed after 2 retries, skipping

❌ Error: 'str' object has no attribute 'get'


Traceback (most recent call last):
  File "C:\Users\Shiva\AppData\Local\Temp\ipykernel_36212\335684099.py", line 277, in <module>
    result = process_prescription(audio_file)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Shiva\AppData\Local\Temp\ipykernel_36212\335684099.py", line 228, in process_prescription
    entities, extraction_time = extract_entities(english_chunks, max_retries=2)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\Shiva\AppData\Local\Temp\ipykernel_36212\335684099.py", line 183, in extract_entities
    merge_entities(final_entities, entities)
  File "C:\Users\Shiva\AppData\Local\Temp\ipykernel_36212\335684099.py", line 106, in merge_entities
    if med.get("name", "").lower() not in existing_med_names:
       ^^^^^^^
AttributeError: 'str' object has no attribute 'get'
