In [1]:
import os, subprocess, sys
from pathlib import Path

REPO_URL = "https://github.com/DeogenesMaranan/ngiml"  # update to your fork if needed
REPO_DIR = Path("/content/ngiml")

if REPO_DIR.exists():
    subprocess.run(["git", "-C", str(REPO_DIR), "pull"], check=True)
else:
    subprocess.run(["git", "clone", REPO_URL, str(REPO_DIR)], check=True)

sys.path.insert(0, str(REPO_DIR))
print("Repo ready at", REPO_DIR)

In [2]:
import os
from pathlib import Path
from huggingface_hub import login, snapshot_download

HF_TOKEN = os.getenv("HF_TOKEN", "")
DATASET_REPO = "juhenes/ngiml"
DATASET_REVISION = "main"
DATA_DIR = "/content/data"

if HF_TOKEN:
    login(token=HF_TOKEN)

os.makedirs(DATA_DIR, exist_ok=True)
snapshot_download(
    repo_id=DATASET_REPO,
    repo_type="dataset",
    local_dir=DATA_DIR,
    revision=DATASET_REVISION,
    token=HF_TOKEN,
    resume_download=True,
)

root = Path(DATA_DIR)
manifest_files = sorted(
    p for p in root.rglob("manifest.*")
    if p.name in {"manifest.parquet", "manifest.json"}
)
tar_count = sum(1 for _ in root.rglob("*.tar")) + sum(1 for _ in root.rglob("*.tar.gz")) + sum(1 for _ in root.rglob("*.tgz"))

print("Dataset ready at", DATA_DIR)
print("Found manifests:", [str(p) for p in manifest_files[:5]])
print("Tar shards count:", tar_count)

In [3]:
from google.colab import drive
from pathlib import Path

# Mount Google Drive to store checkpoints/logs
DRIVE_MOUNT = "/content/drive"
OUTPUT_DIR = f"{DRIVE_MOUNT}/MyDrive/ngiml_runs"

drive.mount(DRIVE_MOUNT)
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
print("Checkpoints will be written to", OUTPUT_DIR)


In [4]:
from pathlib import Path
import json
import dataclasses

from src.data.dataloaders import AugmentationConfig, load_manifest
from src.model.hybrid_ngiml import HybridNGIMLConfig, HybridNGIMLOptimizerConfig, OptimizerGroupConfig
from src.model.feature_fusion import FeatureFusionConfig
from src.model.unet_decoder import UNetDecoderConfig
from src.model.backbones.efficientnet_backbone import EfficientNetBackboneConfig
from src.model.backbones.swin_backbone import SwinBackboneConfig
from src.model.backbones.residual_noise_branch import ResidualNoiseConfig
from src.model.losses import MultiStageLossConfig

data_root = Path(DATA_DIR)
manifest_names = ("manifest.parquet", "manifest.json")
resolved_manifest_path = data_root / "manifest_resolved.json"

manifest_candidates = [
    resolved_manifest_path,
    data_root / "manifest.parquet",
    data_root / "manifest.json",
    data_root / "prepared" / "manifest.parquet",
    data_root / "prepared" / "manifest.json",
    data_root / "ngiml" / "manifest.parquet",
    data_root / "ngiml" / "manifest.json",
]
MANIFEST_PATH = next((p for p in manifest_candidates if p.exists()), None)

if MANIFEST_PATH is None:
    discovered = sorted(
        p for p in data_root.rglob("manifest.*")
        if p.name in manifest_names or p.name == "manifest_resolved.json"
    )
    if discovered:
        MANIFEST_PATH = discovered[0]
    else:
        raise FileNotFoundError(
            f"No manifest.parquet or manifest.json found under {data_root}. "
            "Check Cell 2 download path/repo, or set DATA_DIR to the folder containing the manifest file."
        )

# Fast path: if a resolved manifest already exists, reuse it and skip expensive rewrite/index work.
if resolved_manifest_path.exists() and resolved_manifest_path.stat().st_size > 0:
    MANIFEST_PATH = resolved_manifest_path
    print(f"Using cached resolved manifest: {MANIFEST_PATH}")
