# MedGemma Impact Challenge: model usage examples

This notebook contains **example code patterns** for each model family listed in `gamma-model-list.md`.

- Replace the `..._MODEL_ID` placeholders with the exact checkpoint IDs your environment supports.
- Some model families are *encoders* (produce embeddings), others are *generators* (produce text).


In [None]:
# If running on Kaggle/Colab you likely already have torch installed.
# If running locally and imports fail, this cell will attempt to install minimal deps.

import os
import sys
import math
import importlib
import subprocess
from typing import List

def _pip_install(packages: List[str]) -> None:
    cmd = [sys.executable, "-m", "pip", "install", "-q", "-U", *packages]
    print("Running:", " ".join(cmd))
    subprocess.check_call(cmd)

try:
    import torch
except ModuleNotFoundError:
    # CPU-only torch from PyPI typically works on Linux/macOS/Windows.
    # If you have CUDA and want GPU wheels locally, install torch per https://pytorch.org/get-started/locally/
    _pip_install(["torch"])
    importlib.invalidate_caches()
    import torch

try:
    import transformers  # noqa: F401
except ModuleNotFoundError:
    _pip_install(["transformers>=4.40", "accelerate", "safetensors", "sentencepiece"])
    importlib.invalidate_caches()

try:
    from PIL import Image  # type: ignore
except Exception:
    try:
        _pip_install(["pillow"])
        importlib.invalidate_caches()
        from PIL import Image  # type: ignore
    except Exception:
        Image = None

from transformers import (
    AutoTokenizer,
    AutoProcessor,
    AutoModel,
    AutoModelForCausalLM,
    pipeline,
)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
if DEVICE == "cuda" and torch.cuda.is_bf16_supported():
    DTYPE = torch.bfloat16
elif DEVICE == "cuda":
    DTYPE = torch.float16
else:
    # float16 on CPU is often unsupported/slow; float32 is the safe default.
    DTYPE = torch.float32

print({"device": DEVICE, "dtype": str(DTYPE), "torch": torch.__version__})

Running: /mnt/cgm-atlas/onrh/projects/med-gamma-kaggle/.venv/bin/python -m pip install -q -U torch


  cpu = _conversion_method_template(device=torch.device("cpu"))


Running: /mnt/cgm-atlas/onrh/projects/med-gamma-kaggle/.venv/bin/python -m pip install -q -U transformers>=4.40 accelerate safetensors sentencepiece
Running: /mnt/cgm-atlas/onrh/projects/med-gamma-kaggle/.venv/bin/python -m pip install -q -U pillow


  from .autonotebook import tqdm as notebook_tqdm


{'device': 'cuda', 'dtype': 'torch.bfloat16', 'torch': '2.10.0+cu128'}


In [3]:
# Optional (Kaggle): discover locally-mounted model folders under /kaggle/input
# Many Kaggle competitions provide models as local folders containing a config.json.

from pathlib import Path
import json

def find_hf_model_dirs(root: str = "/kaggle/input") -> list[str]:
    root_path = Path(root)
    if not root_path.exists():
        return []
    model_dirs: list[str] = []
    for cfg in root_path.rglob("config.json"):
        try:
            # Avoid picking up tokenizer configs, etc. but keep it simple.
            if cfg.is_file() and cfg.stat().st_size > 0:
                model_dirs.append(str(cfg.parent))
        except Exception:
            continue
    # De-duplicate + keep stable order
    seen = set()
    uniq = []
    for d in sorted(model_dirs):
        if d not in seen:
            seen.add(d)
            uniq.append(d)
    return uniq

candidates = find_hf_model_dirs()
print(f"Found {len(candidates)} local HF model folders")
for p in candidates[:30]:
    print(" -", p)

# Tip: set these env vars to one of the printed folders (or a HF Hub id if you have access).
# os.environ["MEDGEMMA_TEXT_MODEL_ID"] = "<paste-path-or-hf-id>"
# os.environ["MEDGEMMA_MM_MODEL_ID"]   = "<paste-path-or-hf-id>"


