In [1]:
#whisper only
import os
import json
import whisper
import torch

# === Configuration ===
audio_dir = r"C:\Users\user\Downloads\Private_dataset\private"
output_dir = r"C:\Users\user\Downloads\outputs"
log_path = os.path.join(output_dir, "processed_files.log")
os.makedirs(output_dir, exist_ok=True)

# === Load Whisper model ===
model = whisper.load_model("medium").to("cuda")

# === Load log of already processed files ===
if os.path.exists(log_path):
    with open(log_path, "r", encoding="utf-8") as log_file:
        processed_files = set(line.strip() for line in log_file if line.strip())
else:
    processed_files = set()

# === Transcribe each WAV file ===
for fname in os.listdir(audio_dir):
    if not fname.endswith(".wav"):
        continue

    file_id = os.path.splitext(fname)[0]
    if file_id in processed_files:
        print(f"⏩ Skipping {fname} (already transcribed)")
        continue

    audio_path = os.path.join(audio_dir, fname)
    transcript_path = os.path.join(output_dir, f"{file_id}.txt")
    json_path = os.path.join(output_dir, f"{file_id}.json")

    try:
        print(f"▶ Transcribing {fname}...")
        result = model.transcribe(audio_path)
        transcript = result["text"].strip()

        # Save transcript (.txt)
        with open(transcript_path, "w", encoding="utf-8") as f:
            f.write(transcript)

        # Save full Whisper output (.json)
        with open(json_path, "w", encoding="utf-8") as f:
            json.dump(result, f, indent=2)

        # Log success
        with open(log_path, "a", encoding="utf-8") as log_file:
            log_file.write(file_id + "\n")

        print(f"✅ Done: {file_id}")

    except Exception as e:
        print(f"❌ Error processing {fname}: {e}")

print("🎉 All audio files transcribed.")


  checkpoint = torch.load(fp, map_location=device)


⏩ Skipping 60014.wav (already transcribed)
⏩ Skipping 60015.wav (already transcribed)
⏩ Skipping 60018.wav (already transcribed)
⏩ Skipping 60022.wav (already transcribed)
⏩ Skipping 60049.wav (already transcribed)
⏩ Skipping 60079.wav (already transcribed)
⏩ Skipping 60084.wav (already transcribed)
⏩ Skipping 60102.wav (already transcribed)
⏩ Skipping 60147.wav (already transcribed)
⏩ Skipping 60167.wav (already transcribed)
⏩ Skipping 60180.wav (already transcribed)
⏩ Skipping 60184.wav (already transcribed)
⏩ Skipping 60197.wav (already transcribed)
⏩ Skipping 60226.wav (already transcribed)
⏩ Skipping 60239.wav (already transcribed)
⏩ Skipping 60243.wav (already transcribed)
⏩ Skipping 60245.wav (already transcribed)
⏩ Skipping 60253.wav (already transcribed)
⏩ Skipping 60267.wav (already transcribed)
⏩ Skipping 60275.wav (already transcribed)
⏩ Skipping 60286.wav (already transcribed)
⏩ Skipping 60290.wav (already transcribed)
⏩ Skipping 60330.wav (already transcribed)
⏩ Skipping 

In [2]:
import os
import shutil

# Paths
audio_dir = r"C:\Users\user\Downloads\Private_dataset\private"
transcript_dir = r"C:\Users\user\Downloads\outputs"

# Generate .lab files in audio_dir
for fname in os.listdir(audio_dir):
    if not fname.endswith(".wav"):
        continue

    base = os.path.splitext(fname)[0]
    txt_path = os.path.join(transcript_dir, base + ".txt")
    lab_path = os.path.join(audio_dir, base + ".lab")

    if os.path.exists(txt_path):
        shutil.copy(txt_path, lab_path)
        print(f"✅ Created: {lab_path}")
    else:
        print(f"⚠️ Missing transcript for: {base}")


✅ Created: C:\Users\user\Downloads\Private_dataset\private\60014.lab
✅ Created: C:\Users\user\Downloads\Private_dataset\private\60015.lab
✅ Created: C:\Users\user\Downloads\Private_dataset\private\60018.lab
✅ Created: C:\Users\user\Downloads\Private_dataset\private\60022.lab
✅ Created: C:\Users\user\Downloads\Private_dataset\private\60049.lab
✅ Created: C:\Users\user\Downloads\Private_dataset\private\60079.lab
✅ Created: C:\Users\user\Downloads\Private_dataset\private\60084.lab
✅ Created: C:\Users\user\Downloads\Private_dataset\private\60102.lab
✅ Created: C:\Users\user\Downloads\Private_dataset\private\60147.lab
✅ Created: C:\Users\user\Downloads\Private_dataset\private\60167.lab
✅ Created: C:\Users\user\Downloads\Private_dataset\private\60180.lab
✅ Created: C:\Users\user\Downloads\Private_dataset\private\60184.lab
✅ Created: C:\Users\user\Downloads\Private_dataset\private\60197.lab
✅ Created: C:\Users\user\Downloads\Private_dataset\private\60226.lab
✅ Created: C:\Users\user\Downloads

