<a href="https://colab.research.google.com/github/Long2511/ai-project/blob/main/colab/Project_AI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

0) Installs dependencies

In [None]:
import sys, subprocess, os
def sh(cmd): print(cmd); subprocess.run(cmd, shell=True, check=True)

# CUDA 12.1-compatible torch for Colab
!python3 -m pip -q install --index-url https://download.pytorch.org/whl/cu121 "torch==2.5.1" "torchaudio==2.5.1" "torchvision==0.20.1"
# Core libs: ColPali, transformers, qdrant-client (multi-vector), OCR, PDF
!python3 -m pip -q install "transformers>=4.53.1,<4.54.0" colpali-engine==0.3.12 "qdrant-client>=1.7.3,<2" accelerate sentencepiece pdf2image pytesseract
# System deps for OCR/PDF
sh('apt-get -y update && apt-get -y install tesseract-ocr poppler-utils')

1) Mount Google Drive

In [None]:
try:
    from google.colab import drive
    drive.mount("/content/drive", force_remount=True)
    print("[colab] Drive mounted at /content/drive")
except Exception as e:
    print("[note] Not in Colab or Drive mount failed:", e)

2) Config

In [None]:
SOURCE_DIR         = "/content/drive/MyDrive/Project-AI/PDF-Data"
MODEL_NAME         = "vidore/colpali-v1.2-hf"

# Outputs
PAGES_JSONL        = "/content/clinical_cases_index.jsonl"
CASES_JSONL        = "/content/clinical_cases_cases.jsonl"
STRUCT_JSONL       = "/content/clinical_cases_cases_structured.jsonl"
SFT_EXTRACT_JSONL  = "/content/clinical_cases_extract_sft.jsonl"
SFT_DX_JSONL       = "/content/clinical_cases_dx_sft.jsonl"

# Qdrant collections
COLLECTION_PAGES   = "tropical_cases_colpali_pages"
COLLECTION_CASES   = "tropical_cases_colpali_cases"
VECTOR_SIZE        = 128  # ColPali subvector dim

# Toggles
INDEX_PAGES        = True
INDEX_CASES        = True
ENABLE_OCR         = True
ENABLE_PDF         = True
BATCH              = 2      # embedding batch size
MAX_FILES          = None   # set small int to smoke-test

# Qdrant remote (REST-only)
QDRANT_HOST        = "165.22.56.15"
QDRANT_PORT        = 6334    # QRPcQDRANT_API_KEY     = os.getenv("QDRANT_API_KEY") or None
QDRANT_TIMEOUT     = 1200.0  # large to be safe
UPSERT_BATCH       = 12      # points per upsert() call; keep small for reliability
UPSERT_MAX_RETRIES = 6

3) Imports

In [None]:
import re, json, glob, hashlib, io, time
from typing import List, Dict, Any, Tuple
from PIL import Image
from tqdm import tqdm
import torch

from transformers import ColPaliForRetrieval, ColPaliProcessor
from qdrant_client import QdrantClient, models

try:
    import pytesseract
except Exception as e:
    print("[warn] OCR disabled:", e); ENABLE_OCR=False

try:
    from pdf2image import convert_from_path
except Exception as e:
    print("[warn] PDF→image disabled:", e); ENABLE_PDF=False

4) Scan / OCR / sectionize

In [None]:
def list_media_recursive(root: str) -> List[str]:
    pats = ["**/*.png","**/*.jpg","**/*.jpeg","**/*.pdf"]
    out=[]; [out.extend(glob.glob(os.path.join(root, p), recursive=True)) for p in pats]
    return sorted(out)

def load_image(path: str) -> Image.Image:
    return Image.open(path).convert("RGB")

def parse_page_number_from_path(path: str, default: int = 1) -> int:
    if "#page=" in path:
        try: return int(path.split("#page=")[1])
        except: return default
    m = re.search(r"_page_(\d+)", path, re.I)
    return int(m.group(1)) if m else default

def derive_case_id_from_path(path: str) -> int:
    base = path.split("#page=")[0]
    stem = os.path.splitext(os.path.basename(base))[0]
    stem = re.sub(r"_page_\d+$", "", stem, flags=re.I)
    m = re.match(r"^\s*(\d+)\b", stem)
    if m: return int(m.group(1))
    return int(hashlib.sha1(stem.encode()).hexdigest()[:6], 16)

