In [1]:
import torch
import time
import json
import re
from faster_whisper import WhisperModel,BatchedInferencePipeline
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig


  import pkg_resources


In [3]:
whisper_model = WhisperModel(
     "./models/large-v3",
     device = "cuda",
     compute_type="float16",
     download_root="./models/whisper"
)
print("Whisper model loaded")
batched_model = BatchedInferencePipeline(model=whisper_model)
print("COnverted to batched")

Whisper model loaded
COnverted to batched


In [4]:
segments, info = batched_model.transcribe(    
    language='en',
    audio= r'E:\Projects\Med_Scribe\Testing\Mr_Patil_Medical_converstaino.m4a',
    beam_size=4,
    vad_filter=True,
    batch_size=8,
    word_timestamps=False)


In [5]:
for segment in segments:
    print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text))


[1.23s -> 28.99s]  Mr. Patil, after reading your reports, I can see that your fatty liver and sugar levels are high. But, don't worry. It's the early stage. Take Metformin 500mg in the morning and evening after eating one tablet. And take 2 tsp of live 1252 syrup twice a day. Stop eating oily and sugary foods.
[28.99s -> 54.26s]  Walk for 30 minutes. One more thing, do ultrasound of abdomen on next visit. To check your liver condition. Take medicine continuously for 30 days and follow up. And yes, take food on time. Don't eat late at night. Otherwise sugar control won't happen. Okay? Let's see.


In [6]:
for segment in segments:
    for word in segment.words:
        print("[%.2fs -> %.2fs] %s" % (word.start, word.end, word.word))

In [7]:
segments_list = list(segments)
print(segments_list)

[]


In [8]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True
    
)

model_dir = "./models/gemma-3-1b-it"
tokenizer = AutoTokenizer.from_pretrained(model_dir)
gemma_model = AutoModelForCausalLM.from_pretrained(model_dir,
                                                   quantization_config=bnb_config,
                                                   device_map = "auto",
                                                   torch_dtype=torch.float16,
                                                   low_cpu_mem_usage=True)
print("Gemma loaded")

`torch_dtype` is deprecated! Use `dtype` instead!


Gemma loaded


In [9]:
segments,_ = whisper_model.transcribe(
    language='en',
    audio= r'E:\Projects\Med_Scribe\Testing\Mr_Patil_Medical_converstaino.m4a',
    beam_size=4,
    vad_filter=True,
    

)

print(segments)
transcript_chunks = [segment.text.strip() for segment in segments if segment.text.strip()]
print(58*"=")
print(transcript_chunks)
full_transcript = " ".join(transcript_chunks)
print(58*"=")
print(full_transcript)
print(58*"=")
segments_list = list(segments)

for segment in segments:
    print(f"Start: {segment['start']}, End: {segment['end']}, Text: {segment['text']}")



<generator object restore_speech_timestamps at 0x000002D95E639120>
['Mr. Patil, after reading your reports, I can see that your fatty liver and sugar levels are high.', "But don't worry. It's the early stage.", 'You take Metformin 500mg in the morning and evening after eating one tablet.', 'And take 2-3 spoons of Live 1252 syrup every day.', 'Stop eating oily and sugary foods.', 'Take a walk every day for 30 minutes.', 'One more thing, do an ultrasound of your abdomen for the next visit.', "I want to see your liver condition and what's going on with it.", 'And take medicine continuously for 30 days and follow up.', "And yes, don't eat late at night.", "Otherwise, you won't have sugar control.", "Okay? Let's see."]
Mr. Patil, after reading your reports, I can see that your fatty liver and sugar levels are high. But don't worry. It's the early stage. You take Metformin 500mg in the morning and evening after eating one tablet. And take 2-3 spoons of Live 1252 syrup every day. Stop eating 

In [10]:
segments_list = list(segments)

# Now you can access each segment's data
for segment in segments_list:
    print(segment)
segments_list

[]

In [11]:
from transformers import StoppingCriteria, StoppingCriteriaList