In [3]:
import os
import json

json_dir = r"C:\Users\user\Downloads\outputs"
output_path = os.path.join(json_dir, "task1_answer.txt")

# Collect file_id → text mapping
entries = []
for fname in os.listdir(json_dir):
    if not fname.endswith(".json"):
        continue
    file_id = os.path.splitext(fname)[0]
    if not file_id.isdigit():
        continue
    with open(os.path.join(json_dir, fname), "r", encoding="utf-8") as f:
        data = json.load(f)
        transcript = data.get("text", "").strip()
        entries.append((int(file_id), transcript))

# Sort numerically by file_id
entries.sort()

# Write output
with open(output_path, "w", encoding="utf-8") as f:
    for file_id, text in entries:
        f.write(f"{file_id}\t{text}\n")

print("✅ task1_answer.txt written in correct format and order.")


✅ task1_answer.txt written in correct format and order.


In [6]:
input_path = "task1_answer.txt"         # Your original Task 1 output
output_path = "task1_answer_normalized.txt"  # Output file with normalized tabs

with open(input_path, "r", encoding="utf-8") as infile, \
     open(output_path, "w", encoding="utf-8") as outfile:
    
    for line in infile:
        parts = line.strip().split(None, 1)  # Split on first whitespace
        if len(parts) == 2:
            utt_id, text = parts
            outfile.write(f"{utt_id}\t{text.strip()}\n")  # Force single tab
        elif parts:
            outfile.write(f"{parts[0]}\n")  # Just the ID (edge case)

print("✅ Normalization complete.")


✅ Normalization complete.


# Old Code

In [1]:
#gentle script
import os
import json
import whisper
import requests

# === Configuration ===
audio_dir = r"C:\Users\user\Downloads\Private_dataset\private"
output_dir = r"C:\Users\user\Downloads\outputs"
gentle_url = "http://localhost:8765/transcriptions?async=false"
log_path = os.path.join(output_dir, "processed_files.log")
os.makedirs(output_dir, exist_ok=True)

# === Load Whisper model ===
model = whisper.load_model("medium").to("cuda")

# === Load log of already processed files ===
if os.path.exists(log_path):
    with open(log_path, "r", encoding="utf-8") as log_file:
        processed_files = set(line.strip() for line in log_file if line.strip())
else:
    processed_files = set()

# === Process each WAV file ===
for fname in os.listdir(audio_dir):
    if not fname.endswith(".wav"):
        continue

    file_id = os.path.splitext(fname)[0]
    if file_id in processed_files:
        print(f"⏩ Skipping {fname} (already logged)")
        continue

    audio_path = os.path.join(audio_dir, fname)
    transcript_path = os.path.join(output_dir, f"{file_id}.txt")
    gentle_json_path = os.path.join(output_dir, f"{file_id}_gentle.json")
    output_tsv = os.path.join(output_dir, f"{file_id}_words.tsv")

    try:
        # === Whisper Transcription ===
        print(f"▶ Transcribing {fname}...")
        result = model.transcribe(audio_path)
        transcript = result["text"].strip()

        with open(transcript_path, "w", encoding="utf-8") as f:
            f.write(transcript)

        # === Gentle Forced Alignment ===
        print(f"🕒 Aligning with Gentle: {fname}")
        with open(audio_path, "rb") as audio_file, open(transcript_path, "r", encoding="utf-8") as transcript_file:
            files = {
                "audio": (fname, audio_file, "audio/wav"),
                "transcript": (f"{file_id}.txt", transcript_file, "text/plain"),
            }

            response = requests.post(gentle_url, files=files)
            response.raise_for_status()
            alignment = response.json()

        with open(gentle_json_path, "w", encoding="utf-8") as f:
            json.dump(alignment, f, indent=2)

        # === Extract Word-Level Timing ===
        with open(output_tsv, "w", encoding="utf-8") as f:
            f.write("file_id\tword\tstart\tend\n")
            for word in alignment["words"]:
                if word["case"] != "success":
                    continue
                f.write(f"{file_id}\t{word['word']}\t{word['start']:.2f}\t{word['end']:.2f}\n")

        # === Log as processed ===
        with open(log_path, "a", encoding="utf-8") as log_file:
            log_file.write(file_id + "\n")

        print(f"✅ Completed: {file_id}")

    except Exception as e:
        print(f"❌ Error processing {fname}: {e}")