def guess_title_from_path(p: str) -> str:
    base = os.path.basename(p.split("#page=")[0])
    base = re.sub(r"_page_\d+$", "", os.path.splitext(base)[0], flags=re.I)
    return re.sub(r"[-_]+", " ", base).strip()

SECTION_HEADERS = [
    "Clinical Presentation","History","Clinical Findings","Laboratory Findings",
    "Laboratory Results","Laboratory Investigations","Additional Investigations",
    "Investigations","Questions","Discussion","Answer to Question",
    "Summary Box","Further Reading","The Case Continued","The Case Continued…"
]
SEC_RX = re.compile(rf"(^|\n)\s*({'|'.join([re.escape(h) for h in SECTION_HEADERS])})\s*\n", re.I)

def sectionize(text: str) -> Dict[str,str]:
    parts = SEC_RX.split(text or "")
    out, cur, buf = {}, "body", []
    for chunk in parts:
        if chunk and chunk.strip() in SECTION_HEADERS:
            if buf: out[cur] = "\n".join(buf).strip(); buf=[]
            cur = chunk.strip()
        else:
            if chunk: buf.append(chunk)
    if buf: out[cur] = "\n".join(buf).strip()
    return out

def ocr_text_from_image(img: Image.Image) -> str:
    if not ENABLE_OCR: return ""
    try: return pytesseract.image_to_string(img)
    except Exception: return ""

def load_all_pages(root: str) -> Tuple[List[Image.Image], List[Dict[str, Any]]]:
    files = list_media_recursive(root)
    print(f"[scan] total candidates under {root} -> {len(files)}")
    out_imgs, out_meta = [], []
    for p in files:
        if MAX_FILES and len(out_imgs) >= MAX_FILES: break
        ext = os.path.splitext(p)[1].lower()

        if ext == ".pdf" and ENABLE_PDF:
            try:
                pages = convert_from_path(p, dpi=220)
                for i, im in enumerate(pages, start=1):
                    if MAX_FILES and len(out_imgs) >= MAX_FILES: break
                    out_imgs.append(im)
                    out_meta.append({
                        "case_title": guess_title_from_path(p),
                        "path": f"{p}#page={i}",
                        "page_number": i,
                        "case_id": derive_case_id_from_path(p),
                        "source": "PDF page",
                        "modality": "page_image"
                    })
            except Exception as e:
                print("[pdf] failed ->", p, "|", e)

        elif ext in (".png",".jpg",".jpeg"):
            try:
                im = load_image(p)
                out_imgs.append(im)
                out_meta.append({
                    "case_title": guess_title_from_path(p),
                    "path": p,
                    "page_number": parse_page_number_from_path(p, 1),
                    "case_id": derive_case_id_from_path(p),
                    "source": "image",
                    "modality": "page_image"
                })
            except Exception as e:
                print("[img] failed ->", p, "|", e)

    print(f"[scan] pages prepared: {len(out_imgs)}")
    return out_imgs, out_meta

5) ColPali embed

In [None]:
def load_colpali(device=None):
    device = device or ("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"[env] device={device}\n[env] loading ColPali…")

    model = ColPaliForRetrieval.from_pretrained(
        MODEL_NAME,
        torch_dtype=(torch.bfloat16 if torch.cuda.is_available() else torch.float32),
    ).eval()
    model.to(device)

    processor = ColPaliProcessor.from_pretrained(MODEL_NAME)
    print("[env] ColPali loaded.")
    return model, processor, device

@torch.no_grad()
def embed_images(model, processor, device, pil_images: List[Image.Image]):
    batch = processor(images=pil_images).to(device)
    emb = model(**batch).embeddings  # [B, N, 128] multivectors
    return [e.to("cpu").float().tolist() for e in emb]  # List[List[List[float]]]

@torch.no_grad()
def embed_queries(model, processor, device, queries: List[str]):
    batch = processor(text=queries).to(device)
    emb = model(**batch).embeddings
    return [e.to("cpu").float().tolist() for e in emb]

6) Qdrant connect (REST-only)

In [None]:
def connect_qdrant_rest():
    url = f"http://{QDRANT_HOST}:{QDRANT_PORT}"
    client = QdrantClient(
        url=url,
        api_key=QDRANT_API_KEY,
        timeout=QDRANT_TIMEOUT,
        prefer_grpc=False,  # <- HARD disable gRPC
    )
    # sanity call
    client.get_collections()
    print(f"[qdrant] connected (REST-only): {url}")
    return client