class StopOnBackticks(StoppingCriteria):
    def __init__(self, tokenizer, stop_sequence="AAA"):
        self.tokenizer = tokenizer
        self.stop_sequence = stop_sequence
        self.stop_ids = tokenizer.encode(stop_sequence, add_special_tokens=False)

    def __call__(self, input_ids, scores, **kwargs):
        # Check if the last tokens match the stop sequence
        if len(input_ids[0]) >= len(self.stop_ids):
            if (input_ids[0][-len(self.stop_ids):] == torch.tensor(self.stop_ids, device=input_ids.device)).all():
                return True
        return False

stopping_criteria = StoppingCriteriaList([StopOnBackticks(tokenizer)])

In [12]:
system_prompt = """
You are a medical prescription parser. Extract ONLY information explicitly stated.

Rules:
1. Extract medicines with EXACT dosages mentioned
2. If dosage/frequency unclear, mark as "unspecified"
3. Do NOT infer or assume any information
4. If doctor says "continue previous meds", extract NOTHING
5. Output only one valid JSON object and stop
6. At the end of the Output print AAA


Output format:
{
  "medicines": [{"name": str, "dosage": str, "frequency": str, "duration": str}],
  "diseases": [str],
  "tests": [{"name": str, "timing": str}]
}
"""

final_entities = {
    "medicines" : [],
    "diseases" : [],
    "tests" : []
}

user_prompt = f"""
{system_prompt}
Extract from this prescription conversation:
{full_transcript}

Remember: Only extract explicitly stated information. No assumptions.
"""
# user_prompt = f"{system_prompt}\n\n Text:{full_transcript}"
inputs = tokenizer(user_prompt,return_tensors = 'pt' ).to(gemma_model.device)
outputs = gemma_model.generate(
    **inputs,
    max_new_tokens=512,
    temperature=0.01,
    pad_token_id = tokenizer.pad_token_id,
    eos_token_id = tokenizer.eos_token_id,
    stopping_criteria=stopping_criteria,
    do_sample=False

)
result_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(result_text)


The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.




You are a medical prescription parser. Extract ONLY information explicitly stated.

Rules:
1. Extract medicines with EXACT dosages mentioned
2. If dosage/frequency unclear, mark as "unspecified"
3. Do NOT infer or assume any information
4. If doctor says "continue previous meds", extract NOTHING
5. Output only one valid JSON object and stop
6. At the end of the Output print AAA


Output format:
{
  "medicines": [{"name": str, "dosage": str, "frequency": str, "duration": str}],
  "diseases": [str],
  "tests": [{"name": str, "timing": str}]
}

Extract from this prescription conversation:
Mr. Patil, after reading your reports, I can see that your fatty liver and sugar levels are high. But don't worry. It's the early stage. You take Metformin 500mg in the morning and evening after eating one tablet. And take 2-3 spoons of Live 1252 syrup every day. Stop eating oily and sugary foods. Take a walk every day for 30 minutes. One more thing, do an ultrasound of your abdomen for the next visit.

In [None]:
import torchaudio
print(torchaudio.list_audio_backends())
print(torchaudio.get_audio_backend())

import torchaudio
torchaudio.set_audio_backend("soundfile")




['soundfile']
None


  print(torchaudio.get_audio_backend())
  torchaudio.set_audio_backend("soundfile")


loads the diarization part

In [3]:
import torchaudio
import torch
from speechbrain.pretrained import EncoderClassifier


torchaudio.set_audio_backend("soundfile")

signal, fs = torchaudio.load(r"E:\Projects\Med_Scribe\Testing\audio.wav")
print(signal.shape, fs)

# Load local ECAPA model
model_dir = "pretrained_models/spkrec-ecapa-voxceleb"
classifier = EncoderClassifier.from_hparams(source=model_dir, savedir=None, run_opts={"device":"cuda"})

# Load audio
signal, fs = torchaudio.load(r"E:\Projects\Med_Scribe\Testing\audio.wav")

# Get speaker embeddings
embeddings = classifier.encode_batch(signal)  # shape: (1, embedding_size, frames)
print("Embeddings shape:", embeddings.shape)

# You can now cluster embeddings for speaker diarization


  torchaudio.set_audio_backend("soundfile")


