In [None]:
import re
import time
import requests
import random
from pathlib import Path
from datetime import datetime
import pandas as pd
from tqdm.notebook import tqdm
import openai

# Setup

In [None]:
VLLM_URL = "http://localhost:8000/v1/completions"
PROMPTS_DIR = Path("prompts")
OUTPUT_DIR = Path("../data/synthetic/note-excerpts")
OUTPUT_DIR.mkdir(exist_ok=True)

# Build Prompt

**Define Prompt Template**

In [None]:
PROMPT_TMPL = """Only output one paragraph—no comments, no lists.  Begin with exactly one bold header (choose one):

**Brief Hospital Course:**  
**Major Procedures:**  
**Discharge Summary:**

Then, in approximately {n_sentences} sentences or sentence fragments, document a realistic hospital note excerpt in a rushed, semi‑structured style that covers:
- admission reason
- key findings (labs/imaging/procedures)
- interventions/treatments
- discharge plan

Include common EHR shorthand (HTN, DM2, CAD, WNL, BP, HR, SpO2, CXR, CT scan, ICU, OR, PO, IV, BiPAP).  
Inject “noise” (~10% of sentences) by omitting commas, double‑spacing words, or dangling fragments.  
Vary sentence length (2–3 words up to ~25 words).  
Use demographics only as “A ##‑year‑old M” or “A ##‑year‑old F” inside a sentence—never full identifiers. 
Focus on: {scenario}

End with ***END NOTE***"""

**Candidates: Headers**

In [None]:
HEADERS = [
    "**Brief Hospital Course:**",
    "**Major Procedures:**",
    "**Discharge Summary:**",
]

**Candidates: Scenario**

In [None]:
SCENARIOS = [
  "COPD exacerbation in a smoker",
  "Community‑acquired pneumonia",
  "Acute decompensated heart failure",
  "STEMI post‑PCI",
  "Ischemic stroke on tPA",
  "DKA in type 1 diabetic",
  "GI bleed from PUD",
  "Sepsis from UTI",
  "Post‑op hip fracture repair",
  "AKI on CKD",
  "Preeclampsia in 3rd trimester",
  "Liver transplant post‑op day 2",
  "Traumatic brain injury",
  "ARDS on ventilator",
  "Vascular surgery post‑op"
]

**Candidates: Number of Sentences**

In [None]:
bucket = random.choice([(8, 9), (10, 12), (14, 17)])
n_sentences = random.randint(*bucket)

**Building the prompt**

In [None]:
def make_prompt(template, scenarios):
    bucket = random.choice([(8, 9), (10, 12), (14, 17)])
    n_sentences = random.randint(*bucket)
    scenario = random.choice(scenarios)

    return PROMPT_TMPL.format(
        scenario=scenario,
        n_sentences = n_sentences
    )

In [None]:
make_prompt(PROMPT_TMPL, SCENARIOS)

# Generation functionality

In [None]:
def generate_from_vllm(
    prompt: str,
    temperature: float = 0.7,
    max_tokens: int = 1024,
    top_p: float = 0.95,
    seed: int = None,
    retries: int = 3,
    delay: int = 2,
) -> str:
    """
    Send prompt to local vLLM server and return the generated text.
    
    Args:
      prompt: The text prompt to complete.
      temperature: Sampling temperature.
      max_tokens: Maximum number of tokens to generate.
      top_p: Nucleus sampling cutoff.
      seed: Optional RNG seed to get deterministic outputs.
      stop_sequences: Optional list of strings; generation will stop before any of them.
      retries: How many times to retry on failure.
      delay: Seconds to wait between retries.
    Returns:
      The completed text (empty string if all attempts failed).
    """
    payload: dict[str, any] = {
        "prompt": prompt,
        "temperature": temperature,
        "max_tokens": max_tokens,
        "top_p": top_p,
    }
    if seed is not None:
        payload["seed"] = seed
    
    payload["stop_sequences"] = ["***END NOTE***"]
    for attempt in range(1, retries + 1):
        try:
            r = requests.post(VLLM_URL, json=payload)
            r.raise_for_status()
            text = r.json()["choices"][0]["text"]
            return text.strip()
        except Exception as e:
            print(f"[Warning] vLLM call failed (attempt {attempt}): {e}")
            time.sleep(delay)

    print("[Error] All vLLM attempts failed, returning empty string.")
    return ""


In [None]:
def remove_end_tag(raw_output: str, end_token: str = "***END NOTE***") -> str:
    """
    Extracts the clinical note from a raw LLM response by
    removing everything starting with the end_token.
    """
    # Find the position of the end token
    idx = raw_output.find(end_token)
    if idx != -1:
        # Return everything before the end token, stripped of extra whitespace
        return raw_output[:idx].strip()
    # If no end token found, just return the trimmed raw output
    return raw_output.strip()

# Generate

In [None]:
n = 15

In [None]:
make_prompt(PROMPT_TMPL, SCENARIOS)

In [None]:
generated_notes = []
for i in tqdm(range(n), desc="Generating notes", total=n):
    prompt = make_prompt(PROMPT_TMPL, SCENARIOS)
    temp      = random.uniform(0.7, 1.0)
    top_p     = random.uniform(0.7, 1.0)
    max_tokens  = random.randint(700, 1800)
    seed      = random.randint(0, 2**30)

    raw = generate_from_vllm(
        prompt=prompt,
        temperature=temp,
        top_p=top_p,
        max_tokens=max_tokens,
        seed=seed,
        retries=3,
        delay=1.0
    )
    
    note = remove_end_tag(raw, end_token="***END NOTE***")
    generated_notes.append(note)

In [None]:
'''
outputs = []
for i in tqdm(range(n), total=n, desc="Generating notes "):
    note = generate_from_vllm(base_prompt)
    outputs.append({"id": i, "note-excerpt": remove_end_tag(note)})
''';

In [None]:
sampled_outputs = random.sample(generated_notes, min(n, 15))

In [None]:
for sample in sampled_outputs:
    print((sample))
    print(len(sample))
    print("\n")
    print("--------------------------------------------------")
    print("\n")

# Save

In [None]:
'''
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
out_path = OUTPUT_DIR / f"generated_{prompt_path.stem}_{timestamp}.csv"
pd.DataFrame(generated_notes).to_csv(out_path, index=False)
print(f"✅ Saved: {out_path}")
''';

# Backup

In [None]:
def generate_from_vllm_old(prompt: str, temperature=0.7, max_tokens=1024, retries=3, delay=2, top_p = 0.95) -> str:
    """Send prompt to local vLLM server and return the generated text."""
    payload = {
        "prompt": prompt,
        "temperature": temperature,
        "max_tokens": max_tokens,
        "top_p": top_p
    }
    for attempt in range(retries):
        try:
            response = requests.post(VLLM_URL, json=payload)
            response.raise_for_status()
            return response.json()["choices"][0]["text"].strip()
        except Exception as e:
            print(f"[Warning] Generation failed (attempt {attempt+1}/{retries}): {e}")
            time.sleep(delay)
    return "[Generation Failed]"