In [None]:
import os, subprocess, sys
from pathlib import Path

REPO_URL = "https://github.com/your-user/ngiml.git"  # update to your fork if needed
REPO_DIR = Path("/content/ngiml")

if REPO_DIR.exists():
    subprocess.run(["git", "-C", str(REPO_DIR), "pull"], check=True)
else:
    subprocess.run(["git", "clone", REPO_URL, str(REPO_DIR)], check=True)

sys.path.insert(0, str(REPO_DIR))
print("Repo ready at", REPO_DIR)

In [None]:
import os
from huggingface_hub import login, snapshot_download

HF_TOKEN = os.environ.get("HF_TOKEN")  # optional if repo is private
DATASET_REPO = "your-hf-user/your-prepared-dataset"  # update to your dataset repo id
DATASET_REVISION = "main"
DATA_DIR = "/content/data"

if HF_TOKEN:
    login(token=HF_TOKEN)

os.makedirs(DATA_DIR, exist_ok=True)
snapshot_download(
    repo_id=DATASET_REPO,
    repo_type="dataset",
    local_dir=DATA_DIR,
    revision=DATASET_REVISION,
    token=HF_TOKEN,
)
print("Dataset ready at", DATA_DIR)

In [None]:
from google.colab import drive
from pathlib import Path

# Mount Google Drive to store checkpoints/logs
DRIVE_MOUNT = "/content/drive"
OUTPUT_DIR = f"{DRIVE_MOUNT}/MyDrive/ngiml_runs"

drive.mount(DRIVE_MOUNT)
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
print("Checkpoints will be written to", OUTPUT_DIR)


In [None]:
from pathlib import Path
import json
import dataclasses

from src.data.dataloaders import AugmentationConfig
from src.model.hybrid_ngiml import HybridNGIMLConfig, HybridNGIMLOptimizerConfig, OptimizerGroupConfig
from src.model.feature_fusion import FeatureFusionConfig
from src.model.unet_decoder import UNetDecoderConfig
from src.model.backbones.efficientnet_backbone import EfficientNetBackboneConfig
from src.model.backbones.swin_backbone import SwinBackboneConfig
from src.model.backbones.residual_noise_branch import ResidualNoiseConfig
from src.model.losses import MultiStageLossConfig

MANIFEST_PATH = Path(DATA_DIR) / "manifest.json"

model_cfg = HybridNGIMLConfig(
    efficientnet=EfficientNetBackboneConfig(pretrained=True),
    swin=SwinBackboneConfig(model_name="swin_tiny_patch4_window7_224", pretrained=True),
    residual=ResidualNoiseConfig(gaussian_kernel=5, gaussian_sigma=1.2, highpass_strength=1.0),
    fusion=FeatureFusionConfig(fusion_channels=(128, 192, 256, 320)),
    decoder=UNetDecoderConfig(decoder_channels=None, out_channels=1, per_stage_heads=True),
    optimizer=HybridNGIMLOptimizerConfig(
        efficientnet=OptimizerGroupConfig(lr=5e-5, weight_decay=1e-5),
        swin=OptimizerGroupConfig(lr=2e-5, weight_decay=1e-5),
        residual=OptimizerGroupConfig(lr=1e-4, weight_decay=0.0),
        fusion=OptimizerGroupConfig(lr=1e-4, weight_decay=1e-5),
        decoder=OptimizerGroupConfig(lr=1e-4, weight_decay=1e-5),
        betas=(0.9, 0.999),
        eps=1e-8,
    ),
    use_low_level=True,
    use_context=True,
    use_residual=True,
)

loss_cfg = MultiStageLossConfig(
    dice_weight=1.0,
    bce_weight=1.0,
    pos_weight=2.0,
    stage_weights=None,
    smooth=1e-6,
)

default_aug = AugmentationConfig(
    enable=True,
    copies_per_sample=1,
    enable_flips=True,
    enable_rotations=True,
    max_rotation_degrees=5.0,
    enable_random_crop=True,
    crop_scale_range=(0.9, 1.0),
    enable_color_jitter=True,
    color_jitter_factors=(0.9, 1.1),
    enable_noise=True,
    noise_std_range=(0.0, 0.02),
)

per_dataset_aug = {
    # "dataset_name_from_manifest": AugmentationConfig(enable=False),
}

training_config = {
    "manifest": str(MANIFEST_PATH),
    "output_dir": OUTPUT_DIR,
    "batch_size": 4,
    "epochs": 10,
    "num_workers": 2,
    "amp": True,
    "grad_clip": 1.0,
    "val_every": 1,
    "checkpoint_every": 1,
    "resume": None,
    "round_robin_seed": 0,
    "prefetch_factor": 2,
    "persistent_workers": False,
    "drop_last": True,
    "copies_per_sample": 1,
    "max_rotation_degrees": 5.0,
    "noise_std_max": 0.02,
    "disable_aug": False,
    "device": "cuda",
    "aug_seed": None,
    "default_aug": default_aug,
    "per_dataset_aug": per_dataset_aug,
    "model_config": model_cfg,
    "loss_config": loss_cfg,
}

print(json.dumps(training_config, indent=2, default=lambda o: dataclasses.asdict(o) if dataclasses.is_dataclass(o) else str(o)))


In [None]:
from tools import train_ngiml
cfg = train_ngiml.TrainConfig(**training_config)
train_ngiml.run_training(cfg)