else:
    print("Using manifest:", MANIFEST_PATH)

    def _norm(value: str) -> str:
        return str(value).replace("\\", "/")

    def _suffix_score(a_parts, b_parts):
        score = 0
        for ax, bx in zip(reversed(a_parts), reversed(b_parts)):
            if ax != bx:
                break
            score += 1
        return score

    # Build a one-time tar index (fast lookup by basename)
    tar_files = []
    for pat in ("*.tar", "*.tar.gz", "*.tgz"):
        tar_files.extend(data_root.rglob(pat))
    tar_by_name = {}
    for t in tar_files:
        tar_by_name.setdefault(t.name, []).append(t)
    print(f"Indexed tar files under {data_root}: {len(tar_files)}")

    def _candidate_paths(value: str, manifest_path: Path, data_root: Path):
        normalized = _norm(value)
        p = Path(normalized)

        candidates = []
        if p.is_absolute():
            candidates.append(p)
        else:
            candidates.extend([
                manifest_path.parent / p,
                data_root / p,
                data_root / "ngiml" / p,
                Path("/content") / p,
                Path("/content/data") / p,
                Path("/content/ngiml") / p,
            ])

        if "prepared/" in normalized:
            suffix = normalized.split("prepared/", 1)[1]
            candidates.extend([
                data_root / "prepared" / suffix,
                data_root / "ngiml" / "prepared" / suffix,
                Path("/content") / "prepared" / suffix,
                Path("/content/ngiml") / "prepared" / suffix,
            ])

        if "datasets/" in normalized:
            suffix = normalized.split("datasets/", 1)[1]
            candidates.extend([
                data_root / "datasets" / suffix,
                data_root / "ngiml" / "datasets" / suffix,
                Path("/content") / "datasets" / suffix,
                Path("/content/ngiml") / "datasets" / suffix,
            ])

        seen = set()
        unique = []
        for cand in candidates:
            key = cand.as_posix()
            if key not in seen:
                seen.add(key)
                unique.append(cand)
        return unique

    def _match_tar_by_basename(value: str):
        name = Path(_norm(value)).name
        matches = tar_by_name.get(name, [])
        if not matches:
            return None
        hint_parts = Path(_norm(value)).parts
        return max(matches, key=lambda p: _suffix_score(p.parts, hint_parts))

    def _resolve_file(value: str, manifest_path: Path, data_root: Path) -> Path:
        candidates = _candidate_paths(value, manifest_path, data_root)
        for cand in candidates:
            if cand.exists():
                return cand

        if str(value).endswith((".tar", ".tar.gz", ".tgz")):
            tar_match = _match_tar_by_basename(value)
            if tar_match is not None:
                return tar_match

        return candidates[0] if candidates else Path(_norm(value))

    def _resolve_path(path_str, manifest_path: Path, data_root: Path) -> str | None:
        if path_str is None:
            return None
        normalized = _norm(path_str)
        if "::" in normalized:
            archive, member = normalized.split("::", 1)
            archive_path = _resolve_file(archive, manifest_path, data_root).as_posix()
            member_path = _norm(member)
            return f"{archive_path}::{member_path}"
        return _resolve_file(normalized, manifest_path, data_root).as_posix()

    def _sample_files_exist(sample) -> bool:
        image_path = str(sample.image_path)
        if "::" in image_path:
            archive_path, _ = image_path.split("::", 1)
            if not Path(archive_path).exists():
                return False
        else:
            if not Path(image_path).exists():
                return False

        if sample.mask_path is not None and not Path(sample.mask_path).exists():
            return False
        if sample.high_pass_path is not None and not Path(sample.high_pass_path).exists():
            return False
        return True

    manifest_obj = load_manifest(MANIFEST_PATH)
    rewritten = 0
    for sample in manifest_obj.samples:
        image_new = _resolve_path(sample.image_path, MANIFEST_PATH, data_root)
        mask_new = _resolve_path(sample.mask_path, MANIFEST_PATH, data_root) if sample.mask_path else None
        hp_new = _resolve_path(sample.high_pass_path, MANIFEST_PATH, data_root) if sample.high_pass_path else None

        if image_new != sample.image_path:
            sample.image_path = image_new
            rewritten += 1
        if mask_new != sample.mask_path:
            sample.mask_path = mask_new
            rewritten += 1
        if hp_new != sample.high_pass_path:
            sample.high_pass_path = hp_new
            rewritten += 1

    original_count = len(manifest_obj.samples)
    manifest_obj.samples = [s for s in manifest_obj.samples if _sample_files_exist(s)]
    filtered_out = original_count - len(manifest_obj.samples)

    if not manifest_obj.samples:
        raise FileNotFoundError(
            "No valid samples remain after path resolution. "
            f"Indexed tar files: {len(tar_files)} under {data_root}. "
            "Likely the downloaded dataset does not contain prepared shards referenced by the manifest."
        )

    with open(resolved_manifest_path, "w", encoding="utf-8") as f:
        json.dump(manifest_obj.to_dict(), f)
    MANIFEST_PATH = resolved_manifest_path
    print(
        f"Wrote resolved manifest to {MANIFEST_PATH} "
        f"(updated fields: {rewritten}, removed missing samples: {filtered_out})"
    )

model_cfg = HybridNGIMLConfig(
    efficientnet=EfficientNetBackboneConfig(pretrained=True),
    swin=SwinBackboneConfig(model_name="swin_tiny_patch4_window7_224", pretrained=True),
    residual=ResidualNoiseConfig(num_kernels=3, base_channels=32, num_stages=4),
    fusion=FeatureFusionConfig(fusion_channels=(128, 192, 256, 320)),
    decoder=UNetDecoderConfig(decoder_channels=None, out_channels=1, per_stage_heads=True),
    optimizer=HybridNGIMLOptimizerConfig(
        efficientnet=OptimizerGroupConfig(lr=5e-5, weight_decay=1e-3),
        swin=OptimizerGroupConfig(lr=2e-5, weight_decay=1e-2),
        residual=OptimizerGroupConfig(lr=1e-4, weight_decay=0.0),
        fusion=OptimizerGroupConfig(lr=1e-4, weight_decay=1e-5),
        decoder=OptimizerGroupConfig(lr=1e-4, weight_decay=1e-5),
        betas=(0.9, 0.999),
        eps=1e-8,
    ),
    use_low_level=True,
    use_context=True,
    use_residual=True,
)