torch.Size([1, 2642944]) 48000
Embeddings shape: torch.Size([1, 1, 192])


diarization part working

In [5]:
import torch
from speechbrain.pretrained import EncoderClassifier
from sklearn.cluster import AgglomerativeClustering
import torchaudio

# Load your local model
model_dir = "pretrained_models/spkrec-ecapa-voxceleb"
classifier = EncoderClassifier.from_hparams(source=model_dir, savedir=None, run_opts={"device":"cuda"})

# Load audio
signal, fs = torchaudio.load(r"E:\Projects\Med_Scribe\Testing\audio.wav")

# Get embeddings
embeddings = classifier.encode_batch(signal)  # shape: (1, embedding_size, frames)
embeddings = embeddings.squeeze(0).T.cpu().detach().numpy()  # shape: (frames, embedding_size)

# Cluster embeddings to diarize
num_speakers = 2  # change as needed
clustering = AgglomerativeClustering(n_clusters=num_speakers)
labels = clustering.fit_predict(embeddings)

# Convert labels to timestamps (rough approximation)
frame_duration = signal.shape[1] / fs / embeddings.shape[0]
for i, label in enumerate(labels):
    start_time = i * frame_duration
    end_time = (i + 1) * frame_duration
    print(f"{start_time:.2f}s - {end_time:.2f}s: Speaker {label}")


0.00s - 0.29s: Speaker 1
0.29s - 0.57s: Speaker 1
0.57s - 0.86s: Speaker 0
0.86s - 1.15s: Speaker 0
1.15s - 1.43s: Speaker 1
1.43s - 1.72s: Speaker 1
1.72s - 2.01s: Speaker 1
2.01s - 2.29s: Speaker 0
2.29s - 2.58s: Speaker 0
2.58s - 2.87s: Speaker 0
2.87s - 3.15s: Speaker 1
3.15s - 3.44s: Speaker 1
3.44s - 3.73s: Speaker 0
3.73s - 4.01s: Speaker 1
4.01s - 4.30s: Speaker 1
4.30s - 4.59s: Speaker 1
4.59s - 4.88s: Speaker 0
4.88s - 5.16s: Speaker 1
5.16s - 5.45s: Speaker 0
5.45s - 5.74s: Speaker 1
5.74s - 6.02s: Speaker 1
6.02s - 6.31s: Speaker 0
6.31s - 6.60s: Speaker 1
6.60s - 6.88s: Speaker 1
6.88s - 7.17s: Speaker 1
7.17s - 7.46s: Speaker 0
7.46s - 7.74s: Speaker 0
7.74s - 8.03s: Speaker 1
8.03s - 8.32s: Speaker 1
8.32s - 8.60s: Speaker 0
8.60s - 8.89s: Speaker 0
8.89s - 9.18s: Speaker 0
9.18s - 9.46s: Speaker 0
9.46s - 9.75s: Speaker 0
9.75s - 10.04s: Speaker 1
10.04s - 10.32s: Speaker 1
10.32s - 10.61s: Speaker 1
10.61s - 10.90s: Speaker 1
10.90s - 11.18s: Speaker 0
11.18s - 11.47s:

In [27]:
import torch
from speechbrain.pretrained import EncoderClassifier
from sklearn.cluster import AgglomerativeClustering
import torchaudio
from faster_whisper import WhisperModel, BatchedInferencePipeline
import numpy as np

# 1️⃣ Load speaker embedding model
model_dir = "pretrained_models/spkrec-ecapa-voxceleb"
classifier = EncoderClassifier.from_hparams(source=model_dir, savedir=None, run_opts={"device":"cuda"})

# 2️⃣ Load audio
audio_path = r"E:\Projects\Med_Scribe\Testing\audio.wav"
signal, fs = torchaudio.load(audio_path)

# 3️⃣ Compute speaker embeddings
window = int(fs * 3.0)   # 3 sec window
stride = int(fs * 1.5)   # 50% overlap
embeddings = []