def ensure_collection(client: QdrantClient, name: str):
    try:
        exists = client.collection_exists(name)
    except Exception:
        # REST fallback path if needed
        try:
            client.http.collections_api.get_collection(name)
            exists = True
        except Exception:
            exists = False

    if not exists:
        print(f"[qdrant] creating collection: {name}")
        client.create_collection(
            collection_name=name,
            vectors_config=models.VectorParams(
                size=VECTOR_SIZE,
                distance=models.Distance.COSINE,
                multivector_config=models.MultiVectorConfig(
                    comparator=models.MultiVectorComparator.MAX_SIM
                ),
            ),
            on_disk_payload=True,
            hnsw_config=models.HnswConfigDiff(m=32, ef_construct=128),
            optimizers_config=models.OptimizersConfigDiff(default_segment_number=2),
        )
    else:
        print(f"[qdrant] collection exists: {name}")

# Deterministic 63-bit IDs (so re-runs are true updates)
def stable_point_id(key: str) -> int:
    h = hashlib.sha1(key.encode("utf-8")).hexdigest()
    return int(h[:15], 16) & ((1<<63)-1)

7) Upsert helper (uses client.upsert in batches)

In [None]:
def upsert_points(client: QdrantClient, name: str, vectors, payloads, id_keys):
    assert len(vectors) == len(payloads) == len(id_keys)
    total = len(id_keys)
    print(f"[qdrant] upserting {total} points → {name} with client.upsert, batch={UPSERT_BATCH}")
    for i in range(0, total, UPSERT_BATCH):
        chunk_vecs  = vectors[i:i+UPSERT_BATCH]
        chunk_pl    = payloads[i:i+UPSERT_BATCH]
        chunk_ids   = id_keys[i:i+UPSERT_BATCH]

        points = [
            models.PointStruct(
                id=stable_point_id(k),
                vector=v,           # multivector: List[List[float]] (subvector size=128)
                payload=p
            ) for v, p, k in zip(chunk_vecs, chunk_pl, chunk_ids)
        ]

        # retry with exponential backoff
        for attempt in range(UPSERT_MAX_RETRIES):
            try:
                client.upsert(collection_name=name, points=points, wait=False)
                break
            except Exception as e:
                if attempt == UPSERT_MAX_RETRIES - 1:
                    print(f"[qdrant] upsert batch {i//UPSERT_BATCH+1} FAILED permanently:", repr(e))
                    raise
                sleep_s = 1.5 * (2 ** attempt)
                print(f"[qdrant] upsert batch {i//UPSERT_BATCH+1} retry {attempt+1} in {sleep_s:.1f}s →", repr(e))
                time.sleep(sleep_s)

8) Build PAGE-level index + JSONL

In [None]:
def build_pages_index(root: str):
    model, processor, device = load_colpali()
    client = connect_qdrant_rest()
    imgs, meta = load_all_pages(root)
    if not imgs:
        raise RuntimeError("No pages prepared. Check SOURCE_DIR and file types (png/jpg/pdf).")

    if INDEX_PAGES:
        ensure_collection(client, COLLECTION_PAGES)
        print("[embed] pages…")
        all_vecs = []
        for i in tqdm(range(0, len(imgs), BATCH)):
            vecs = embed_images(model, processor, device, imgs[i:i+BATCH])  # list of multivectors
            all_vecs.extend(vecs)

        # Deterministic IDs from the page path
        ids = [f"page::{m['path']}" for m in meta]
        upsert_points(client, COLLECTION_PAGES, all_vecs, meta, ids)

    print("[jsonl] writing per-page:", PAGES_JSONL)
    with open(PAGES_JSONL, "w", encoding="utf-8") as fh:
        for im, m in zip(imgs, meta):
            rec = dict(m)
            txt = ocr_text_from_image(im)
            if txt:
                rec["ocr_text"] = txt
                rec["sections"] = sectionize(txt)
            fh.write(json.dumps(rec, ensure_ascii=False) + "\n")

    print(f"[done] pages → {PAGES_JSONL}")
    return client, (model, processor, device)

9. Merge PAGES → unique CASES

