# Run the refactored CXR MIL project in Jupyter

This notebook:
1) Unzips the project  
2) Installs it in editable mode  
3) Verifies PyTorch + CUDA  
4) Runs training via the CLI module

Set the `ZIP_PATH` and `DATA_ROOT` variables below to match your environment.

In [None]:
# --- 0) Set paths ---
import os, sys, pathlib

ZIP_PATH = "/workspace/pyproject.zip"   # <-- change if needed
PROJECT_DIR = "/workspace/project_flexible_preserve_aspect_ratio"   # where to unzip
DATA_ROOT = "/workspace/data"                                 # <-- change if needed


print("Python:", sys.executable)
print("ZIP_PATH exists?", os.path.exists(ZIP_PATH))
print("DATA_ROOT exists?", os.path.exists(DATA_ROOT))

In [None]:
# --- 1) Unzip project ---
import shutil, subprocess, pathlib, os

# If project already exists → skip unzip
if os.path.exists(PROJECT_DIR) and os.listdir(PROJECT_DIR):
    print("Project already exists, skipping unzip.")
else:
    print("Extracting project to:", PROJECT_DIR)
    os.makedirs(PROJECT_DIR, exist_ok=True)

    cmd = f'unzip -q "{ZIP_PATH}" -d "{PROJECT_DIR}"'
    print("Running:", cmd)
    subprocess.check_call(cmd, shell=True)

# Detect actual project root (some zips include a single top-level folder)
entries = [p for p in pathlib.Path(PROJECT_DIR).iterdir()]
if len(entries) == 1 and entries[0].is_dir():
    PROJECT_ROOT = str(entries[0])
else:
    PROJECT_ROOT = PROJECT_DIR

print("PROJECT_ROOT:", PROJECT_ROOT)


In [None]:
# --- 2) Install project (editable) ---
import subprocess, sys

# If torch/torchvision are missing in THIS kernel env, install them first, then rerun this cell.
cmd = [sys.executable, "-m", "pip", "install", "-e", PROJECT_ROOT]
print("Running:", " ".join(cmd))
subprocess.check_call(cmd)