for start in range(0, signal.shape[1] - window, stride):
    chunk = signal[:, start:start + window]
    emb = classifier.encode_batch(chunk).squeeze(0).cpu().detach().numpy()
    embeddings.append(emb)

embeddings = np.vstack(embeddings)


# 4️⃣ Cluster embeddings to assign speaker labels
from sklearn.metrics import silhouette_score
best_score, best_k = -1, 1
for k in range(1, 6):  # try up to 5 speakers
    tmp_labels = AgglomerativeClustering(n_clusters=k).fit_predict(embeddings)
    if len(set(tmp_labels)) > 1:
        score = silhouette_score(embeddings, tmp_labels)
        if score > best_score:
            best_score, best_k = score, k

clustering = AgglomerativeClustering(n_clusters=best_k)
labels = clustering.fit_predict(embeddings)
print(f"Detected ~{best_k} speakers")


# 5️⃣ Frame-to-time mapping
frame_duration = signal.shape[1] / fs / embeddings.shape[0]
diarization_segments = []
current_label = labels[0]
start_time = 0.0

for i, label in enumerate(labels):
    if label != current_label:
        end_time = i * frame_duration
        diarization_segments.append((start_time, end_time, current_label))
        start_time = end_time
        current_label = label
diarization_segments.append((start_time, len(labels) * frame_duration, current_label))

# 6️⃣ Load Faster-Whisper
whisper_model = WhisperModel(
    "./models/large-v3",
    device="cuda",
    compute_type="float16",
    download_root="./models/whisper",
)
batched_model = BatchedInferencePipeline(model=whisper_model)

# 7️⃣ Transcribe audio





  state_dict = torch.load(path, map_location=device)
  stats = torch.load(path, map_location=device)
  stats = torch.load(path, map_location=device)


Detected ~5 speakers


In [28]:
segments, info = batched_model.transcribe(audio_path, batch_size=16, language='en')  # adjust batch_size if needed

In [29]:
# Process each transcription segment
for trans_segment in segments:
    start = trans_segment.start
    end = trans_segment.end
    
    # Find the speaker segment that overlaps the most with this transcription
    max_overlap = 0
    best_speaker = None
    
    # Use the diarization segments (stored earlier) for speaker labels
    for spk_segment in diarization_segments:  # This should be the list created earlier with (start, end, label) tuples
        spk_start, spk_end, spk_label = spk_segment  # Now this unpacking will work
        
        # Calculate overlap
        overlap_start = max(start, spk_start)
        overlap_end = min(end, spk_end)
        
        if overlap_end > overlap_start:
            overlap_duration = overlap_end - overlap_start
            if overlap_duration > max_overlap:
                max_overlap = overlap_duration
                best_speaker = spk_label
    
    if best_speaker is not None:
        print(f"Speaker {best_speaker + 1}: {trans_segment.text.strip()}")
    else:
        print(f"Unknown Speaker: {trans_segment.text.strip()}")

Speaker 4: Mr. Patil, after reading your reports, I can see that your fatty liver and sugar levels are high. But, don't worry. It's the early stage. Take Metformin 500mg in the morning and evening after eating one tablet. And take 2 tsp of live 1252 syrup twice a day. Stop eating oily and sugary foods.
Speaker 1: Walk for 30 minutes. One more thing, do Ultrasound of Abdomen on next visit. I want to see your liver condition, what is the condition. And take medicine continuously for 30 days and follow up. And yes, take food on time. Don't eat late at night. Otherwise, sugar control won't happen. Okay? Let's see.


In [33]:
# robust_diarize_and_label.py
import os
import numpy as np
import torch
import torchaudio
from speechbrain.pretrained import EncoderClassifier
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics import silhouette_score
from scipy.signal import medfilt
from faster_whisper import WhisperModel, BatchedInferencePipeline
from math import ceil