In [None]:
def merge_pages_to_cases(pages_jsonl: str, out_cases_jsonl: str) -> str:
    print("[merge] reading:", pages_jsonl)
    groups: Dict[int, List[Dict[str, Any]]] = {}
    with open(pages_jsonl, "r", encoding="utf-8") as f:
        for line in f:
            rec = json.loads(line)
            cid = rec.get("case_id")
            try: cid = int(cid)
            except: cid = derive_case_id_from_path(rec["path"])
            rec["case_id"] = cid
            groups.setdefault(cid, []).append(rec)

    n_cases = 0
    with open(out_cases_jsonl, "w", encoding="utf-8") as out:
        for cid, records in groups.items():
            pages = sorted(records, key=lambda r: int(r.get("page_number", parse_page_number_from_path(r["path"],1))))
            title = pages[0].get("case_title") or guess_title_from_path(pages[0]["path"])
            merged_texts = [p.get("ocr_text") for p in pages if p.get("ocr_text")]
            merged_text = "\n\n".join(
                (f"--- PAGE {p.get('page_number', parse_page_number_from_path(p['path'],1))} ---\n{p.get('ocr_text','')}".strip())
                for p in pages if p.get("ocr_text")
            ).strip() if merged_texts else None

            merged_sections: Dict[str,str] = {}
            for p in pages:
                secs = p.get("sections") or {}
                for k, v in secs.items():
                    if not v: continue
                    merged_sections[k] = (merged_sections.get(k, "") + ("\n\n" if merged_sections.get(k) else "") + v).strip()

            case_obj = {
                "case_id": cid,
                "case_title": title,
                "n_pages": len(pages),
                "page_paths": [p["path"] for p in pages],
                "merged": {"ocr_text": merged_text, "sections": merged_sections or None}
            }
            out.write(json.dumps(case_obj, ensure_ascii=False) + "\n")
            n_cases += 1

    print(f"[done] cases → {out_cases_jsonl} | total unique cases: {n_cases}")
    return out_cases_jsonl

10) Structured extractor