Found 0 local HF model folders


## Model ID helper (local Kaggle folders or HF Hub)

The exact checkpoint IDs vary by competition setup. The cell below will:
- list local model folders in `/kaggle/input` (if present)
- suggest a mapping to the env vars used in this notebook
- show how to set the env vars

In [4]:
# Suggest model-id mapping from local folders (if available).
from pathlib import Path

def _suggest_model_ids(paths: list[str]) -> dict[str, str]:
    # naive keyword-based suggestions; edit as needed
    keys = {
        "medgemma_text": "MEDGEMMA_TEXT_MODEL_ID",
        "medgemma_mm": "MEDGEMMA_MM_MODEL_ID",
        "medasr": "MEDASR_MODEL_ID",
        "medsiglip": "MEDSIGLIP_MODEL_ID",
        "cxr": "FOUNDATION_IMAGE_MODEL_ID",
        "derm": "FOUNDATION_IMAGE_MODEL_ID",
        "path": "FOUNDATION_IMAGE_MODEL_ID",
        "hear": "HEAR_MODEL_ID",
    }
    suggestions: dict[str, str] = {}
    for p in paths:
        name = Path(p).name.lower()
        for kw, env in keys.items():
            if kw in name and env not in suggestions:
                suggestions[env] = p
    return suggestions

# Reuse the list from the previous cell if present
try:
    _candidates = candidates  # type: ignore[name-defined]
except Exception:
    _candidates = find_hf_model_dirs() if "find_hf_model_dirs" in globals() else []

print("Available local model folders:")
for p in _candidates[:50]:
    print(" -", p)
if len(_candidates) > 50:
    print(f"... and {len(_candidates) - 50} more")

suggested = _suggest_model_ids(_candidates)
print("\nSuggested env var mapping (edit as needed):")
for k in (
    "MEDGEMMA_TEXT_MODEL_ID",
    "MEDGEMMA_MM_MODEL_ID",
    "MEDASR_MODEL_ID",
    "MEDSIGLIP_MODEL_ID",
    "FOUNDATION_IMAGE_MODEL_ID",
    "HEAR_MODEL_ID",
):
    print(f"{k} = {suggested.get(k, '<fill-me>')}")

# Example: set env vars to local paths
# os.environ["MEDGEMMA_TEXT_MODEL_ID"] = "<paste-path-or-hf-id>"
# os.environ["MEDGEMMA_MM_MODEL_ID"]   = "<paste-path-or-hf-id>"

Available local model folders:

Suggested env var mapping (edit as needed):
MEDGEMMA_TEXT_MODEL_ID = <fill-me>
MEDGEMMA_MM_MODEL_ID = <fill-me>
MEDASR_MODEL_ID = <fill-me>
MEDSIGLIP_MODEL_ID = <fill-me>
FOUNDATION_IMAGE_MODEL_ID = <fill-me>
HEAR_MODEL_ID = <fill-me>


In [9]:
# Optional: download a Kaggle Model via kagglehub and auto-set MEDGEMMA_TEXT_MODEL_ID.
# This is useful if you have a Kaggle model handle like:
#   keras/medgemma/keras/medgemma_1.5_instruct_4b
#
# Note: Transformers expects a Hugging Face-style folder (with config.json).
# Some KaggleHub downloads are Keras-native; in that case we can't use AutoModelForCausalLM directly.
import os
from pathlib import Path
import importlib
import subprocess
import sys

def _pip_install(packages):
    cmd = [sys.executable, "-m", "pip", "install", "-q", "-U", *packages]
    print("Running:", " ".join(cmd))
    subprocess.check_call(cmd)

try:
    import kagglehub  # type: ignore
except ModuleNotFoundError:
    _pip_install(["kagglehub"])
    importlib.invalidate_caches()
    import kagglehub  # type: ignore

KAGGLEHUB_HANDLE = os.getenv("KAGGLEHUB_MEDGEMMA_TEXT_HANDLE", "keras/medgemma/keras/medgemma_1.5_instruct_4b")
print("kagglehub handle:", KAGGLEHUB_HANDLE)