# -------- CONFIG --------
MODEL_DIR = "pretrained_models/spkrec-ecapa-voxceleb"   # local ECAPA folder
AUDIO_PATH = r"E:\Projects\Med_Scribe\Testing\audio.wav"  # local audio (wav or m4a if ffmpeg backend)
WHISPER_MODEL_PATH = "./models/large-v3"                 # your faster-whisper model folder
DEVICE = "cuda"                                         # or "cpu"
WINDOW_SEC = 30.0        # longer context to reduce pitch-sensitivity
OVERLAP_SEC = 4.0       # 50% overlap
MAX_SPEAKERS = 5
SIM_THRESHOLD = 0.80   # cosine similarity threshold for reassignment
MEDIAN_KERNEL_SEC = 3.0 # smoothing kernel in seconds (odd integer multiple of frames)
# ------------------------

def load_audio(path):
    sig, sr = torchaudio.load(path)
    # ensure mono
    if sig.shape[0] > 1:
        sig = sig.mean(dim=0, keepdim=True)
    return sig, sr

def chunk_audio(signal, sr, window_sec, overlap_sec):
    win = int(window_sec * sr)
    stride = int((window_sec - overlap_sec) * sr)
    total = signal.shape[1]
    starts = list(range(0, max(1, total - win + 1), stride))
    if starts[-1] + win < total:
        starts.append(max(0, total - win))
    chunks = []
    times = []
    for s in starts:
        e = min(s + win, total)
        chunks.append(signal[:, s:e])
        times.append((s / sr, e / sr))
    return chunks, times

def get_embeddings(classifier, chunks):
    embs = []
    for chunk in chunks:
        # classifier.encode_batch accepts waveform like (batch, time) or (channels, time) as you used earlier
        with torch.no_grad():
            emb = classifier.encode_batch(chunk)  # returns tensor-like
        # normalize & flatten to vector
        emb = torch.tensor(emb).squeeze().cpu().numpy()
        emb = emb.reshape(-1)  # ensure 1D
        emb = emb / (np.linalg.norm(emb) + 1e-8)
        embs.append(emb)
    return np.vstack(embs)  # shape: (n_chunks, emb_dim)

def estimate_num_speakers(embeddings, max_k=5):
    from sklearn.metrics import silhouette_score
    from sklearn.cluster import AgglomerativeClustering
    best_k, best_score = 2, -1

    n_samples = len(embeddings)
    max_k = min(max_k, n_samples - 1)  # ✅ prevent invalid k

    for k in range(2, max_k + 1):
        labels = AgglomerativeClustering(n_clusters=k).fit_predict(embeddings)
        score = silhouette_score(embeddings, labels)
        if score > best_score:
            best_k, best_score = k, score
    return best_k

def cluster_and_reassign(embeddings, sim_threshold, max_k):
    # estimate K
    k = estimate_num_speakers(embeddings, max_k)
    if k == 1:
        labels = np.zeros(len(embeddings), dtype=int)
    else:
        labels = AgglomerativeClustering(n_clusters=k).fit_predict(embeddings)

    # compute means
    unique = sorted(set(labels))
    means = {u: embeddings[labels == u].mean(axis=0) for u in unique}
    for u in means:
        means[u] = means[u] / (np.linalg.norm(means[u]) + 1e-8)

    # reassignment by cosine similarity to means (stability hack)
    for i, emb in enumerate(embeddings):
        sims = {u: float(np.dot(emb, means[u])) for u in unique}
        best_u, best_sim = max(sims.items(), key=lambda x: x[1])
        if best_sim >= sim_threshold:
            labels[i] = best_u
        # else keep original cluster (rare)
    return labels, k

def labels_to_segments(labels, times):
    segs = []
    cur_label = labels[0]
    cur_start = times[0][0]
    for i in range(1, len(labels)):
        if labels[i] != cur_label:
            cur_end = times[i][0]  # end is start of current chunk
            segs.append((cur_start, cur_end, cur_label))
            cur_label = labels[i]
            cur_start = times[i][0]
    # finish
    segs.append((cur_start, times[-1][1], cur_label))
    return segs

def median_smooth_labels(labels, times, kernel_sec):
    # map labels to per-chunk sequence and apply medfilt with odd kernel
    if kernel_sec <= 0:
        return labels
    avg_chunk_dur = (times[0][1] - times[0][0])
    kernel = int(round(kernel_sec / avg_chunk_dur))
    if kernel % 2 == 0:
        kernel += 1
    if kernel < 1:
        kernel = 1
    if kernel == 1:
        return labels
    return medfilt(labels.astype(float), kernel_size=kernel).astype(int)