In [None]:
SYMPTOM_LEXICON = {
    "fever":["fever","pyrexia","febrile"], "cough":["cough"], "headache":["headache","cephalgia"],
    "rash":["rash","maculopapular","petechiae","urticaria","vesicular"], "diarrhea":["diarrhea","diarrhoea","loose stools"],
    "vomiting":["vomit","emesis"], "abdominal pain":["abdominal pain","abd pain","stomach pain"],
    "jaundice":["jaundice","icterus"], "weight loss":["weight loss"], "night sweats":["night sweats"],
    "myalgia":["myalgia","muscle pain"], "arthralgia":["arthralgia","joint pain"], "dyspnea":["dyspnea","shortness of breath"],
    "chest pain":["chest pain"], "pruritus":["pruritus","itch"], "hematuria":["hematuria","blood in urine"],
    "hematochezia":["hematochezia"], "melena":["melena","melaena"], "conjunctivitis":["conjunctivitis","red eyes"],
    "proptosis":["proptosis","eye bulging"], "bleeding":["bleeding","hemorrhage","haemorrhage","gum bleeding"],
    "confusion":["confusion","altered mental state"], "lymphadenopathy":["lymphadenopathy","swollen nodes"],
    "hepatosplenomegaly":["hepatosplenomegaly","hepatomegaly","splenomegaly"], "ulcer":["ulcer","eschar","chancre"],
    "lesion":["lesion","plaque","nodule","papule","pustule"], "itchy boil":["boil","furuncle","myiasis"],
}
NEGATION_RX = re.compile(r"\b(no|not|denies?|without|absence of)\b", re.I)
VITALS_PATTERNS = {
    "temperature_c": r"(?:temp(?:erature)?|t)\s*[:=]?\s*(\d{1,2}(?:\.\d)?)\s*°?\s*c",
    "temperature_f": r"(?:temp(?:erature)?|t)\s*[:=]?\s*(\d{2,3}(?:\.\d)?)\s*°?\s*f",
    "hr_bpm":        r"(?:hr|heart\s*rate|pulse)\s*[:=]?\s*(\d{2,3})\s*bpm?",
    "bp":            r"(?:bp|blood\s*pressure)\s*[:=]?\s*(\d{2,3})\s*/\s*(\d{2,3})",
    "rr":            r"(?:rr|respiratory\s*rate)\s*[:=]?\s*(\d{1,2})\s*/?min|\brr\s*(\d{1,2})\b",
    "spo2":          r"(?:spo2|sat(?:uration)?|oxygen\s*saturation)\s*[:=]?\s*(\d{2,3})\s*%",
    "height_cm":     r"(?:height|ht)\s*[:=]?\s*(\d{2,3})\s*cm",
    "weight_kg":     r"(?:weight|wt)\s*[:=]?\s*(\d{1,3}(?:\.\d)?)\s*kg",
}
LAB_PATTERNS = {
    "hb_g_dl":          (r"\b(?:hb|hemoglobin|haemoglobin)\s*[:=]?\s*(\d{1,2}(?:\.\d)?)\s*g/?dl", "g/dL"),
    "wbc_10e9_l":       (r"\b(?:wbc|white\s*blood\s*cell[s]?)\s*[:=]?\s*(\d{1,2}(?:\.\d)?)\s*x?\s*10\^?9\s*/?\s*l", "10^9/L"),
    "plt_10e9_l":       (r"\b(?:plt|platelet[s]?)\s*[:=]?\s*(\d{2,3}(?:\.\d+)?)\s*x?\s*10\^?9\s*/?\s*l", "10^9/L"),
    "crp_mg_l":         (r"\bcrp\s*[:=]?\s*(\d{1,3}(?:\.\d)?)\s*mg/?l", "mg/L"),
    "esr_mm_h":         (r"\besr\s*[:=]?\s*(\d{1,3})\s*mm/?h", "mm/h"),
    "alt_u_l":          (r"\balt\s*[:=]?\s*(\d{1,4})\s*u/?l", "U/L"),
    "ast_u_l":          (r"\bast\s*[:=]?\s*(\d{1,4})\s*u/?l", "U/L"),
    "bilirubin_umol_l": (r"\bbilirubin\s*[:=]?\s*(\d{1,4})\s*(?:µ?mol/?l|umol/?l)", "µmol/L"),
    "creatinine_umol_l":(r"\bcreatinine\s*[:=]?\s*(\d{1,4})\s*(?:µ?mol/?l|umol/?l)", "µmol/L"),
    "sodium_mmol_l":    (r"\b(?:na|sodium)\s*[:=]?\s*(\d{2,3})\s*mmol/?l", "mmol/L"),
    "potassium_mmol_l": (r"\b(?:k|potassium)\s*[:=]?\s*(\d\.\d|\d{1,2})\s*mmol/?l", "mmol/L"),
    "glucose_mmol_l":   (r"\bglucose\s*[:=]?\s*(\d{1,2}(?:\.\d)?)\s*mmol/?l", "mmol/L"),
}
MICRO_PATTERNS = [
    (r"\bthick\s*smear\b.*\b(positive|negative)\b", "malaria_thick_smear"),
    (r"\bthin\s*smear\b.*\b(positive|negative)\b",  "malaria_thin_smear"),
    (r"\brdt\b.*\b(positive|negative)\b",            "malaria_RDT"),
    (r"\bhiv\b.*\b(positive|negative)\b",            "HIV_test"),
    (r"\bbrucella\b.*\b(agglutination|serology|pcr|culture)\b.*\b(positive|negative)\b", "Brucella_test"),
    (r"\bdengue\b.*\b(ns1|igg|igm|pcr)\b.*\b(positive|negative)\b", "Dengue_test"),
    (r"\bblood\s*culture[s]?\b.*\b(?:for\s+)?([A-Z][a-zA-Z]+)\b.*\b(positive|negative)\b","blood_culture"),
]
IMAGING_KEYS = ["x-ray","cxr","ultrasound","u/s","ct","mri"]

def find_symptoms(text: str) -> List[str]:
    out=set(); t=" "+(text or "").lower()+" "
    for canon, syns in SYMPTOM_LEXICON.items():
        for s in syns:
            for m in re.finditer(rf"\b{s}\b", t):
                window=t[max(0,m.start()-25):m.start()]
                if NEGATION_RX.search(window):
                    continue
                out.add(canon); break
    return sorted(out)

def parse_vitals(text: str) -> Dict[str, Any]:
    t=(text or "").lower(); out={}
    m=re.search(VITALS_PATTERNS["temperature_c"],t)
    if m: out["temperature_c"]=float(m.group(1))
    m=re.search(VITALS_PATTERNS["temperature_f"],t)
    if m and "temperature_c" not in out: out["temperature_c"]=round((float(m.group(1))-32)*5/9,1)
    m=re.search(VITALS_PATTERNS["hr_bpm"],t);  out["hr_bpm"]=int(m.group(1)) if m else None
    m=re.search(VITALS_PATTERNS["bp"],t);      out["bp_mmHg"]=f"{m.group(1)}/{m.group(2)}" if m else None
    m=re.search(VITALS_PATTERNS["rr"],t);      out["rr_min"]=int(m.group(1) or m.group(2)) if m else None
    m=re.search(VITALS_PATTERNS["spo2"],t);    out["spo2_pct"]=int(m.group(1)) if m else None
    m=re.search(VITALS_PATTERNS["height_cm"],t); out["height_cm"]=int(m.group(1)) if m else None
    m=re.search(VITALS_PATTERNS["weight_kg"],t); out["weight_kg"]=float(m.group(1)) if m else None
    if out.get("height_cm") and out.get("weight_kg"):
        h=out["height_cm"]/100.0; out["bmi"]=round(out["weight_kg"]/ (h*h),1)
    return {k:v for k,v in out.items() if v is not None}