print("🎉 All files processed.")


  checkpoint = torch.load(fp, map_location=device)


▶ Transcribing 60014.wav...
🕒 Aligning with Gentle: 60014.wav
✅ Completed: 60014
▶ Transcribing 60015.wav...
🕒 Aligning with Gentle: 60015.wav
✅ Completed: 60015
▶ Transcribing 60018.wav...
🕒 Aligning with Gentle: 60018.wav
✅ Completed: 60018
▶ Transcribing 60022.wav...
🕒 Aligning with Gentle: 60022.wav
✅ Completed: 60022
▶ Transcribing 60049.wav...
🕒 Aligning with Gentle: 60049.wav
✅ Completed: 60049
▶ Transcribing 60079.wav...


KeyboardInterrupt: 

In [6]:
import os

# === Config ===
output_dir = r"C:\Users\user\Downloads\outputs"
final_output_path = os.path.join(output_dir, "task1_answer.txt")

# === Gather all Whisper transcripts ===
lines = []

for fname in os.listdir(output_dir):
    if fname.endswith(".txt") and "_gentle" not in fname:
        file_id = os.path.splitext(fname)[0]
        txt_path = os.path.join(output_dir, fname)

        with open(txt_path, "r", encoding="utf-8") as f:
            transcript = f.read().strip().replace("\n", " ")
            lines.append(f"{file_id}\t{transcript}")

# === Sort numerically by file_id if needed ===
lines.sort(key=lambda x: int(x.split("\t")[0]))

# === Save as final output ===
with open(final_output_path, "w", encoding="utf-8") as f:
    for line in lines:
        f.write(line + "\n")

print(f"✅ Task 1 output saved to:\n{final_output_path}")


✅ Task 1 output saved to:
C:\Users\user\Downloads\outputs\task1_answer.txt


In [15]:
import os
import time
import nltk
from nltk.tokenize import sent_tokenize
from deepmultilingualpunctuation import PunctuationModel

# === Set NLTK data path and download 'punkt' tokenizer ===
nltk.data.path.insert(0, "C:/Users/user/nltk_data")
nltk.download("punkt", download_dir="C:/Users/user/nltk_data")

# === Configuration ===
input_dir = r"C:\Users\user\Downloads\outputs"
output_file = os.path.join(input_dir, "task1_answer_punctuated.txt")
model = PunctuationModel()

# === Prepare file list ===
file_list = [f for f in os.listdir(input_dir) if f.endswith(".txt") and f.split(".")[0].isdigit()]
total = len(file_list)
results = []

# === Process each file ===
for idx, fname in enumerate(file_list, 1):
    start = time.time()

    file_id = os.path.splitext(fname)[0]
    path = os.path.join(input_dir, fname)

    print(f"🔄 [{idx}/{total}] Processing {fname}...")

    with open(path, "r", encoding="utf-8") as f:
        raw = f.read().strip().replace("\n", " ")

    try:
        punctuated = model.restore_punctuation(raw)
        capitalized = " ".join(
            s[0].upper() + s[1:] if s else "" for s in sent_tokenize(punctuated)
        )
    except Exception as e:
        print(f"⚠️ Failed to punctuate {file_id}: {e}")
        capitalized = raw  # fallback

    results.append(f"{file_id}\t{capitalized}")

    elapsed = time.time() - start
    print(f"✅ [{idx}/{total}] Finished {fname} in {elapsed:.2f}s")

# === Save final output ===
with open(output_file, "w", encoding="utf-8") as f:
    for line in sorted(results, key=lambda x: int(x.split("\t")[0])):
        f.write(line + "\n")

print(f"\n✅ All done. Final output saved to:\n{output_file}")


[nltk_data] Downloading package punkt to C:/Users/user/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
Device set to use cpu


🔄 [1/775] Processing 24016.txt...
✅ [1/775] Finished 24016.txt in 2.32s
🔄 [2/775] Processing 24018.txt...
✅ [2/775] Finished 24018.txt in 1.93s
🔄 [3/775] Processing 24023.txt...
✅ [3/775] Finished 24023.txt in 1.47s
🔄 [4/775] Processing 24055.txt...
✅ [4/775] Finished 24055.txt in 1.59s
🔄 [5/775] Processing 24063.txt...
✅ [5/775] Finished 24063.txt in 1.11s
🔄 [6/775] Processing 24081.txt...
✅ [6/775] Finished 24081.txt in 1.07s
🔄 [7/775] Processing 24098.txt...
✅ [7/775] Finished 24098.txt in 0.99s
🔄 [8/775] Processing 24198.txt...
✅ [8/775] Finished 24198.txt in 1.28s
🔄 [9/775] Processing 24224.txt...
✅ [9/775] Finished 24224.txt in 1.20s
🔄 [10/775] Processing 2423.txt...
✅ [10/775] Finished 2423.txt in 1.63s
🔄 [11/775] Processing 24244.txt...
✅ [11/775] Finished 24244.txt in 1.54s
🔄 [12/775] Processing 24264.txt...
✅ [12/775] Finished 24264.txt in 1.20s
🔄 [13/775] Processing 24274.txt...
✅ [13/775] Finished 24274.txt in 1.15s
🔄 [14/775] Processing 24290.txt...
✅ [14/775] Finished 242