def align_transcript_with_speakers(trans_segments, diarization_segments):
    # diarization_segments: list of (start,end,label)
    out = []
    for t in trans_segments:
        s, e, text = t.start, t.end, t.text.strip()
        # find diarization segment with max overlap
        best_label, best_overlap = None, 0.0
        for ds in diarization_segments:
            ds_s, ds_e, ds_label = ds
            overlap = max(0.0, min(e, ds_e) - max(s, ds_s))
            if overlap > best_overlap:
                best_overlap = overlap
                best_label = ds_label
        out.append((best_label, s, e, text))
    return out

def main():
    print("Loading models...")
    classifier = EncoderClassifier.from_hparams(source=MODEL_DIR, savedir=None, run_opts={"device": DEVICE})
    whisper_model = WhisperModel(WHISPER_MODEL_PATH, device=DEVICE, compute_type="float16", download_root="./models/whisper")
    batched_model = BatchedInferencePipeline(model=whisper_model)
    print("Models loaded.")

    signal, sr = load_audio(AUDIO_PATH)
    print(f"Audio loaded: {AUDIO_PATH} sr={sr} duration={signal.shape[1]/sr:.2f}s")

    # chunk audio
    chunks, times = chunk_audio(signal, sr, WINDOW_SEC, OVERLAP_SEC)
    print(f"{len(chunks)} chunks, window={WINDOW_SEC}s overlap={OVERLAP_SEC}s")

    # embeddings
    embeddings = get_embeddings(classifier, chunks)
    print("Embeddings computed:", embeddings.shape)

    # clustering + reassignment
    labels, detected_k = cluster_and_reassign(embeddings, SIM_THRESHOLD, MAX_SPEAKERS)
    print(f"Detected speakers (initial/after): {detected_k}, labels unique: {sorted(set(labels))}")

    # smoothing
    labels_sm = median_smooth_labels(labels, times, MEDIAN_KERNEL_SEC)
    diarization_segments = labels_to_segments(labels_sm, times)
    print("Diarization segments (merged):")
    for s,e,l in diarization_segments:
        print(f"  {s:.2f}s - {e:.2f}s : Speaker {l+1}")

    # Transcribe with faster-whisper BatchedInferencePipeline
    print("Transcribing with faster-whisper...")
    segments_gen, info = batched_model.transcribe(AUDIO_PATH, batch_size=16, language='en')
    segments = list(segments_gen)
    print(f"Transcribed {len(segments)} segments")
    transcript_chunks = [segment.text.strip() for segment in segments if segment.text.strip()]
    full_transcript = " ".join(transcript_chunks)
    # Align transcript sentences with diarization segments
    labeled = align_transcript_with_speakers(segments, diarization_segments)
    print("\n===== SPEAKER-LABELED TRANSCRIPT =====\n")
    for label, s, e, text in labeled:
        if label is None:
            print(f"Unknown Speaker: {text}")
        else:
            print(f"Speaker {label+1}: {text}")
    print("\nDone.")

    # -----------------------------
    # Entity extraction using Gemma
    # -----------------------------
    from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, StoppingCriteria, StoppingCriteriaList
    import torch

    # Configure 4-bit quantization exactly as before
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True
    )

    model_dir = "./models/gemma-3-1b-it"
    tokenizer = AutoTokenizer.from_pretrained(model_dir)
    gemma_model = AutoModelForCausalLM.from_pretrained(
        model_dir,
        quantization_config=bnb_config,
        device_map="auto",
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True
    )
    print("Gemma loaded")

    # Stopping criteria for your AAA delimiter
    class StopOnBackticks(StoppingCriteria):
        def __init__(self, tokenizer, stop_sequence="AAA"):
            self.tokenizer = tokenizer
            self.stop_sequence = stop_sequence
            self.stop_ids = tokenizer.encode(stop_sequence, add_special_tokens=False)

        def __call__(self, input_ids, scores, **kwargs):
            if len(input_ids[0]) >= len(self.stop_ids):
                if (input_ids[0][-len(self.stop_ids):] == torch.tensor(self.stop_ids, device=input_ids.device)).all():
                    return True
            return False

    stopping_criteria = StoppingCriteriaList([StopOnBackticks(tokenizer)])

    # -----------------------------
    # Prepare system + user prompt
    # -----------------------------
    system_prompt = """
    You are a medical prescription parser. Extract ONLY information explicitly stated.

    Rules:
    1. Extract medicines with EXACT dosages mentioned
    2. If dosage/frequency unclear, mark as "unspecified"
    3. Do NOT infer or assume any information
    4. If doctor says "continue previous meds", extract NOTHING
    5. Output only one valid JSON object and stop
    6. At the end of the Output print AAA

    Output format:
    {
    "medicines": [{"name": str, "dosage": str, "frequency": str, "duration": str}],
    "diseases": [str],
    "tests": [{"name": str, "timing": str}]
    }
    """

    user_prompt = f"""
    {system_prompt}
    Extract from this prescription conversation:
    {full_transcript}

    Remember: Only extract explicitly stated information. No assumptions.
    """

    # -----------------------------
    # Run Gemma model
    # -----------------------------
    inputs = tokenizer(user_prompt, return_tensors='pt').to(gemma_model.device)
    outputs = gemma_model.generate(
        **inputs,
        max_new_tokens=512,
        temperature=0.01,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
        stopping_criteria=stopping_criteria,
        do_sample=False
    )

    result_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(result_text)