def parse_labs(text: str) -> Dict[str, Dict[str, Any]]:
    t=(text or "").lower(); labs={}
    for key,(rx,unit) in LAB_PATTERNS.items():
        m=re.search(rx,t)
        if m:
            try: val=float(m.group(1).replace(",",""))
            except: continue
            labs[key]={"value":val,"unit":unit}
    return labs

def parse_micro(text: str) -> List[Dict[str, Any]]:
    t=(text or "").lower(); out=[]
    for rx,name in MICRO_PATTERNS:
        for m in re.finditer(rx,t,re.I):
            g=m.groups()
            if name=="blood_culture":
                out.append({"test":name,"organism":g[0],"result":g[1].lower()})
            else:
                out.append({"test":name,"result":g[-1].lower()})
    return out

def extract_imaging(text: str) -> List[str]:
    lines=(text or "").splitlines(); hits=[]
    for i,l in enumerate(lines):
        if any(k in l.lower() for k in IMAGING_KEYS):
            hits.append(" ".join(lines[i:i+3]).strip())
    return hits[:10]

def extract_demographics(text: str) -> Dict[str, Any]:
    out={}
    m=re.search(r"(\d{1,3})\s*[-–]?\s*year[- ]old", text or "", re.I)
    if m: out["age"]=int(m.group(1))
    tl=(text or "").lower()
    if any(w in tl for w in ["female","woman","girl"]): out["sex"]="female"
    if any(w in tl for w in ["male","man","boy"]): out.setdefault("sex","male")
    if re.search(r"\b(pregnan(t|cy))\b", tl): out["pregnant"]=True
    m=re.search(r"\bfrom\s+([A-Z][A-Za-z]+(?:\s+[A-Z][A-Za-z]+)*)", text or "")
    if m: out["from_location"]=m.group(1)
    travels=re.findall(r"\b(?:returned|travel(?:ed|led)?|migrant|expatriate)\s+(?:from|to)\s+([A-Z][A-Za-z]+(?:\s+[A-Za-z]+)*)", text or "")
    if travels: out["travel"]=list(dict.fromkeys(travels))
    if "hiv" in tl:
        stat="unknown"
        if re.search(r"hiv.*positive", tl): stat="positive"
        if re.search(r"hiv.*negative", tl): stat="negative"
        out["hiv_status"]=stat
    return out

def extract_diagnoses(full_text: str, sections: Dict[str,str]) -> Dict[str, Any]:
    DIAG_KEYS = {
        "final":[r"\bfinal\s*diagnosis\b", r"\bdefinitive\s*diagnosis\b", r"\bdiagnosis:\b"],
    "provisional":[r"\bprovisional\s*diagnosis\b", r"\bimpression\b"],
        "differential":[r"\bdifferential[s]?\b", r"\bdifferential\s*diagnoses?\b"],
    }
    out={"provisional":None,"differential":None,"final":None}
    t=full_text or ""
    for k,patterns in DIAG_KEYS.items():
        for p in patterns:
            m=re.search(p+r".{0,30}[:]?\s*(.+)", t, re.I)
            if m:
                val=re.split(r"\n|\.  ", m.group(1).strip())[0]
                out[k]=val; break
    if not out["final"]:
        disc=(sections or {}).get("Discussion","") or ""
        m=re.search(r"\bdiagnos(e|is|ed)\b.*?:?\s*(.+)", disc, re.I)
        if m: out["final"]=m.group(2).split("\n")[0].strip()
    return out

def extract_management(text: str) -> Dict[str, Any]:
    tl=(text or "").lower()
    rx=r"(?:treated|given|started|therapy|administered)\s+(?:with\s+)?([A-Za-z][A-Za-z0-9\- ]+)"
    meds=[m.group(1).strip() for m in re.finditer(rx, text or "", re.I)]
    meds=list(dict.fromkeys(meds)) or None
    outcome=None
    if re.search(r"\bimproved|recovered|resolved|discharged\b", tl): outcome="improved"
    if re.search(r"\bdied|death|fatal\b", tl): outcome="died"
    return {"medications":meds,"outcome":outcome}