In [1]:
import os
import requests

# === Configuration ===
gentle_url = "http://localhost:8765/transcriptions?async=false"
audio_dir = r"C:\Users\user\Downloads\Validation_Dataset\audio"
transcript_dir = r"C:\Users\user\Downloads\outputs"
output_dir = r"C:\Users\user\Downloads\outputs"  # save JSONs here

# === Batch process all matching file pairs ===
for fname in os.listdir(audio_dir):
    if fname.endswith(".wav"):
        file_id = os.path.splitext(fname)[0]
        audio_path = os.path.join(audio_dir, fname)
        txt_path = os.path.join(transcript_dir, f"{file_id}.txt")
        output_path = os.path.join(output_dir, f"{file_id}_words.json")

        if not os.path.exists(txt_path):
            print(f"⚠️ Transcript not found: {txt_path}")
            continue

        print(f"🔄 Aligning {file_id}...")

        with open(audio_path, "rb") as audio_f, open(txt_path, "r", encoding="utf-8") as txt_f:
            files = {
                "audio": ("audio.wav", audio_f, "audio/wav"),
                "transcript": ("transcript.txt", txt_f, "text/plain")
            }

            try:
                response = requests.post(gentle_url, files=files)
                response.raise_for_status()

                with open(output_path, "w", encoding="utf-8") as out_f:
                    out_f.write(response.text)

                print(f"✅ Saved: {output_path}")

            except Exception as e:
                print(f"❌ Failed for {file_id}: {e}")


🔄 Aligning 24016...
✅ Saved: C:\Users\user\Downloads\outputs\24016_words.json
🔄 Aligning 24018...
✅ Saved: C:\Users\user\Downloads\outputs\24018_words.json
🔄 Aligning 24023...
✅ Saved: C:\Users\user\Downloads\outputs\24023_words.json
🔄 Aligning 24055...
✅ Saved: C:\Users\user\Downloads\outputs\24055_words.json
🔄 Aligning 24063...
✅ Saved: C:\Users\user\Downloads\outputs\24063_words.json
🔄 Aligning 24081...
✅ Saved: C:\Users\user\Downloads\outputs\24081_words.json
🔄 Aligning 24098...
✅ Saved: C:\Users\user\Downloads\outputs\24098_words.json
🔄 Aligning 24198...
✅ Saved: C:\Users\user\Downloads\outputs\24198_words.json
🔄 Aligning 24224...
✅ Saved: C:\Users\user\Downloads\outputs\24224_words.json
🔄 Aligning 2423...
✅ Saved: C:\Users\user\Downloads\outputs\2423_words.json
🔄 Aligning 24244...
✅ Saved: C:\Users\user\Downloads\outputs\24244_words.json
🔄 Aligning 24264...
✅ Saved: C:\Users\user\Downloads\outputs\24264_words.json
🔄 Aligning 24274...
✅ Saved: C:\Users\user\Downloads\outputs\24274

# Task2

In [3]:
from textgrid import TextGrid
import os
import json

input_dir = r"C:\Users\user\Downloads\outputs"
output_dir = r"C:\Users\user\Downloads\outputs"

for file in os.listdir(input_dir):
    if file.endswith(".TextGrid"):
        tg = TextGrid.fromFile(os.path.join(input_dir, file))
        words = []
        for tier in tg.tiers:
            if "word" in tier.name.lower():
                for interval in tier.intervals:
                    word = interval.mark.strip()
                    if word:
                        words.append({
                            "word": word,
                            "start": interval.minTime,
                            "end": interval.maxTime,
                            "case": "success"
                        })
        base = os.path.splitext(file)[0]
        with open(os.path.join(output_dir, f"{base}_words.json"), "w", encoding="utf-8") as f:
            json.dump({"words": words}, f, indent=2)


In [5]:
import os
import json
import re
import spacy

# === Paths ===
json_folder = r"C:\Users\user\Downloads\outputs"
output_file = os.path.join(json_folder, "task2_answer.txt")

# === Load models ===
nlp_sci = spacy.load("en_core_web_sm")
nlp_bc5cdr = spacy.load("en_ner_bc5cdr_md")