loss_cfg = MultiStageLossConfig(
    dice_weight=1.0,
    bce_weight=1.0,
    pos_weight=2.0,
    stage_weights=None,
    smooth=1e-6,
)

default_aug = AugmentationConfig(
    enable=True,
    views_per_sample=1,
    enable_flips=True,
    enable_rotations=True,
    max_rotation_degrees=5.0,
    enable_random_crop=True,
    crop_scale_range=(0.9, 1.0),
    enable_color_jitter=True,
    color_jitter_factors=(0.9, 1.1),
    enable_noise=True,
    noise_std_range=(0.0, 0.02),
)

per_dataset_aug = {
    "IMD2020": AugmentationConfig(
        enable=True,
        views_per_sample=2,
        enable_flips=True,
        enable_rotations=True,
        max_rotation_degrees=8.0,
        enable_random_crop=True,
        crop_scale_range=(0.85, 1.0),
        enable_color_jitter=True,
        color_jitter_factors=(0.85, 1.15),
        enable_noise=True,
        noise_std_range=(0.0, 0.03),
    ),
}

training_config = {
    "manifest": str(MANIFEST_PATH),
    "output_dir": OUTPUT_DIR,
    "batch_size": 8,
    "epochs": 50,
    "num_workers": 0,
    "amp": True,
    "grad_clip": 1.0,
    "val_every": 1,
    "checkpoint_every": 1,
    "resume": None,
    "auto_resume": True,
    "round_robin_seed": 42,
    "prefetch_factor": 2,
    "persistent_workers": False,
    "drop_last": True,
    "views_per_sample": 1,
    "max_rotation_degrees": 5.0,
    "noise_std_max": 0.02,
    "disable_aug": False,
    "device": "cuda",
    "aug_seed": 42,
    "seed": 42,
    "early_stopping_patience": 8,
    "early_stopping_min_delta": 1e-4,
    "default_aug": default_aug,
    "per_dataset_aug": per_dataset_aug,
    "model_config": model_cfg,
    "loss_config": loss_cfg,
}

print(json.dumps(training_config, indent=2, default=lambda o: dataclasses.asdict(o) if dataclasses.is_dataclass(o) else str(o)))

In [None]:
import os
import json
import dataclasses

# Throughput-oriented defaults (safer for Colab than compile+reduce-overhead).
recommended_workers = max(2, min(8, (os.cpu_count() or 4) // 2))

training_config.update({
    "num_workers": recommended_workers,
    "persistent_workers": True,
    "pin_memory": True,
    "auto_local_cache": True,
    "local_cache_dir": "/content/cache",
    "compile_model": False,
    "compile_mode": "default",
    "channels_last": True,
    "use_tf32": True,
    "balance_sampling": True,
})

# Quality regression guardrails: restore legacy regularization defaults.
model_cfg = training_config.get("model_config")
if model_cfg is not None and getattr(model_cfg, "optimizer", None) is not None:
    model_cfg.optimizer.efficientnet.weight_decay = 1e-5
    model_cfg.optimizer.swin.weight_decay = 1e-5
    model_cfg.optimizer.fusion.weight_decay = 1e-5
    model_cfg.optimizer.decoder.weight_decay = 1e-5

# IMD2020 views=2 can change optimization dynamics and slow training; keep 1 for parity/stability.
if "per_dataset_aug" in training_config and "IMD2020" in training_config["per_dataset_aug"]:
    training_config["per_dataset_aug"]["IMD2020"].views_per_sample = 1

effective_view_multiplier = {
    name: cfg.views_per_sample if cfg.enable else 1
    for name, cfg in training_config.get("per_dataset_aug", {}).items()
}

print("Applied speed+quality settings:")
print({k: training_config[k] for k in [
    "num_workers",
    "persistent_workers",
    "pin_memory",
    "auto_local_cache",
    "local_cache_dir",
    "compile_model",
    "compile_mode",
    "channels_last",
    "use_tf32",
    "balance_sampling",
]})
print("Per-dataset views_per_sample:", effective_view_multiplier)

# Effective config to be passed to TrainConfig
print("Effective training config (post-settings):")
print(json.dumps(training_config, indent=2, default=lambda o: dataclasses.asdict(o) if dataclasses.is_dataclass(o) else str(o)))

In [5]:
from importlib import reload
from tools import train_ngiml

# Ensure latest module state in this kernel
reload(train_ngiml)

cfg = train_ngiml.TrainConfig(**training_config)
train_ngiml.run_training(cfg)

In [None]:
from google.colab import drive
drive.mount('/content/drive')