def build_structured_cases(cases_jsonl: str, out_jsonl: str) -> str:
    n=0
    with open(cases_jsonl,"r",encoding="utf-8") as fin, \
         open(out_jsonl,"w",encoding="utf-8") as fout:
        for line in fin:
            case=json.loads(line)
            merged=case.get("merged") or {}
            text=merged.get("ocr_text") or ""
            secs=merged.get("sections") or sectionize(text)

            hx = secs.get("History","") or ""
            cf = secs.get("Clinical Findings","") or ""
            labs_text = secs.get("Laboratory Findings","") or secs.get("Laboratory Results","") or secs.get("Laboratory Investigations","") or ""
            inv_text  = secs.get("Additional Investigations","") or secs.get("Investigations","") or ""
            disc = secs.get("Discussion","") or ""
            summary = secs.get("Summary Box","") or ""

            demographics = extract_demographics(text + "\n" + hx)
            vitals       = parse_vitals(hx + "\n" + cf)
            symptoms     = find_symptoms(hx + "\n" + cf)
            labs         = parse_labs(labs_text + "\n" + inv_text)
            microbiology = parse_micro(labs_text + "\n" + inv_text)
            imaging      = extract_imaging(inv_text + "\n" + disc)
            diagnoses    = extract_diagnoses(text + "\n" + disc + "\n" + summary, secs)
            management   = extract_management(disc + "\n" + summary)

            signs=[]
            for key in ["rash","lymphadenopathy","hepatosplenomegaly","ulcer","lesion","proptosis","jaundice"]:
                if key in symptoms and re.search(rf"\b{key}\b", (cf or "").lower()):
                    signs.append(key)
            signs=sorted(set(signs))

            obj={
                "case_id": case["case_id"],
                "case_title": case.get("case_title"),
                "n_pages": case.get("n_pages"),
                "page_paths": case.get("page_paths"),
                "patient": demographics or None,
                "presentation": {"symptoms": symptoms or None, "signs": signs or None, "history_text": hx or None},
                "vitals": vitals or None,
                "tests": {"labs": labs or None, "microbiology": microbiology or None, "imaging_findings": imaging or None},
                "diagnoses": diagnoses or None,
                "management": management or None,
                "free_text": {"discussion": disc or None, "summary": summary or None}
            }
            fout.write(json.dumps(obj, ensure_ascii=False) + "\n"); n+=1
    print(f"[done] structured → {out_jsonl} (cases={n})")
    return out_jsonl

11) SFT writers

In [None]:
INSTR_EXTRACT = (
    "Extract the following fields from the clinical case text and return only JSON with keys: "
    "patient (age, sex, from_location, travel, hiv_status, pregnant), "
    "presentation (symptoms[], signs[], history_text), vitals, tests (labs, microbiology, imaging_findings), "
    "diagnoses (provisional, differential, final), management (medications, outcome)."
)

def write_sft_extraction(cases_jsonl: str, struct_jsonl: str, out_sft: str, max_ctx_chars: int = 12000) -> str:
    n=0
    with open(cases_jsonl,"r",encoding="utf-8") as fin_cases, \
         open(struct_jsonl,"r",encoding="utf-8") as fin_struct, \
         open(out_sft,"w",encoding="utf-8") as fout:
        for case_line, struct_line in zip(fin_cases, fin_struct):
            case = json.loads(case_line)
            struct = json.loads(struct_line)
            text = (case.get("merged") or {}).get("ocr_text") or ""
            text = text[:max_ctx_chars]
            ex = {
                "case_id": struct["case_id"],
                "messages": [
                    {"role":"system","content":"You are an accurate clinical information extraction model."},
                    {"role":"user","content": f"{INSTR_EXTRACT}\n\nCASE TITLE: {case.get('case_title')}\n\nTEXT:\n{text}"},
                    {"role":"assistant","content": json.dumps({
                        k:struct[k] for k in ["patient","presentation","vitals","tests","diagnoses","management"]
                    }, ensure_ascii=False)}
                ]
            }
            fout.write(json.dumps(ex, ensure_ascii=False) + "\n"); n+=1
    print(f"[done] SFT-extract → {out_sft} (pairs={n})")
    return out_sft