try:
    downloaded_path = Path(kagglehub.model_download(KAGGLEHUB_HANDLE))
    print("Downloaded to:", downloaded_path)
except Exception as e:
    print("kagglehub download failed:", repr(e))
    print("If you're not on Kaggle, you may need Kaggle credentials configured.")
    downloaded_path = None

def _find_hf_dir(root: Path) -> str | None:
    if (root / "config.json").exists():
        return str(root)
    # Search for a HF-style config.json under the downloaded directory.
    for cfg in root.rglob("config.json"):
        try:
            if cfg.is_file() and cfg.stat().st_size > 0:
                return str(cfg.parent)
        except Exception:
            continue
    return None

if downloaded_path is not None:
    hf_dir = _find_hf_dir(downloaded_path)
    if hf_dir:
        os.environ["MEDGEMMA_TEXT_MODEL_ID"] = hf_dir
        print("Set MEDGEMMA_TEXT_MODEL_ID =", hf_dir)
    else:
        print("No HuggingFace-style config.json found under the downloaded model.")
        print("This likely means it's a Keras-native checkpoint; the Transformers examples won't load it.")
        print("In that case, use Keras/KerasNLP examples for MedGemma instead.")

Running: /mnt/cgm-atlas/onrh/projects/med-gamma-kaggle/.venv/bin/python -m pip install -q -U kagglehub


  from .autonotebook import tqdm as notebook_tqdm


kagglehub handle: keras/medgemma/keras/medgemma_1.5_instruct_4b


Downloading 10 files:   0%|          | 0/10 [00:00<?, ?it/s]

Downloading to /home/onrh/.cache/kagglehub/models/keras/medgemma/keras/medgemma_1.5_instruct_4b/1/model_00000.weights.h5...




Downloading to /home/onrh/.cache/kagglehub/models/keras/medgemma/keras/medgemma_1.5_instruct_4b/1/config.json...



100%|██████████| 2.09k/2.09k [00:00<00:00, 1.88MB/s]

Downloading to /home/onrh/.cache/kagglehub/models/keras/medgemma/keras/medgemma_1.5_instruct_4b/1/image_converter.json...





Downloading to /home/onrh/.cache/kagglehub/models/keras/medgemma/keras/medgemma_1.5_instruct_4b/1/preprocessor.json...
Downloading to /home/onrh/.cache/kagglehub/models/keras/medgemma/keras/medgemma_1.5_instruct_4b/1/metadata.json...