if __name__ == "__main__":
    main()



Loading models...


RuntimeError: CUDA failed with error out of memory

In [None]:
# -----------------------------
# Entity extraction using Gemma
# -----------------------------
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, StoppingCriteria, StoppingCriteriaList
import torch

# Configure 4-bit quantization exactly as before
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True
)

model_dir = "./models/gemma-3-1b-it"
tokenizer = AutoTokenizer.from_pretrained(model_dir)
gemma_model = AutoModelForCausalLM.from_pretrained(
    model_dir,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True
)
print("Gemma loaded")

# Stopping criteria for your AAA delimiter
class StopOnBackticks(StoppingCriteria):
    def __init__(self, tokenizer, stop_sequence="AAA"):
        self.tokenizer = tokenizer
        self.stop_sequence = stop_sequence
        self.stop_ids = tokenizer.encode(stop_sequence, add_special_tokens=False)

    def __call__(self, input_ids, scores, **kwargs):
        if len(input_ids[0]) >= len(self.stop_ids):
            if (input_ids[0][-len(self.stop_ids):] == torch.tensor(self.stop_ids, device=input_ids.device)).all():
                return True
        return False

stopping_criteria = StoppingCriteriaList([StopOnBackticks(tokenizer)])

# -----------------------------
# Prepare system + user prompt
# -----------------------------
system_prompt = """
You are a medical prescription parser. Extract ONLY information explicitly stated.

Rules:
1. Extract medicines with EXACT dosages mentioned
2. If dosage/frequency unclear, mark as "unspecified"
3. Do NOT infer or assume any information
4. If doctor says "continue previous meds", extract NOTHING
5. Output only one valid JSON object and stop
6. At the end of the Output print AAA

Output format:
{
  "medicines": [{"name": str, "dosage": str, "frequency": str, "duration": str}],
  "diseases": [str],
  "tests": [{"name": str, "timing": str}]
}
"""

user_prompt = f"""
{system_prompt}
Extract from this prescription conversation:
{full_transcript}

Remember: Only extract explicitly stated information. No assumptions.
"""

# -----------------------------
# Run Gemma model
# -----------------------------
inputs = tokenizer(user_prompt, return_tensors='pt').to(gemma_model.device)
outputs = gemma_model.generate(
    **inputs,
    max_new_tokens=512,
    temperature=0.01,
    pad_token_id=tokenizer.pad_token_id,
    eos_token_id=tokenizer.eos_token_id,
    stopping_criteria=stopping_criteria,
    do_sample=False
)

result_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(result_text)