# === Regex patterns ===
duration_regex = re.compile(r"\b(\d+|a|an|one|two|few|several)\s+(seconds?|minutes?|hours?|days?|weeks?|months?|years?)\b", re.IGNORECASE)
time_regex = re.compile(r"\b(\d{1,2}(:\d{2})?\s*(AM|PM|am|pm)|tonight|this (morning|evening|afternoon)|yesterday|today|tomorrow|now)\b", re.IGNORECASE)
set_regex = re.compile(r"\b(every|each|daily|weekly|monthly|once a|twice a)\b", re.IGNORECASE)
doctor_regex = re.compile(r"\b(Dr\.?\s+[A-Z][a-z]+|doctor\s+[A-Z][a-z]+)\b")
patient_regex = re.compile(r"\b(the\s+patient|patient\s+[A-Z][a-z]+)\b")

common_family_names = {
    "tanya", "ivan", "james", "emma", "david", "kelly", "franco", "audrey",
    "jess", "sydney", "jack", "sophia", "alex", "maria"
}

city_list = {"chicago", "london", "paris", "new york", "beijing", "tokyo", "boston", "houston", "seattle"}

# === Label logic ===
def classify_label(label, text):
    text_clean = text.strip()
    text_lower = text_clean.lower()

    if re.search(r"\b\d+\s*(years?|months?|days?)\s+old\b", text_lower) or re.search(r"\baged\s+\d+\b", text_lower):
        return "AGE"
    if re.fullmatch(r"[A-Z0-9\-]{5,}", text_clean) or re.fullmatch(r"\d{4,}", text_clean):
        return "ID_NUMBER"
    if duration_regex.search(text_clean):
        return "DURATION"
    if time_regex.search(text_clean):
        return "TIME"
    if set_regex.search(text_clean):
        return "SET"
    if doctor_regex.search(text_clean):
        return "DOCTOR"
    if patient_regex.search(text_clean):
        return "PATIENT"
    if "hospital" in text_lower:
        return "HOSPITAL"
    if text_lower in city_list:
        return "CITY"
    if text_lower in common_family_names:
        return "FAMILYNAME"
    if label == "PERSON":
        if len(text_clean.split()) == 1:
            return "FAMILYNAME"
        if len(text_clean.split()) >= 2:
            return "PERSONALNAME"
    if text_clean.istitle() and len(text_clean.split()) == 2:
        return "PERSONALNAME"
    if "clinic" in text_lower or "bank" in text_lower:
        return "ORGANIZATION"
    if label in {"DATE", "TIME", "DURATION", "SET"}:
        return label
    return "OTHER"

# === Align spans ===
def align_to_time(ent_start_char, ent_end_char, word_spans):
    start_time = end_time = None
    for start, end, w in word_spans:
        if start <= ent_start_char < end:
            start_time = w.get("start")
        if start < ent_end_char <= end:
            end_time = w.get("end")
        if start_time is not None and end_time is not None:
            break
    return start_time, end_time

# === Process all files ===
results = set()
for fname in os.listdir(json_folder):
    if not fname.endswith("_words.json"):
        continue
    file_id = fname.split("_")[0]

    try:
        with open(os.path.join(json_folder, fname), "r", encoding="utf-8") as f:
            data = json.load(f)
            words = data.get("words", [])
            words = [w for w in words if isinstance(w, dict) and w.get("case") != "not-found-in-audio"]
    except Exception:
        continue
    if not words:
        continue

    full_text = ""
    word_spans = []
    offset = 0
    for w in words:
        word = w["word"]
        start = offset
        end = start + len(word)
        word_spans.append((start, end, w))
        full_text += word + " "
        offset = end + 1
    full_text = full_text.strip()

    try:
        doc1 = nlp_sci(full_text)
        doc2 = nlp_bc5cdr(full_text)
    except:
        continue

    seen = set()
    for doc in [doc1, doc2]:
        for ent in doc.ents:
            key = (ent.start_char, ent.end_char)
            if key in seen:
                continue
            seen.add(key)
            ent_text = ent.text.strip()
            label = classify_label(ent.label_, ent_text)

            # KEEP all entities, including "OTHER"
            start_time, end_time = align_to_time(ent.start_char, ent.end_char, word_spans)
            if start_time is not None and end_time is not None:
                line = f"{file_id}\t{label}\t{start_time:.2f}\t{end_time:.2f}\t{ent_text}"
                results.add(line)

# === Save output ===
with open(output_file, "w", encoding="utf-8") as f:
    for line in sorted(results):
        f.write(line + "\n")

print(f"✅ Output saved to: {output_file}")


✅ Output saved to: C:\Users\user\Downloads\outputs\task2_answer.txt


In [1]:
import os
import sys
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# === Paths ===
input_file = r"C:\Users\user\Downloads\outputs\task2_answer.txt"
output_file = r"C:\Users\user\Downloads\outputs\task2_llm_verified_batched.txt"