def write_sft_dx(struct_jsonl: str, out_sft: str) -> str:
    n=0
    with open(struct_jsonl,"r",encoding="utf-8") as fin, open(out_sft,"w",encoding="utf-8") as fout:
        for line in fin:
            s=json.loads(line)
            dx=(s.get("diagnoses") or {}).get("final")
            if not dx:
                continue
            prompt={"role":"user","content":(
                "Given the structured case below, return JSON with keys: final_diagnosis, differentials[], "
                "key_features, next_tests[].\n\nSTRUCTURED_CASE:\n"
                + json.dumps({k:s[k] for k in ["patient","presentation","vitals","tests"]}, ensure_ascii=False)
            )}
            target={"final_diagnosis": dx, "differentials": [], "key_features": (s.get("presentation") or {}).get("symptoms") or [], "next_tests": []}
            ex={"case_id": s["case_id"], "messages":[
                {"role":"system","content":"You are a careful tropical medicine diagnostician."},
                prompt, {"role":"assistant","content": json.dumps(target, ensure_ascii=False)}
            ]}
            fout.write(json.dumps(ex, ensure_ascii=False)+"\n"); n+=1
    print(f"[done] SFT-dx → {out_sft} (pairs={n})")
    return out_sft

12) Case-level Qdrant (one point per case; deterministic IDs; upsert)

In [None]:
def build_case_level_index(cases_jsonl: str, mpd=None):
    model, processor, device = mpd if mpd else load_colpali()
    client = connect_qdrant_rest()
    ensure_collection(client, COLLECTION_CASES)

    def load_one_page_image(path: str) -> Image.Image:
        if "#page=" in path:
            pdf, page = path.split("#page=")[0], int(path.split("#page=")[1])
            pages = convert_from_path(pdf, dpi=220, first_page=page, last_page=page)
            return pages[0]
        return load_image(path)

    vectors, payloads, ids = [], [], []
    with open(cases_jsonl, "r", encoding="utf-8") as f:
        for line in tqdm(f, desc="[case-level] embedding"):
            case = json.loads(line)
            imgs = [load_one_page_image(p) for p in case["page_paths"]]

            # concatenate sub-vectors across pages → one multivector per case
            mv_all = []
            for i in range(0, len(imgs), BATCH):
                batch_imgs = imgs[i:i+BATCH]
                mvs = embed_images(model, processor, device, batch_imgs)  # list of multivectors
                for mv in mvs:
                    mv_all.extend(mv)

            vectors.append(mv_all)  # List[List[float]] with subvector size=128
            payloads.append({
                "case_id": case["case_id"],
                "case_title": case["case_title"],
                "n_pages": case["n_pages"],
                "page_paths": case["page_paths"],
                "modality": "case_multivector",
            })
            ids.append(f"case::{case['case_id']}")

    upsert_points(client, COLLECTION_CASES, vectors, payloads, ids)
    print("[done] case-level collection built →", COLLECTION_CASES)

13) Run the pipeline

In [None]:
def head(path, n=1):
    try:
        with open(path, "r", encoding="utf-8") as f:
            for _ in range(n):
                print(f.readline().rstrip())
    except FileNotFoundError:
        print("missing:", path)

def main():
    client, mpd = build_pages_index(SOURCE_DIR)
    cases_path   = merge_pages_to_cases(PAGES_JSONL, CASES_JSONL)
    struct_path  = build_structured_cases(CASES_JSONL, STRUCT_JSONL)
    write_sft_extraction(CASES_JSONL, STRUCT_JSONL, SFT_EXTRACT_JSONL)
    write_sft_dx(STRUCT_JSONL, SFT_DX_JSONL)
    if INDEX_CASES:
        build_case_level_index(cases_path, mpd=mpd)

    print("\n[files]")
    for p in [PAGES_JSONL, CASES_JSONL, STRUCT_JSONL, SFT_EXTRACT_JSONL, SFT_DX_JSONL]:
        try:
            import subprocess, shlex
            out = subprocess.check_output(shlex.split(f"wc -l {p}")).decode().strip()
            print(out)
        except Exception:
            print("missing:", p)

    print("\n[sample case jsonl]");       head(CASES_JSONL, 1)
    print("\n[sample structured jsonl]"); head(STRUCT_JSONL, 1)

try:
    main()
except Exception as e:
    # No sys.exit; just show the error cleanly
    print("[fatal]", repr(e))

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive
