In [None]:
# ===========================
# JUPYTER CELL: Launch Best_SSL_Model_Selection.py (run inside notebook)
# ===========================
import os, sys, subprocess, pathlib, glob, os.path as op, re

PROJECT_DIR = "/work/projects/myproj/Linear_Probing_For_SSL"
SCRIPT = "Best_SSL_Model_Selection.py"
REPO_ROOT = str(pathlib.Path(PROJECT_DIR).resolve().parent)

def _detect_user_base():
    aau = glob.glob("/work/Member Files:*")
    if aau:
        return op.basename(aau[0])
    sdu = [d for d in glob.glob("/work/*#*") if op.isdir(d)]
    return op.basename(sdu[0]) if sdu else ""

USER_BASE_DIR = os.environ.get("USER_BASE_DIR") or _detect_user_base() or ""
if USER_BASE_DIR:
    os.environ["USER_BASE_DIR"] = USER_BASE_DIR
WORK_ROOT = pathlib.Path("/work") / USER_BASE_DIR if USER_BASE_DIR else pathlib.Path.cwd()

# Which classes to probe:
# "leu"  -> only Leucocyte
# "epi"  -> only Squamous Epithelial Cell
# "all"  -> both
os.environ["RFDETR_PROBE_TARGET"] = "epi"

# Probe settings
# Input mode:
#   - "640": full-image mode (no patching)
#   - any other positive integer (e.g. "224"): patch mode with that patch size
RFDETR_INPUT_MODE = "640"
os.environ["RFDETR_INPUT_MODE"] = RFDETR_INPUT_MODE
mode_key = RFDETR_INPUT_MODE.strip().lower()
# Clear legacy size knobs to avoid stale notebook env affecting runs.
os.environ.pop("RFDETR_PATCH_SIZE", None)
os.environ.pop("RFDETR_FULL_RESOLUTION", None)
# Backward-compatible toggle for older scripts
os.environ["RFDETR_USE_PATCH_224"] = "0" if mode_key in {"640", "640x640", "full", "full640", "full_640"} else "1"
os.environ["RFDETR_TRAIN_FRACTION"] = "0.125"
os.environ["RFDETR_FRACTION_SEED"] = "42"
os.environ["SEED"] = "42"
os.environ.setdefault("RFDETR_PARALLEL_GPUS", "8")
os.environ.setdefault("NUM_WORKERS", "8")

# Required paths (mounted-root friendly defaults)
os.environ.setdefault("SSL_CKPT_ROOT", str(WORK_ROOT / "SSL_Checkpoints"))
os.environ.setdefault("STAT_DATASETS_ROOT", "/work/projects/myproj/SOLO_Supervised_RFDETR/Stat_Dataset")
os.environ.setdefault("DATASET_PREFIX_LEU", "QA-2025v2_Leucocyte_OVR")
os.environ.setdefault("DATASET_PREFIX_EPI", "QA-2025v2_SquamousEpithelialCell_OVR")

# Optional explicit dataset overrides (if set, they override STAT_DATASETS_ROOT+prefix)
# os.environ["DATASET_LEUCO_DIR"] = "/work/projects/myproj/SOLO_Supervised_RFDETR/Stat_Dataset/QA-2025v2_Leucocyte_OVR_..."
# os.environ["DATASET_EPI_DIR"] = "/work/projects/myproj/SOLO_Supervised_RFDETR/Stat_Dataset/QA-2025v2_SquamousEpithelialCell_OVR_..."

# Where outputs are written: script creates session_YYYYMMDD_HHMMSS under this
os.environ.setdefault("OUTPUT_BASE", "/work/projects/myproj/Linear_Probing_For_SSL/SSL_SELECTION")

# Optional: explicit checkpoint list (comma-separated). If unset, script scans SSL_CKPT_ROOT.
# os.environ["SSL_CKPTS"] = "epoch_epoch-004.ckpt,epoch_epoch-009.ckpt,epoch_epoch-014.ckpt,epoch_epoch-029.ckpt,last.ckpt"

# Optional image root used for resolving file_name entries in COCO json
# os.environ["IMAGES_FALLBACK_ROOT"] = "/work/Member Files:yourname/CellScanData/Zoom10x - Quality Assessment_Cleaned"

# CUDA allocator safety
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
os.environ.setdefault("TF32", "1")

wd = pathlib.Path(PROJECT_DIR).resolve()
py = sys.executable
cmd = [py, "-u", str(wd / SCRIPT)]