In [None]:
# --- 3) Verify torch + CUDA ---
import torch
print("torch:", torch.__version__)
print("cuda available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("gpu:", torch.cuda.get_device_name(0))
    print("capability:", torch.cuda.get_device_capability(0))

In [None]:
import sys
sys.path.append(PROJECT_ROOT)

In [None]:
# --- 4) Quick sanity: import the package ---
import cxr_mil
print("cxr_mil imported from:", cxr_mil.__file__)

## Optional: download dataset from Kaggle

This matches the original notebook’s Kaggle download step.

**Dataset slug:** `alexandrostsikalas/grand-xray-slam-resized-512`

### What you need
- A Kaggle API token (`access_token`).
  - In Kaggle: *Account → API → Create New Token*.
  - Upload `access_token` to the notebook environment.

If your data already exists under `DATA_ROOT`, you can skip this section.


In [None]:
# add access_token in the workspace folder
!mkdir -p /root/.kaggle
!mv /workspace/access_token /root/.kaggle/access_token
!chmod 600 /root/.kaggle/access_token
!pip install kaggle


In [None]:
%%bash
set -euo pipefail

DATA_DIR=/workspace/data
LOG_FILE="$DATA_DIR/download_progress.log"
mkdir -p "$DATA_DIR"
: > "$LOG_FILE"

log(){ echo "$@" | tee -a "$LOG_FILE"; }

log "==== DOWNLOAD SESSION START $(date) ===="
log "DATA_DIR=$DATA_DIR"
log "LOG_FILE=$LOG_FILE"

# -----------------------------
# Monitor download size: append to LOG_FILE every +5GB downloaded (log-only)
# Sums sizes of ALL .zip files in DATA_DIR (handles multi-zip downloads).
# -----------------------------
(
  THRESHOLD=$((5*1024*1024*1024))  # 5GB in bytes
  NEXT=$THRESHOLD

  while true; do
    total=0
    shopt -s nullglob
    for f in "$DATA_DIR"/*.zip; do
      sz=$(stat -c%s "$f" 2>/dev/null || echo 0)
      total=$((total + sz))
    done
    shopt -u nullglob

    while [ "$total" -ge "$NEXT" ]; do
      gb=$((NEXT/1024/1024/1024))
      echo "Download progress: ${gb}GB reached" >> "$LOG_FILE"
      NEXT=$((NEXT + THRESHOLD))
    done

    sleep 5
  done
) &
MON_PID=$!

cleanup_monitor() {
  kill "$MON_PID" 2>/dev/null || true
  wait "$MON_PID" 2>/dev/null || true
}
trap cleanup_monitor EXIT

# -----------------------------
# 1) Download (choose ONE). Progress bar stays in notebook.
# -----------------------------
log "Downloading..."

# OPTION A: Competition
# kaggle competitions download -c grand-xray-slam-division-a -p "$DATA_DIR"

# OPTION B: Dataset
kaggle datasets download -d alexandrostsikalas/grand-xray-slam-resized-512 -p "$DATA_DIR"

cleanup_monitor
trap - EXIT
log "Download finished."

# -----------------------------
# 2) Unzip quietly + remove ZIP(s)
# -----------------------------
log "Looking for ZIP files in $DATA_DIR ..."
shopt -s nullglob
zips=( "$DATA_DIR"/*.zip )
shopt -u nullglob

if [ ${#zips[@]} -eq 0 ]; then
  log "ERROR: No ZIP files found in $DATA_DIR"
  exit 1
fi

log "Unzipping ${#zips[@]} file(s)..."
for z in "${zips[@]}"; do
  log "Extracting $(basename "$z") ..."
  unzip -oq "$z" -d "$DATA_DIR" >> "$LOG_FILE" 2>&1
  rm -f "$z"
  log "Removed $(basename "$z")"
done

log "DONE $(date)"

In [None]:
# In Jupyter terminal
!rm -r ~/.local/share/Trash

In [None]:
import os, sys, subprocess, shlex, textwrap

WEIGHTS_DIR = "/workspace/weights"
os.makedirs(WEIGHTS_DIR, exist_ok=True)

# -----------------------------
# URLs (official sources)
# -----------------------------
CHEX_URL = "https://huggingface.co/torchxrayvision/densenet121-res224-chex/resolve/main/model.pt"
ALL_DENSENET_URL = "https://github.com/mlmed/torchxrayvision/releases/download/v1/nih-pc-chex-mimic_ch-google-openi-kaggle-densenet121-d121-tw-lr001-rot45-tr15-sc15-seed0-best.pt"
ALL_RESNET50_512_URL = "https://github.com/mlmed/torchxrayvision/releases/download/v1/pc-nih-rsna-siim-vin-resnet50-test512-e400-state.pt"

# -----------------------------
# Local paths
# -----------------------------
SRC_CHEX   = f"{WEIGHTS_DIR}/densenet121-res224-chex__model.pt"
SRC_ALL_DN = f"{WEIGHTS_DIR}/densenet121-res224-all.pt"
SRC_ALL_R50= f"{WEIGHTS_DIR}/resnet50-res512-all.pt"

OUT_CHEX   = f"{WEIGHTS_DIR}/densenet121-res224-chex__state_dict.pth"
OUT_ALL_DN = f"{WEIGHTS_DIR}/densenet121-res224-all__state_dict.pth"
OUT_ALL_R50= f"{WEIGHTS_DIR}/resnet50-res512-all__state_dict.pth"

# Minimum “this is not an HTML error page” sizes (bytes). Very conservative.
MIN_SIZES = {
    SRC_CHEX:   5_000_000,   # ~28MB expected
    SRC_ALL_DN: 5_000_000,   # ~28MB expected
    SRC_ALL_R50:10_000_000,  # ~90MB expected
}

def have(cmd):
    return subprocess.call(["bash","-lc", f"command -v {shlex.quote(cmd)} >/dev/null 2>&1"]) == 0

def run_bash(cmd):
    print("\n>>> bash -lc", cmd)
    subprocess.check_call(["bash", "-lc", cmd])

def ok_file(path, min_size=1):
    return os.path.exists(path) and os.path.getsize(path) >= min_size

def download(url, dst, min_size=1):
    # If already good, keep it
    if ok_file(dst, min_size):
        print(f"OK: exists {dst} ({os.path.getsize(dst)/1e6:.1f} MB)")
        return

    # Remove bad/partial file
    if os.path.exists(dst):
        print(f"Removing incomplete file: {dst}")
        os.remove(dst)

    tmp = dst + ".part"
    if os.path.exists(tmp):
        os.remove(tmp)

    if have("curl"):
        run_bash(f"curl -fL --retry 3 --retry-delay 2 -o {shlex.quote(tmp)} {shlex.quote(url)}")
    else:
        run_bash(f"wget -q --show-progress --progress=bar:force:noscroll -O {shlex.quote(tmp)} -L {shlex.quote(url)}")

    if not ok_file(tmp, min_size):
        raise RuntimeError(f"Download failed or file too small: {url} -> {tmp}")

    os.replace(tmp, dst)
    print(f"Saved: {dst} ({os.path.getsize(dst)/1e6:.1f} MB)")

# -----------------------------
# 1) Download (idempotent)
# -----------------------------
download(CHEX_URL, SRC_CHEX, MIN_SIZES[SRC_CHEX])
download(ALL_DENSENET_URL, SRC_ALL_DN, MIN_SIZES[SRC_ALL_DN])
download(ALL_RESNET50_512_URL, SRC_ALL_R50, MIN_SIZES[SRC_ALL_R50])

# -----------------------------
# 2) Install minimal deps only if needed (idempotent)
# -----------------------------
# We need torchxrayvision ONLY if a file is pickled and requires allowlisted unpickle.
# Installing it is cheap (no-deps), and safe to re-run.
run_bash(f"{shlex.quote(sys.executable)} -m pip install -U --no-deps torchxrayvision")

# Ensure small deps torchxrayvision imports (also safe to re-run)
run_bash(f"{shlex.quote(sys.executable)} -m pip install -U imageio scikit-image pandas tqdm requests pillow")

# -----------------------------
# 3) Convert to pure state_dict (idempotent)
# -----------------------------
def convert_if_needed(src, dst):
    if ok_file(dst, 1_000_000):  # converted state_dicts should not be tiny
        print(f"OK: converted exists {dst} ({os.path.getsize(dst)/1e6:.1f} MB)")
        return

    py = textwrap.dedent(r"""
    import torch, sys

    SRC = sys.argv[1]
    DST = sys.argv[2]

    def is_state_dict(d):
        if not isinstance(d, dict) or len(d) == 0:
            return False
        v = next(iter(d.values()))
        return torch.is_tensor(v)

    def extract_state_dict(obj):
        if is_state_dict(obj):
            return obj
        if isinstance(obj, dict):
            for k in ["state_dict","model_state_dict","model","net","backbone","encoder"]:
                if k in obj:
                    cand = obj[k]
                    if hasattr(cand, "state_dict"):
                        return cand.state_dict()
                    if is_state_dict(cand):
                        return cand
            raise ValueError("Can't find state_dict in checkpoint. Keys: " + str(list(obj.keys())[:30]))
        if hasattr(obj, "state_dict"):
            return obj.state_dict()
        raise TypeError("Unsupported loaded object type: " + str(type(obj)))

    def allowlisted_unpickle_load(path):
        import torch.serialization
        import torchxrayvision as xrv

        allow = []
        for name in ["DenseNet", "_DenseLayer", "_DenseBlock", "_Transition",
                     "ResNet", "BasicBlock", "Bottleneck"]:
            if hasattr(xrv.models, name):
                allow.append(getattr(xrv.models, name))
        if not allow:
            raise RuntimeError("Could not build allowlist from torchxrayvision.models")

        with torch.serialization.safe_globals(allow):
            return torch.load(path, map_location="cpu", weights_only=False)

    print("\n=== Converting ===")
    print("SRC:", SRC)
    print("DST:", DST)

    # Try safe weights-only load first
    try:
        obj = torch.load(SRC, map_location="cpu", weights_only=True)
        print("Loaded with weights_only=True (safe).")
    except Exception as e:
        msg = str(e)
        if ("Weights only load failed" in msg) or ("Unsupported global" in msg) or ("safe_globals" in msg):
            print("Pickle detected. Using allowlisted unpickle (trusted source)...")
            obj = allowlisted_unpickle_load(SRC)
        else:
            raise

    sd = extract_state_dict(obj)
    torch.save(sd, DST)

    print("Saved:", DST)
    print("Num keys:", len(sd))
    print("First keys:", list(sd.keys())[:8])
    """)

    subprocess.check_call([sys.executable, "-c", py, src, dst])
    if not ok_file(dst, 1_000_000):
        raise RuntimeError(f"Conversion produced suspiciously small file: {dst}")

convert_if_needed(SRC_CHEX, OUT_CHEX)
convert_if_needed(SRC_ALL_DN, OUT_ALL_DN)
convert_if_needed(SRC_ALL_R50, OUT_ALL_R50)

print("\n✅ Final files (use these in YAML checkpoint_path):")
print("  ", OUT_CHEX)
print("  ", OUT_ALL_DN)
print("  ", OUT_ALL_R50)


In [None]:
import subprocess, sys

subprocess.check_call([sys.executable, "-m", "pip", "install", "imageio"])
print("✔ imageio installed")


## Run training

This calls the CLI using the **same Python** as the notebook kernel (`sys.executable`).

In [None]:
import sys, subprocess, shlex, os

CONFIGS_DIR = os.path.join(PROJECT_ROOT, "configs")

yamls = [

    "exp12_imagenet_densenet121_512_resize_preserve_ratio.yaml",

]

for y in yamls:
    CONFIG_PATH = os.path.join(CONFIGS_DIR, y)
    if not os.path.exists(CONFIG_PATH):
        raise FileNotFoundError(f"Missing YAML: {CONFIG_PATH}")

    cmd = [
        sys.executable, "-m", "cxr_mil.train_cv",
        "--root", DATA_ROOT,
        "--config", CONFIG_PATH,
    ]

    print("\n" + "="*100)
    print("Running:", y)
    print("Command:\n", " ".join(shlex.quote(c) for c in cmd))

    rc = subprocess.call(cmd)
    print("Process finished with return code:", rc)

    if rc != 0:
        raise RuntimeError(f"Training failed for {y} (return code {rc}).")

print("\n✅ All experiments finished.")


## Optional: check GPU status

Run this in a separate cell to see GPU utilization.

In [None]:
import subprocess
try:
    subprocess.check_call("nvidia-smi", shell=True)
except Exception as e:
    print("nvidia-smi not available:", e)