[A

100%|██████████| 931/931 [00:00<00:00, 163kB/s]




  0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading to /home/onrh/.cache/kagglehub/models/keras/medgemma/keras/medgemma_1.5_instruct_4b/1/assets/tokenizer/vocabulary.spm...


100%|██████████| 2.72k/2.72k [00:00<00:00, 537kB/s]
100%|██████████| 190/190 [00:00<00:00, 116kB/s]

[A

Downloading to /home/onrh/.cache/kagglehub/models/keras/medgemma/keras/medgemma_1.5_instruct_4b/1/model_00001.weights.h5...
Downloading to /home/onrh/.cache/kagglehub/models/keras/medgemma/keras/medgemma_1.5_instruct_4b/1/model.weights.json...




[A[A


[A[A[A


100%|██████████| 79.4k/79.4k [00:00<00:00, 359kB/s]


Downloading to /home/onrh/.cache/kagglehub/models/keras/medgemma/keras/medgemma_1.5_instruct_4b/1/tokenizer.json...





100%|██████████| 633/633 [00:00<00:00, 872kB/s]


Downloading to /home/onrh/.cache/kagglehub/models/keras/medgemma/keras/medgemma_1.5_instruct_4b/1/task.json...





100%|██████████| 6.19k/6.19k [00:00<00:00, 14.9MB/s]

[A

[A[A
[A

[A[A
100%|██████████| 4.47M/4.47M [00:01<00:00, 4.56MB/s]
Downloading 10 files:  10%|█         | 1/10 [00:02<00:18,  2.02s/it]

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A

Downloaded to: /home/onrh/.cache/kagglehub/models/keras/medgemma/keras/medgemma_1.5_instruct_4b/1
Set MEDGEMMA_TEXT_MODEL_ID = /home/onrh/.cache/kagglehub/models/keras/medgemma/keras/medgemma_1.5_instruct_4b/1





In [10]:
# Download a couple of example images and store paths for later cells.
# Note: some hosts (including Wikimedia) may return HTTP 403/429 to automated downloads;
# we use generated placeholder images here to keep the notebook runnable everywhere.
from pathlib import Path
import urllib.request
import urllib.error

SAMPLE_DIR = Path("sample_images")
SAMPLE_DIR.mkdir(exist_ok=True, parents=True)

example_images = {
    # Generated placeholder images (no special access rules, no licensing concerns).
    "image_a": "https://placehold.co/512x512/png?text=Sample+Image+A",
    "image_b": "https://placehold.co/512x512/png?text=Sample+Image+B",
}

def _download(url: str, dest: Path) -> Path:
    if dest.exists():
        return dest

    print(f"Downloading {url} -> {dest}")
    req = urllib.request.Request(
        url,
        headers={
            # A minimal UA is usually enough to prevent 403 from some CDNs.
            "User-Agent": "Mozilla/5.0 (X11; Linux x86_64) Python-urllib/3",
            "Accept": "*/*",
        },
    )
    try:
        with urllib.request.urlopen(req, timeout=60) as resp, open(dest, "wb") as f:
            f.write(resp.read())
    except urllib.error.HTTPError as e:
        # Keep the notebook actionable: explain what to do next.
        raise RuntimeError(
            f"Failed to download sample image (HTTP {e.code}). "
            "If you are offline or behind a restrictive network, "
            "set SAMPLE_IMAGE_PATH to a local file instead."
        ) from e
    return dest

SAMPLE_IMAGE_PATH = str(_download(example_images["image_a"], SAMPLE_DIR / "sample_a.png"))
SAMPLE_IMAGE_PATH_B = str(_download(example_images["image_b"], SAMPLE_DIR / "sample_b.png"))

print("Sample images:")
print(" -", SAMPLE_IMAGE_PATH)
print(" -", SAMPLE_IMAGE_PATH_B)

Sample images:
 - sample_images/sample_a.png
 - sample_images/sample_b.png


## 1) MedGemma 27B (text-only) — text generation

Use this pattern for any **text-only** MedGemma checkpoint (instruction-tuned or base).


In [1]:
import os
from pathlib import Path

# Try (in order): explicit env var -> auto-suggested local Kaggle folder -> user fill-in.
MEDGEMMA_TEXT_MODEL_ID = os.getenv("MEDGEMMA_TEXT_MODEL_ID", "").strip()

if not MEDGEMMA_TEXT_MODEL_ID:
    # If earlier helper cells ran, they may have populated a `suggested` dict.
    try:
        _suggested = suggested  # type: ignore[name-defined]
    except Exception:
        _suggested = {}

    candidate = _suggested.get("MEDGEMMA_TEXT_MODEL_ID") if isinstance(_suggested, dict) else None
    if isinstance(candidate, str) and candidate and Path(candidate).exists():
        MEDGEMMA_TEXT_MODEL_ID = candidate
        print(f"Auto-selected MEDGEMMA_TEXT_MODEL_ID = {MEDGEMMA_TEXT_MODEL_ID}")

if not MEDGEMMA_TEXT_MODEL_ID:
    print("MEDGEMMA_TEXT_MODEL_ID is not set.")
    print("- On Kaggle: run the model-discovery cells above and then set:")
    print("  os.environ['MEDGEMMA_TEXT_MODEL_ID'] = '<path under /kaggle/input>'")
    print("- Or set it in the notebook environment variables.")
    print("Skipping text-generation example in this cell.")
else:
    tokenizer = AutoTokenizer.from_pretrained(MEDGEMMA_TEXT_MODEL_ID)
    model = AutoModelForCausalLM.from_pretrained(
        MEDGEMMA_TEXT_MODEL_ID,
        device_map="auto" if DEVICE == "cuda" else None,
        torch_dtype=DTYPE if DEVICE == "cuda" else None,
    )

    prompt = "You are a medical assistant. Summarize: patient has fever and cough for 3 days."  # example only
    inputs = tokenizer(prompt, return_tensors="pt")
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    with torch.inference_mode():
        out = model.generate(**inputs, max_new_tokens=200, do_sample=False)

    print(tokenizer.decode(out[0], skip_special_tokens=True))

MEDGEMMA_TEXT_MODEL_ID is not set.
- On Kaggle: run the model-discovery cells above and then set:
  os.environ['MEDGEMMA_TEXT_MODEL_ID'] = '<path under /kaggle/input>'
- Or set it in the notebook environment variables.
Skipping text-generation example in this cell.


## 2) MedGemma 4B / 27B (multimodal) — image + text → text

This pattern is typical for **multimodal** vision-language checkpoints.

- You need a local image file (or a PIL image).
- Transformers model class can differ between releases; the `try/except` below handles common variants.


In [7]:
# MedGemma 1.5 (multimodal) example using Hugging Face Transformers.
# Requested model: google/medgemma-1.5-4b-it
#
# Notes:
# - This HF Hub repo may require access approval and/or a token.
# - On Kaggle, add your HF token as a Secret named HUGGINGFACE_HUB_TOKEN (or HF_TOKEN).
import os
import sys
import subprocess
import importlib
from pathlib import Path

def _pip_install(packages):
    cmd = [sys.executable, "-m", "pip", "install", "-q", "-U", *packages]
    print("Running:", " ".join(cmd))
    subprocess.check_call(cmd)

# Make this cell runnable even if earlier setup cells weren't executed.
try:
    from transformers import AutoProcessor, AutoTokenizer  # type: ignore
    import transformers  # noqa: F401
except ModuleNotFoundError:
    _pip_install(["transformers>=4.40", "accelerate", "safetensors", "sentencepiece"])
    importlib.invalidate_caches()
    from transformers import AutoProcessor, AutoTokenizer  # type: ignore
    import transformers  # noqa: F401

try:
    from PIL import Image  # type: ignore
except Exception:
    _pip_install(["pillow"])
    importlib.invalidate_caches()
    from PIL import Image  # type: ignore

MODEL_ID = "google/medgemma-1.5-4b-it"
MEDGEMMA_MM_MODEL_ID = MODEL_ID
os.environ["MEDGEMMA_MM_MODEL_ID"] = MODEL_ID
print("Using model:", MODEL_ID)

if not (os.getenv("HUGGINGFACE_HUB_TOKEN") or os.getenv("HF_TOKEN")):
    print("Tip: if model download fails with 401/403, set HUGGINGFACE_HUB_TOKEN (or HF_TOKEN).")

# Prefer AutoModelForImageTextToText when available; fall back for older transformers.
MMModel = None
for _name in ("AutoModelForImageTextToText", "AutoModelForVision2Seq", "AutoModelForCausalLM"):
    try:
        MMModel = getattr(__import__("transformers", fromlist=[_name]), _name)
        break
    except Exception:
        continue

if MMModel is None:
    raise ImportError(
        "Could not find a compatible multimodal model class in your installed transformers. "
        "Try upgrading: pip install -U 'transformers>=4.40'"
    )

processor = AutoProcessor.from_pretrained(MODEL_ID)
mm_model = MMModel.from_pretrained(
    MODEL_ID,
    device_map="auto" if DEVICE == "cuda" else None,
    torch_dtype=DTYPE if DEVICE == "cuda" else None,
    # Some Hugging Face gated repos require a token. Transformers will pick it up from env vars
    # (HUGGINGFACE_HUB_TOKEN / HF_TOKEN) automatically.
 )

# Use sample image if available; otherwise set your own path.
image_path = globals().get("SAMPLE_IMAGE_PATH", "")
if not image_path or image_path == "<path-to-image>":
    # If the sample-image cell hasn't been run, fall back to a local file if present.
    fallback = Path("sample_images/sample_a.png")
    image_path = str(fallback) if fallback.exists() else "<path-to-image>"
print("Image path:", image_path)

image = Image.open(image_path).convert("RGB")
prompt = "Describe the key findings in this image."  # example only
inputs = processor(text=prompt, images=image, return_tensors="pt")

# Move tensors to the model device (works for both CPU and GPU).
inputs = {k: (v.to(mm_model.device) if hasattr(v, "to") else v) for k, v in inputs.items()}

with torch.inference_mode():
    out = mm_model.generate(**inputs, max_new_tokens=200, do_sample=False)

# Decode output tokens robustly.
if hasattr(processor, "batch_decode"):
    text = processor.batch_decode(out, skip_special_tokens=True)[0]
else:
    mm_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    text = mm_tokenizer.decode(out[0], skip_special_tokens=True)

print(text)

NameError: name 'Image' is not defined

## 3) MedASR — speech → text

Use the `automatic-speech-recognition` pipeline.


In [None]:
MEDASR_MODEL_ID = os.getenv("MEDASR_MODEL_ID", "<fill-me>")

asr = pipeline(
    task="automatic-speech-recognition",
    model=MEDASR_MODEL_ID,
    device=0 if DEVICE == "cuda" else -1,
)

# Replace with a real WAV/FLAC file path.
audio_path = "<path-to-audio>"
result = asr(audio_path)
print(result["text"] if isinstance(result, dict) and "text" in result else result)


## 4) MedSigLIP — image encoder embeddings

This pattern works for CLIP/SigLIP-like encoders that produce an image embedding.


In [None]:
MEDSIGLIP_MODEL_ID = os.getenv("MEDSIGLIP_MODEL_ID", "<fill-me>")

if MEDSIGLIP_MODEL_ID in ("", "<fill-me>"):
    raise ValueError(
        "Set MEDSIGLIP_MODEL_ID to a local model folder or a Hugging Face model id. "
        "(On Kaggle, run the earlier cell to list /kaggle/input candidates.)"
    )

if Image is None:
    raise RuntimeError("Pillow is required for image examples. Install pillow.")

siglip_processor = AutoProcessor.from_pretrained(MEDSIGLIP_MODEL_ID)
siglip_model = AutoModel.from_pretrained(
    MEDSIGLIP_MODEL_ID,
    device_map="auto" if DEVICE == "cuda" else None,
    torch_dtype=DTYPE if DEVICE == "cuda" else None,
)

# Use sample image if available; otherwise set your own path.
image_path = globals().get("SAMPLE_IMAGE_PATH", "<path-to-image>")
image = Image.open(image_path).convert("RGB")

inputs = siglip_processor(images=image, return_tensors="pt")
inputs = {k: v.to(siglip_model.device) if hasattr(v, "to") else v for k, v in inputs.items()}

with torch.inference_mode():
    if hasattr(siglip_model, "get_image_features"):
        emb = siglip_model.get_image_features(**inputs)
    else:
        out = siglip_model(**inputs)
        last_hidden = out.last_hidden_state  # [batch, seq, dim]
        emb = last_hidden.mean(dim=1)

emb = torch.nn.functional.normalize(emb, dim=-1)
print("embedding shape:", tuple(emb.shape))

## 5) CXR Foundation / Derm Foundation / Path Foundation — image embeddings

These are typically **image encoders**. Use the same embedding pattern, then do similarity search or train a small classifier head.


In [None]:
FOUNDATION_IMAGE_MODEL_ID = os.getenv("FOUNDATION_IMAGE_MODEL_ID", "<fill-me>")

if FOUNDATION_IMAGE_MODEL_ID in ("", "<fill-me>"):
    raise ValueError(
        "Set FOUNDATION_IMAGE_MODEL_ID to a local model folder or a Hugging Face model id. "
        "(On Kaggle, run the earlier cell to list /kaggle/input candidates.)"
    )

if Image is None:
    raise RuntimeError("Pillow is required for image examples. Install pillow.")

foundation_processor = AutoProcessor.from_pretrained(FOUNDATION_IMAGE_MODEL_ID)
foundation_model = AutoModel.from_pretrained(
    FOUNDATION_IMAGE_MODEL_ID,
    device_map="auto" if DEVICE == "cuda" else None,
    torch_dtype=DTYPE if DEVICE == "cuda" else None,
)

def embed_image(path: str) -> torch.Tensor:
    img = Image.open(path).convert("RGB")
    batch = foundation_processor(images=img, return_tensors="pt")
    batch = {k: v.to(foundation_model.device) if hasattr(v, "to") else v for k, v in batch.items()}
    with torch.inference_mode():
        if hasattr(foundation_model, "get_image_features"):
            vec = foundation_model.get_image_features(**batch)
        else:
            out = foundation_model(**batch)
            vec = out.last_hidden_state.mean(dim=1)
    return torch.nn.functional.normalize(vec, dim=-1).cpu()

# Example: cosine similarity between two images
img_a = globals().get("SAMPLE_IMAGE_PATH", "<path-to-image-a>")
img_b = globals().get("SAMPLE_IMAGE_PATH_B", "<path-to-image-b>")
va = embed_image(img_a)
vb = embed_image(img_b)
cos_sim = float((va * vb).sum(dim=-1))
print("cosine similarity:", cos_sim)

## 6) HeAR (Lung Acoustics) — audio embeddings

HeAR-style models are typically **audio encoders**. The exact processor signature varies, so this cell includes a fallback.


In [None]:
HEAR_MODEL_ID = os.getenv("HEAR_MODEL_ID", "<fill-me>")

hear_processor = AutoProcessor.from_pretrained(HEAR_MODEL_ID)
hear_model = AutoModel.from_pretrained(
    HEAR_MODEL_ID,
    device_map="auto" if DEVICE == "cuda" else None,
    torch_dtype=DTYPE if DEVICE == "cuda" else None,
)

# Provide audio as a file path or as a 1D float array. Many processors accept raw audio.
audio_path = "<path-to-audio>"

def embed_audio(path: str) -> torch.Tensor:
    # Most audio processors accept `audio=...` or `raw_speech=...`.
    # If this fails, load audio yourself (e.g., librosa) and pass the array + sampling_rate.
    try:
        batch = hear_processor(audio=path, return_tensors="pt")
    except TypeError:
        batch = hear_processor(raw_speech=path, return_tensors="pt")

    if DEVICE == "cuda":
        batch = {k: v.to(hear_model.device) if hasattr(v, "to") else v for k, v in batch.items()}

    with torch.inference_mode():
        out = hear_model(**batch)
        if hasattr(out, "pooler_output") and out.pooler_output is not None:
            vec = out.pooler_output
        elif hasattr(out, "last_hidden_state"):
            vec = out.last_hidden_state.mean(dim=1)
        else:
            vec = out[0].mean(dim=1)

    return torch.nn.functional.normalize(vec, dim=-1).cpu()

emb = embed_audio(audio_path)
print("embedding shape:", tuple(emb.shape))


## Optional: make a single config dict

If you want to centralize checkpoint IDs, fill these in once and reuse.


In [None]:
MODEL_IDS = {
    "medgemma_text": MEDGEMMA_TEXT_MODEL_ID,
    "medgemma_mm": MEDGEMMA_MM_MODEL_ID,
    "medasr": MEDASR_MODEL_ID,
    "medsiglip": MEDSIGLIP_MODEL_ID,
    "foundation_image": FOUNDATION_IMAGE_MODEL_ID,  # cxr/derm/path
    "hear": HEAR_MODEL_ID,
}

MODEL_IDS