print("\n[LAUNCH]")
print(" cwd:", wd)
print(" cmd:", cmd)
print(" env: USER_BASE_DIR=", os.environ.get("USER_BASE_DIR"))
print(" env: WORK_ROOT=", WORK_ROOT)
print(" env: RFDETR_PROBE_TARGET=", os.environ.get("RFDETR_PROBE_TARGET"))
print(" env: RFDETR_INPUT_MODE=", os.environ.get("RFDETR_INPUT_MODE"))
print(" env: RFDETR_USE_PATCH_224=", os.environ.get("RFDETR_USE_PATCH_224"))
print(" env: SSL_CKPT_ROOT=", os.environ.get("SSL_CKPT_ROOT"))
print(" env: STAT_DATASETS_ROOT=", os.environ.get("STAT_DATASETS_ROOT"))
print(" env: DATASET_PREFIX_LEU=", os.environ.get("DATASET_PREFIX_LEU"))
print(" env: DATASET_PREFIX_EPI=", os.environ.get("DATASET_PREFIX_EPI"))
print(" env: RFDETR_TRAIN_FRACTION=", os.environ.get("RFDETR_TRAIN_FRACTION"))
print(" env: RFDETR_FRACTION_SEED=", os.environ.get("RFDETR_FRACTION_SEED"))
print(" env: RFDETR_PARALLEL_GPUS=", os.environ.get("RFDETR_PARALLEL_GPUS"))
print(" env: OUTPUT_BASE=", os.environ.get("OUTPUT_BASE"))

# Stream output live with compact progress formatting
print("\n========== LIVE OUTPUT ==========")
EPOCHS_EXPECTED = int(os.environ.get("RFDETR_EPOCHS_EXPECTED", "40"))
p = subprocess.Popen(
    cmd,
    cwd=str(wd),
    text=True,
    bufsize=1,
    stdout=subprocess.PIPE,
    stderr=subprocess.STDOUT,
)
current_ssl = "unknown"
for raw in p.stdout:
    line = raw.rstrip("\n")

    m_ssl = re.search(r"SSL=([^ ]+)\s+â†’", line)
    if m_ssl:
        current_ssl = m_ssl.group(1)
        print(f"\n[TRACK] SSL checkpoint: {current_ssl}")
        print(line)
        continue

    if line.startswith("[RUN]") or line.startswith("[PLAN]") or line.startswith("[PROBE]") or line.startswith("[PARALLEL]") or line.startswith("[DONE]") or line.startswith("[FINAL]"):
        print(line)
        continue

    if "rf-detr-large.pth:" in line:
        if "100%|" in line:
            print(f"[TRACK][{current_ssl}] backbone download complete")
        continue

    if line.startswith("Namespace("):
        print(f"[TRACK][{current_ssl}] trainer config loaded")
        continue

    m_epoch = re.search(r"^Epoch:\s*\[(\d+)\]\s*\[(\d+)/(\d+)\].*?lr:\s*([0-9.eE+-]+).*?loss:\s*([0-9.]+)", line)
    if m_epoch:
        ep = int(m_epoch.group(1)) + 1
        it_cur = m_epoch.group(2)
        it_tot = m_epoch.group(3)
        lr = m_epoch.group(4)
        loss = m_epoch.group(5)
        print(f"[EPOCH][{current_ssl}] {ep}/{EPOCHS_EXPECTED}  iter {it_cur}/{it_tot}  lr={lr}  loss={loss}")
        continue

    if line.startswith("Epoch: [") and "Total time:" in line:
        print(f"[EPOCH][{current_ssl}] {line}")
        continue

    m_test = re.search(r"^Test:\s*\[(\d+)/(\d+)\]", line)
    if m_test:
        t_cur = int(m_test.group(1))
        t_tot = int(m_test.group(2))
        if t_cur == 0 or t_cur % 10 == 0 or t_cur == (t_tot - 1):
            print(f"[TEST][{current_ssl}] batch {t_cur + 1}/{t_tot}")
        continue

    if line.startswith("Test: Total time:"):
        print(f"[TEST][{current_ssl}] {line}")
        continue

    if line.startswith(" Average Precision") or line.startswith(" Average Recall"):
        print(f"[METRIC][{current_ssl}] {line.strip()}")
        continue

    if line.startswith("Accumulating evaluation results") or line.startswith("IoU metric:"):
        print(f"[TEST][{current_ssl}] {line}")
        continue

    if line.startswith("Loading pretrain weights") or line.startswith("Start training"):
        print(f"[TRACK][{current_ssl}] {line}")
        continue

    if line.startswith("Averaged stats:"):
        print(f"[EPOCH][{current_ssl}] {line[:220]}...")
        continue

    if line.startswith("UserWarning:"):
        print(f"[WARN][{current_ssl}] {line}")
        continue
p.wait()
print("\n[RETURNCODE]", p.returncode)
if p.returncode != 0:
    raise subprocess.CalledProcessError(p.returncode, cmd)