# === Config ===
BATCH_SIZE = 8

# === Logging ===
def log_message(msg, level="DEBUG"):
    try:
        print(f"[{level}] {msg}", flush=True)
    except UnicodeEncodeError:
        print(f"[{level}] [UnicodeEncodeError suppressed message]", flush=True)

log_message("Starting LLM-based NER verification script (Attempting corrections)...", "INFO")
log_message(f"Input file: {input_file}")
log_message(f"Output file: {output_file}")
log_message(f"Batch size: {BATCH_SIZE}")

# === LLM Setup ===
model_id = "unsloth/DeepSeek-R1-Distill-Qwen-7B-bnb-4bit"
log_message(f"Loading tokenizer and model from: {model_id}...")
try:
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    # Ensure correct dtype for potential Unsloth optimizations if available
    model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.float16)
    model.eval()
    log_message(f"✅ Model loaded successfully on device: {model.device}", "INFO")

    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
        log_message("Set tokenizer.pad_token_id to tokenizer.eos_token_id.")

except Exception as e:
    log_message(f"❌ Model load failed: {e}", "ERROR")
    sys.exit(1)

# === Labels ===
VALID_LABELS = [
    "PATIENT", "DOCTOR", "PERSONALNAME", "FAMILYNAME", "PROFESSION", "ROOM", "DEPARTMENT", "HOSPITAL",
    "ORGANIZATION", "STREET", "CITY", "STATE", "COUNTRY", "COUNTY", "ZIP", "LOCATION-OTHER", "DISTRICT",
    "AGE", "DATE", "TIME", "DURATION", "SET", "PHONE", "FAX", "EMAIL", "URL", "IPADDRESS",
    "SOCIAL_SECURITY_NUMBER", "MEDICAL_RECORD_NUMBER", "HEALTH_PLAN_NUMBER", "ACCOUNT_NUMBER", "LICENSE_NUMBER",
    "VEHICLE_ID", "DEVICE_ID", "BIOMETRIC_ID", "ID_NUMBER", "OTHER"
]

# === Prompt Template ===
def make_prompt(file_id, start, end, text, current_label):
    """
    Revised prompt to be more directive about assigning the *correct* label.
    It uses the current label as context but prioritizes the task of correct assignment.
    """
    labels_list_str = ", ".join([lbl for lbl in VALID_LABELS if lbl != "OTHER"])

    # Notice the "Think step" and the clearer instruction to provide only the line.
    prompt = f"""You are a highly analytical and precise medical Named Entity Recognition (NER) expert. Your single goal is to identify the **most accurate entity label** for the provided text snippet from the `VALID_LABELS` list.

**VALID_LABELS:** {labels_list_str} (Use 'OTHER' only if absolutely no other label applies).

**Instructions:**
1.  Carefully examine the "Text" field.
2.  Consider the "Current Label" as a suggestion, but do not automatically assume it is correct.
3.  Choose the *single best* label from `VALID_LABELS` that perfectly describes the "Text".

**Output Format:**
File ID\tCORRECT_LABEL\tStart\tEnd\tText

---
**Example of Required Correction:**
Input Data:
File ID: 99999
Start: 10.00
End: 12.34
Text: john doe
Current Label: FAMILYNAME
Output: 99999\tPERSONALNAME\t10.00\t12.34\tjohn doe

---
**Your Task - Input Data:**
File ID: {file_id}
Start: {start}
End: {end}
Text: {text}
Current Label: {current_label}
Output:"""
    return prompt

# === Read Input ===
seen = set()
entries = []
prompts = []
log_message(f"🔎 Reading from: {input_file}")
try:
    if not os.path.exists(input_file):
        log_message(f"🚨 ERROR: Input file not found at {input_file}.", "ERROR")
        sys.exit(1)

    with open(input_file, "r", encoding="utf-8") as f:
        for line_num, line in enumerate(f):
            line = line.strip()
            if not line:
                log_message(f"Skipping empty line at input line {line_num + 1}.")
                continue

            parts = line.split("\t")
            if len(parts) != 5:
                log_message(f"Skipping malformed line at input line {line_num + 1}: '{line}'. Expected 5 parts, got {len(parts)}.", "WARNING")
                continue

            file_id, label, start, end, text = parts
            key = (file_id, start, end, text.lower())
            if key in seen:
                log_message(f"Skipping duplicate entry at input line {line_num + 1}: '{line}'.", "DEBUG")
                continue
            seen.add(key)
            entries.append((file_id, label, start, end, text))
            prompts.append(make_prompt(file_id, start, end, text, label))
    log_message(f"Finished reading {len(entries)} unique entities from input file.", "INFO")

except FileNotFoundError:
    log_message(f"🚨 ERROR: Input file not found at '{input_file}'. Please verify the path.", "ERROR")
    sys.exit(1)
except Exception as e:
    log_message(f"🚨 An unexpected error occurred during input file reading: {e}", "ERROR")
    sys.exit(1)

# === Batched Inference ===
corrected_output_lines = []
total_entities = len(prompts)
log_message(f"🚀 Beginning batched LLM inference on {total_entities} entities...", "INFO")

for i in range(0, total_entities, BATCH_SIZE):
    batch_prompts = prompts[i:i + BATCH_SIZE]
    batch_entries = entries[i:i + BATCH_SIZE]
    batch_start_idx = i

    log_message(f"Processing batch {int(i/BATCH_SIZE) + 1} of {(total_entities + BATCH_SIZE - 1) // BATCH_SIZE} ({batch_start_idx}-{min(i + BATCH_SIZE, total_entities) - 1})", "INFO")

    try:
        inputs = tokenizer(
            batch_prompts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=tokenizer.model_max_length
        ).to(model.device)

        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=100, # Increased max_new_tokens for more flexibility
                do_sample=False,    # Keep greedy for consistent output format
                temperature=0.3,    # Slightly increased temperature to encourage more varied responses if unsure
                eos_token_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.eos_token_id,
                num_return_sequences=1
            )
        decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)

    except Exception as e:
        log_message(f"❌ Batch {int(i/BATCH_SIZE) + 1} failed during LLM inference: {e}. Falling back to original entries for this batch.", "ERROR")
        for entry in batch_entries:
            corrected_output_lines.append("\t".join(entry))
        continue

    # Process each individual output from the batch
    for j, (original_entry, full_llm_output) in enumerate(zip(batch_entries, decoded_outputs)):
        current_entity_idx = batch_start_idx + j
        original_line = "\t".join(original_entry)

        try:
            # The prompt now ends with "Output:\n"
            prompt_end_marker = "Output:"
            generated_text_start_index = full_llm_output.rfind(prompt_end_marker)

            if generated_text_start_index == -1:
                log_message(f"⚠️ Prompt end marker not found in LLM output for entity {current_entity_idx}. Using original entry.", "WARNING")
                corrected_output_lines.append(original_line)
                continue

            generated_content = full_llm_output[generated_text_start_index + len(prompt_end_marker):].strip()

            valid_line_found = False
            for line_candidate in generated_content.splitlines():
                line_candidate = line_candidate.strip()
                parts = line_candidate.split("\t")
                
                # Validate the structure and consistency of the generated line
                if len(parts) == 5 and \
                   parts[0] == original_entry[0] and \
                   parts[2] == original_entry[2] and \
                   parts[3] == original_entry[3] and \
                   parts[4] == original_entry[4]:
                    
                    predicted_label = parts[1].strip()
                    
                    # Normalize and validate predicted label
                    canonical_label = "OTHER" # Default if no match found
                    is_valid_label = False
                    for valid_lbl in VALID_LABELS:
                        if predicted_label.lower() == valid_lbl.lower():
                            canonical_label = valid_lbl
                            is_valid_label = True
                            break
                    
                    if not is_valid_label: # If LLM generated a label not in our valid list
                         log_message(f"⚠️ Invalid label '{predicted_label}' generated for entity {current_entity_idx}. Falling back to original. LLM Output: '{line_candidate}'", "WARNING")
                         corrected_output_lines.append(original_line)
                         valid_line_found = True # Treat as processed, but used fallback
                         break

                    parts[1] = canonical_label # Use the canonical form
                    corrected_output_lines.append("\t".join(parts))
                    
                    # Log whether a correction was made
                    if canonical_label == original_entry[1]:
                        log_message(f"✅ [{current_entity_idx + 1}/{total_entities}] Text: '{original_entry[4]}' → Label: '{canonical_label}' (No change)", "INFO")
                    else:
                        log_message(f"✅ [{current_entity_idx + 1}/{total_entities}] Text: '{original_entry[4]}' → Corrected: '{original_entry[1]}' -> '{canonical_label}'", "INFO")
                    
                    valid_line_found = True
                    break # Found a valid line, stop processing candidates for this entry

            if not valid_line_found:
                log_message(f"⚠️ No valid 5-part tab-separated output line found after parsing LLM output for entity {current_entity_idx}. Using original entry: '{original_entry[4]}'", "WARNING")
                corrected_output_lines.append(original_line) # Fallback if no valid line could be extracted

        except Exception as parse_error:
            log_message(f"❌ Error processing LLM output for entity {current_entity_idx}. Using original entry: '{original_entry[4]}'. Error: {parse_error}", "ERROR")
            corrected_output_lines.append(original_line)

    # ⛔ Testing mode: Stop after the first 3 batches (remove this for full run)
    if i >= (BATCH_SIZE * 2) and (total_entities > BATCH_SIZE * 2):
        log_message("🧪 Testing mode: Script stopped after processing 3 batches. Remove this `if` block for full run.", "INFO")
        break

log_message(f"Finished LLM inference. Processed {total_entities} entities. Total lines for output: {len(corrected_output_lines)}", "INFO")

# === Save Output ===
try:
    log_message(f"Writing {len(corrected_output_lines)} corrected entries to: {output_file}", "INFO")
    with open(output_file, "w", encoding="utf-8") as f:
        f.write("\n".join(corrected_output_lines))
    log_message(f"🎉 Saved verified output to: {output_file}", "INFO")

except Exception as e:
    log_message(f"🚨 ERROR: Failed to save output file. Error: {e}", "ERROR")
    sys.exit(1)

log_message("Script execution finished.", "INFO")


  from .autonotebook import tqdm as notebook_tqdm


[INFO] Starting LLM-based NER verification script (Attempting corrections)...
[DEBUG] Input file: C:\Users\user\Downloads\outputs\task2_answer.txt
[DEBUG] Output file: C:\Users\user\Downloads\outputs\task2_llm_verified_batched.txt
[DEBUG] Batch size: 8
[DEBUG] Loading tokenizer and model from: unsloth/DeepSeek-R1-Distill-Qwen-7B-bnb-4bit...
[INFO] ✅ Model loaded successfully on device: cuda:0
[DEBUG] 🔎 Reading from: C:\Users\user\Downloads\outputs\task2_answer.txt
[INFO] Finished reading 2384 unique entities from input file.
[INFO] 🚀 Beginning batched LLM inference on 2384 entities...
[INFO] Processing batch 1 of 298 (0-7)


The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


[INFO] ✅ [1/2384] Text: 'june 6 2063' → Label: 'DATE' (No change)
[INFO] ✅ [2/2384] Text: 'june 6 2063' → Label: 'DATE' (No change)
[INFO] ✅ [3/2384] Text: 'katherine' → Corrected: 'FAMILYNAME' -> 'PERSONALNAME'
[INFO] ✅ [4/2384] Text: 'dr martin' → Label: 'PERSONALNAME' (No change)
[INFO] ✅ [5/2384] Text: 'may 24 2063' → Label: 'DATE' (No change)
[INFO] ✅ [6/2384] Text: 'granuloma' → Label: 'OTHER' (No change)
[INFO] ✅ [8/2384] Text: '9 o'clock' → Label: 'TIME' (No change)
[INFO] Processing batch 2 of 298 (8-15)


The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


[INFO] ✅ [9/2384] Text: 'february 12 2063' → Label: 'DATE' (No change)
[INFO] ✅ [10/2384] Text: 'sarah' → Corrected: 'FAMILYNAME' -> 'PERSONALNAME'
[INFO] ✅ [11/2384] Text: 'dr brook' → Label: 'PERSONALNAME' (No change)
[INFO] ✅ [12/2384] Text: 'october 12 2062' → Label: 'DATE' (No change)
[INFO] ✅ [13/2384] Text: '8 mm' → Label: 'OTHER' (No change)
[INFO] ✅ [14/2384] Text: '40 mm' → Label: 'OTHER' (No change)
[INFO] ✅ [15/2384] Text: '1' → Label: 'OTHER' (No change)
[INFO] ✅ [16/2384] Text: '4' → Label: 'OTHER' (No change)
[INFO] Processing batch 3 of 298 (16-23)


The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


[INFO] ✅ [17/2384] Text: '2' → Label: 'OTHER' (No change)
[INFO] ✅ [18/2384] Text: '11' → Label: 'OTHER' (No change)
[INFO] ✅ [19/2384] Text: 'january 10 2006' → Label: 'DATE' (No change)
[INFO] ✅ [20/2384] Text: 'november 4 2013' → Label: 'DATE' (No change)
[INFO] ✅ [21/2384] Text: 'the day' → Label: 'DATE' (No change)
[INFO] ✅ [22/2384] Text: 'qld 3373' → Label: 'PERSONALNAME' (No change)
[INFO] ✅ [24/2384] Text: 'december 10 2063' → Label: 'DATE' (No change)
[INFO] 🧪 Testing mode: Script stopped after processing 3 batches. Remove this `if` block for full run.
[INFO] Finished LLM inference. Processed 2384 entities. Total lines for output: 24
[INFO] Writing 24 corrected entries to: C:\Users\user\Downloads\outputs\task2_llm_verified_batched.txt
[INFO] 🎉 Saved verified output to: C:\Users\user\Downloads\outputs\task2_llm_verified_batched.txt
[INFO] Script execution finished.
