In [1]:

from pathlib import Path
import shutil, zipfile, time

# 승호님 드라이브
#ZIP_DRIVE1 = Path("/content/drive/MyDrive/DataSet/MultiCameraSemanticSegmentation_Left,Label.zip")
#ZIP_DRIVE2 = Path("/content/drive/MyDrive/DataSet/MonoCameraSemanticSegmentation.zip")

# 자체 드라이브
ZIP_DRIVE1 = Path("/content/drive/MyDrive/datasets/SDLane.zip")
#ZIP_DRIVE2 = Path("/content/drive/MyDrive/datasets/MonoCameraSemanticSegmentation.zip")
#ZIP_DRIVE3 = Path("/content/drive/MyDrive/datasets/MultiCameraSemanticSegmentation_g_labels.zip")
CACHE_DIR = Path("/content/ds_cache")
ZIP_LOCAL1 = CACHE_DIR / "SDLane.zip" # SDLane 기준 복사 107초
EXTRACT_DIR1 = CACHE_DIR / "SDLane" # 압축풀기 130초
#ZIP_LOCAL2 = CACHE_DIR / "MCSSeg_mono.zip"
#EXTRACT_DIR2 = CACHE_DIR / "MCSSeg_mono"
#ZIP_LOCAL3 = CACHE_DIR / "g_labels.zip"
#EXTRACT_DIR3 = CACHE_DIR

CACHE_DIR.mkdir(parents=True, exist_ok=True)
EXTRACT_DIR1.mkdir(parents=True, exist_ok=True)
#EXTRACT_DIR2.mkdir(parents=True, exist_ok=True)


assert ZIP_DRIVE1.exists(), f"ZIP not found: {ZIP_DRIVE1}"
# 1) zip을 로컬로 복사
t0 = time.time()
if not ZIP_LOCAL1.exists() or ZIP_LOCAL1.stat().st_size != ZIP_DRIVE1.stat().st_size:
    print("Copying zip to local...")
    shutil.copy2(ZIP_DRIVE1, ZIP_LOCAL1)
#    shutil.copy2(ZIP_DRIVE2, ZIP_LOCAL2)
    #shutil.copy2(ZIP_DRIVE3, ZIP_LOCAL3)

print(f"zip local: {ZIP_LOCAL1} ({ZIP_LOCAL1.stat().st_size/1024/1024:.1f} MB), copy_time={time.time()-t0:.1f}s")

# 2) 압축 해제 (이미 풀려있으면 스킵)
marker = EXTRACT_DIR1 / ".unzipped_done"
if not marker.exists():
    print("Unzipping...")
    with zipfile.ZipFile(ZIP_LOCAL1, "r") as zf:
        zf.extractall(EXTRACT_DIR1)
    #with zipfile.ZipFile(ZIP_LOCAL2, "r") as zf:
    #    zf.extractall(EXTRACT_DIR2)
    #with zipfile.ZipFile(ZIP_LOCAL3, "r") as zf:
    #    zf.extractall(EXTRACT_DIR3)
    marker.write_text("ok")
print("Extracted to:", EXTRACT_DIR1)
0
# 3) 내부 구조 빠르게 확인 (상위 몇 개만)
top = list(EXTRACT_DIR1.glob("*"))[:20]
print("Top-level entries (sample):")
for p in top:
    print(" -", p)

Copying zip to local...
zip local: /content/ds_cache/SDLane.zip (8334.9 MB), copy_time=124.0s
Unzipping...
Extracted to: /content/ds_cache/SDLane
Top-level entries (sample):
 - /content/ds_cache/SDLane/SDLane
 - /content/ds_cache/SDLane/.unzipped_done


In [None]:
# ============================================================
# DeepLabV3+ (semantic) for SDLane (polygon/polyline JSON) - Single Colab Cell (TRAIN ONLY)
# - Train: /content/ds_cache/SDLane/SDLane/train/{images,labels}/<hash>/*.jpg|.png|.json
# - Val  : /content/ds_cache/SDLane/SDLane/test/{images,labels}/<hash>/*.jpg|.png|.json
# - Exclude bad train folder: a03dce1fc941e29b7a692717e40ca88d7c3aa18e
# - GT: union(binary) lane mask from geometries
# - Metrics: Dice/IoU + Boundary-F1 @ tolerance (2/4/8 px)
# - Saves: /content/drive/MyDrive/seman_seg_runs/<run>/ (epoch ckpt, best.pt, history)
# - Includes: profiling(fetch/H2D/compute/opt/imgs/s), torch.compile, channels_last, optional fixed crop
# ============================================================

!pip -q install -U segmentation-models-pytorch timm

import os, glob, json, math, time, random, shutil, subprocess
import numpy as np
from datetime import datetime
from PIL import Image
import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import segmentation_models_pytorch as smp
from tqdm import tqdm

# ---------------------------
# 0) Run dir
# ---------------------------

RUNS_ROOT = "/content/drive/MyDrive/seman_seg_runs"
run_name  = datetime.now().strftime("deeplabv3p_%Y%m%d_%H%M%S")
RUN_DIR   = os.path.join(RUNS_ROOT, run_name)
CKPT_DIR  = os.path.join(RUN_DIR, "checkpoints")
os.makedirs(CKPT_DIR, exist_ok=True)
print("[RunDir]", RUN_DIR)

# ---------------------------
# 1) Config (Resume+Profile 코드 우선 세팅 반영)
# ---------------------------
SEED = 42

TRAIN_ROOT = "/content/ds_cache/SDLane/SDLane/train"
VAL_ROOT   = "/content/ds_cache/SDLane/SDLane/test"
BAD_FOLDER = "a03dce1fc941e29b7a692717e40ca88d7c3aa18e"

# Image sizing (keep aspect, moderate downscale)
RESIZE_SHORTEST = 1024
RESIZE_LONGEST  = 1920
SIZE_DIVISOR    = 32

# Rasterization
LINE_WIDTH = 6
USE_AA = False

# Training (아래 코드 우선: BS=16, LR=6e-4)
EPOCHS       = 13
TRAIN_BS     = 16
LR           = 6e-4
WEIGHT_DECAY = 1e-4
LAMBDA_DICE  = 0.5
POS_WEIGHT   = 8.0

# Limits
MAX_TRAIN_SAMPLES = None
MAX_VAL_SAMPLES   = 500

# Metrics
BND_TOLS = (2, 4, 8)

# Profiling (아래 코드 우선)
PROFILE_EVERY = 100
SMI_EVERY     = 200
GRAD_ACCUM_STEPS = 1

# Speed knobs (아래 코드 우선)
USE_COMPILE       = True
USE_CHANNELS_LAST = True

# Optional fixed crop for train only
USE_TRAIN_FIXED_CROP = False
FIXED_CROP_SIZE = 1024

# Model
ENCODER_NAME = "resnet50"
ENCODER_WTS  = "imagenet"

# DataLoader workers
NUM_WORKERS = min(8, os.cpu_count() or 2)

# Save config
cfg = dict(
    seed=SEED,
    train_root=TRAIN_ROOT, val_root=VAL_ROOT, bad_folder=BAD_FOLDER,
    resize_shortest=RESIZE_SHORTEST, resize_longest=RESIZE_LONGEST, size_divisor=SIZE_DIVISOR,
    line_width=LINE_WIDTH, use_aa=USE_AA,
    epochs=EPOCHS, train_bs=TRAIN_BS, lr=LR, weight_decay=WEIGHT_DECAY,
    lambda_dice=LAMBDA_DICE, pos_weight=POS_WEIGHT,
    bnd_tols=list(BND_TOLS),
    max_train_samples=MAX_TRAIN_SAMPLES, max_val_samples=MAX_VAL_SAMPLES,
    num_workers=NUM_WORKERS,
    profile_every=PROFILE_EVERY, smi_every=SMI_EVERY, grad_accum_steps=GRAD_ACCUM_STEPS,
    use_compile=USE_COMPILE, use_channels_last=USE_CHANNELS_LAST,
    use_train_fixed_crop=USE_TRAIN_FIXED_CROP, fixed_crop_size=FIXED_CROP_SIZE,
    encoder_name=ENCODER_NAME, encoder_weights=ENCODER_WTS,
)
with open(os.path.join(RUN_DIR, "config.json"), "w") as f:
    json.dump(cfg, f, indent=2)

# ---------------------------
# 2) Reproducibility & speed knobs
# ---------------------------
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

cv2.setNumThreads(0)
torch.set_num_threads(1)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision("high")
torch.backends.cudnn.benchmark = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("[Device]", device, "| workers:", NUM_WORKERS, "| BS:", TRAIN_BS, "| Accum:", GRAD_ACCUM_STEPS)

# AMP dtype: bf16(ampere+) else fp16 + GradScaler
AMP_DTYPE = None
USE_SCALER = False
if device.type == "cuda":
    major, _ = torch.cuda.get_device_capability()
    AMP_DTYPE = torch.bfloat16 if major >= 8 else torch.float16
    USE_SCALER = (AMP_DTYPE == torch.float16)
scaler = torch.cuda.amp.GradScaler(enabled=USE_SCALER)
print(f"[AMP] dtype={AMP_DTYPE} | GradScaler={USE_SCALER}")

# ---------------------------
# 3) Utils: resize / pad
# ---------------------------
def resize_keep_aspect(H, W, shortest_edge, longest_edge):
    scale = shortest_edge / min(H, W)
    newH, newW = int(round(H * scale)), int(round(W * scale))
    if max(newH, newW) > longest_edge:
        scale = longest_edge / max(newH, newW)
        newH, newW = int(round(newH * scale)), int(round(newW * scale))
    return max(1,newH), max(1,newW)

def pad_to_divisor_tensor(x, divisor=32, pad_value=0.0):
    # x: (C,H,W)
    C, H, W = x.shape
    padH = (divisor - H % divisor) % divisor
    padW = (divisor - W % divisor) % divisor
    if padH == 0 and padW == 0:
        return x, (0,0,0,0)
    x = F.pad(x, (0, padW, 0, padH), value=pad_value)
    return x, (0, padW, 0, padH)

def pad_to_divisor_mask(y, divisor=32, pad_value=0):
    # y: (H,W)
    H, W = y.shape
    padH = (divisor - H % divisor) % divisor
    padW = (divisor - W % divisor) % divisor
    if padH == 0 and padW == 0:
        return y.long(), (0,0,0,0)
    yy = F.pad(y.unsqueeze(0).unsqueeze(0).float(), (0, padW, 0, padH), value=float(pad_value)).squeeze(0).squeeze(0)
    return yy.long(), (0, padW, 0, padH)

# ---------------------------
# 4) Rasterize: JSON geometries -> union binary mask
# ---------------------------
def _is_closed_polygon(pts_xy, close_tol=5.0, min_area=50.0):
    if pts_xy.shape[0] < 3:
        return False
    p0, pN = pts_xy[0], pts_xy[-1]
    if np.linalg.norm(p0 - pN) > close_tol:
        return False
    area = cv2.contourArea(pts_xy.astype(np.float32))
    return abs(area) >= min_area

def make_union_mask(json_path, H, W, line_width=6, use_aa=False):
    with open(json_path, "r") as f:
        ann = json.load(f)
    geoms = ann.get("geometry", [])
    mask = np.zeros((H, W), dtype=np.uint8)
    lt = cv2.LINE_AA if use_aa else cv2.LINE_8

    for g in geoms:
        if not g or len(g) < 2:
            continue
        pts = np.array(g, dtype=np.float32)
        pts_i = np.round(pts).astype(np.int32)

        if _is_closed_polygon(pts, close_tol=5.0, min_area=50.0):
            cv2.fillPoly(mask, [pts_i], 1)
        else:
            cv2.polylines(mask, [pts_i], isClosed=False, color=1, thickness=line_width, lineType=lt)
    return mask

# ---------------------------
# 5) Index dataset
# ---------------------------
def build_items(split_root, exclude_folder=None):
    img_root = os.path.join(split_root, "images")
    lbl_root = os.path.join(split_root, "labels")
    folders = sorted([d for d in os.listdir(img_root) if os.path.isdir(os.path.join(img_root, d))])
    items = []
    for folder in folders:
        if exclude_folder is not None and folder == exclude_folder:
            continue
        img_dir = os.path.join(img_root, folder)
        lbl_dir = os.path.join(lbl_root, folder)
        if not os.path.isdir(lbl_dir):
            continue
        img_paths = sorted(glob.glob(os.path.join(img_dir, "*.jpg"))) + sorted(glob.glob(os.path.join(img_dir, "*.png")))
        for img_path in img_paths:
            stem = os.path.splitext(os.path.basename(img_path))[0]
            json_path = os.path.join(lbl_dir, f"{stem}.json")
            if os.path.isfile(json_path):
                items.append((img_path, json_path))
    return items

train_items = build_items(TRAIN_ROOT, exclude_folder=BAD_FOLDER)
val_items   = build_items(VAL_ROOT, exclude_folder=None)

if MAX_TRAIN_SAMPLES is not None:
    train_items = train_items[:MAX_TRAIN_SAMPLES]
if MAX_VAL_SAMPLES is not None:
    val_items = val_items[:MAX_VAL_SAMPLES]

print(f"[Index] train={len(train_items)} | val={len(val_items)} | excluded={BAD_FOLDER}")

# ---------------------------
# 6) Dataset
# ---------------------------
IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)
IMAGENET_STD  = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)

class SDLaneSemanticDataset(Dataset):
    def __init__(self, items, shortest_edge, longest_edge, size_divisor=32,
                 line_width=6, use_aa=False, return_orig_for_eval=False,
                 use_fixed_crop=False, fixed_crop_size=1024):
        self.items = items
        self.shortest_edge = shortest_edge
        self.longest_edge = longest_edge
        self.size_divisor = size_divisor
        self.line_width = line_width
        self.use_aa = use_aa
        self.return_orig_for_eval = return_orig_for_eval
        self.use_fixed_crop = use_fixed_crop
        self.fixed_crop_size = fixed_crop_size

    def __len__(self):
        return len(self.items)

    def _random_crop(self, img_rs, m_rs, crop_size):
        h, w = m_rs.shape
        if h < crop_size or w < crop_size:
            pad_h = max(0, crop_size - h)
            pad_w = max(0, crop_size - w)
            img_rs = cv2.copyMakeBorder(img_rs, 0, pad_h, 0, pad_w, borderType=cv2.BORDER_CONSTANT, value=(0,0,0))
            m_rs   = np.pad(m_rs, ((0,pad_h),(0,pad_w)), mode="constant", constant_values=0)
            h, w = m_rs.shape
        y0 = random.randint(0, h - crop_size)
        x0 = random.randint(0, w - crop_size)
        img_c = img_rs[y0:y0+crop_size, x0:x0+crop_size]
        m_c   = m_rs[y0:y0+crop_size, x0:x0+crop_size]
        return img_c, m_c

    def __getitem__(self, idx):
        img_path, json_path = self.items[idx]
        img = Image.open(img_path).convert("RGB")
        W, H = img.size

        gt_mask = make_union_mask(json_path, H, W, line_width=self.line_width, use_aa=self.use_aa)  # (H,W) uint8 0/1

        newH, newW = resize_keep_aspect(H, W, self.shortest_edge, self.longest_edge)
        img_np = np.array(img)
        img_rs = cv2.resize(img_np, (newW, newH), interpolation=cv2.INTER_LINEAR)
        m_rs   = cv2.resize(gt_mask, (newW, newH), interpolation=cv2.INTER_NEAREST)

        if (not self.return_orig_for_eval) and self.use_fixed_crop:
            img_rs, m_rs = self._random_crop(img_rs, m_rs, self.fixed_crop_size)

        x = torch.from_numpy(img_rs).permute(2,0,1).float() / 255.0
        x = (x - IMAGENET_MEAN) / IMAGENET_STD
        y = torch.from_numpy(m_rs).long()

        x, pad = pad_to_divisor_tensor(x, divisor=self.size_divisor, pad_value=0.0)
        y, _   = pad_to_divisor_mask(y, divisor=self.size_divisor, pad_value=0)

        out = {"pixel_values": x, "labels": y}

        if self.return_orig_for_eval:
            out["orig_hw"]   = (H, W)
            out["gt_orig"]   = torch.from_numpy(gt_mask).to(torch.uint8)
            out["resized_hw"]= (newH, newW)
            out["pad"]       = pad  # (0,padW,0,padH)
        return out

def train_collate_fn(batch):
    return {
        "pixel_values": torch.stack([b["pixel_values"] for b in batch], dim=0),
        "labels": torch.stack([b["labels"] for b in batch], dim=0),
    }

def val_collate_fn(batch):
    b = batch[0]
    return {
        "pixel_values": b["pixel_values"].unsqueeze(0),
        "labels": b["labels"].unsqueeze(0),
        "orig_hw": b["orig_hw"],
        "gt_orig": b["gt_orig"],
        "resized_hw": b["resized_hw"],
        "pad": b["pad"],
    }

# DataLoader prefetch_factor는 num_workers>0에서만 유효
dl_common = dict(pin_memory=True, persistent_workers=(NUM_WORKERS > 0))
train_loader = DataLoader(
    SDLaneSemanticDataset(
        train_items, RESIZE_SHORTEST, RESIZE_LONGEST, SIZE_DIVISOR,
        line_width=LINE_WIDTH, use_aa=USE_AA, return_orig_for_eval=False,
        use_fixed_crop=USE_TRAIN_FIXED_CROP, fixed_crop_size=FIXED_CROP_SIZE
    ),
    batch_size=TRAIN_BS, shuffle=True, num_workers=NUM_WORKERS,
    collate_fn=train_collate_fn, drop_last=True,
    prefetch_factor=4 if NUM_WORKERS > 0 else None,
    **dl_common
)

val_loader = DataLoader(
    SDLaneSemanticDataset(
        val_items, RESIZE_SHORTEST, RESIZE_LONGEST, SIZE_DIVISOR,
        line_width=LINE_WIDTH, use_aa=USE_AA, return_orig_for_eval=True,
        use_fixed_crop=False
    ),
    batch_size=1, shuffle=False, num_workers=max(1, NUM_WORKERS//2),
    collate_fn=val_collate_fn,
    prefetch_factor=2 if max(1, NUM_WORKERS//2) > 0 else None,
    pin_memory=True, persistent_workers=True
)

# ---------------------------
# 7) Model
# ---------------------------
model = smp.DeepLabV3Plus(
    encoder_name=ENCODER_NAME,
    encoder_weights=ENCODER_WTS,
    in_channels=3,
    classes=1,
    activation=None,
).to(device)

if USE_CHANNELS_LAST:
    model = model.to(memory_format=torch.channels_last)

if USE_COMPILE:
    try:
        model = torch.compile(model, mode="reduce-overhead")
        print("[Compile] enabled")
    except Exception as e:
        print("[Compile] failed:", repr(e))

# ---------------------------
# 8) Loss: BCEWithLogits + Dice
# ---------------------------
bce = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([POS_WEIGHT], device=device))

def soft_dice_loss(logits, targets, eps=1e-6):
    probs = torch.sigmoid(logits)
    tgt = targets.unsqueeze(1).float()
    num = 2 * (probs * tgt).sum(dim=(2,3)) + eps
    den = (probs + tgt).sum(dim=(2,3)) + eps
    return 1.0 - (num / den).mean()

def loss_fn(logits, targets):
    b = bce(logits, targets.unsqueeze(1).float())
    d = soft_dice_loss(logits, targets)
    return b + LAMBDA_DICE * d, float(b.detach().cpu()), float(d.detach().cpu())

# ---------------------------
# 9) Metrics: Dice/IoU + Boundary F1
# ---------------------------
def dice_iou(pred_bin, gt_bin, eps=1e-6):
    pred = pred_bin.astype(bool)
    gt   = gt_bin.astype(bool)
    inter = np.logical_and(pred, gt).sum()
    p = pred.sum(); g = gt.sum()
    dice = (2.0 * inter + eps) / (p + g + eps)
    iou  = (inter + eps) / (p + g - inter + eps)
    return float(dice), float(iou)

def boundary_map(bin_mask):
    k = np.ones((3,3), np.uint8)
    b = cv2.morphologyEx(bin_mask.astype(np.uint8), cv2.MORPH_GRADIENT, k)
    return (b > 0).astype(np.uint8)

def boundary_f1(pred_bin, gt_bin, tol_px=3):
    pred_b = boundary_map(pred_bin)
    gt_b   = boundary_map(gt_bin)

    pred_pts = pred_b.sum()
    gt_pts   = gt_b.sum()
    if pred_pts == 0 and gt_pts == 0:
        return 1.0, 1.0, 1.0
    if pred_pts == 0 or gt_pts == 0:
        return 0.0, 0.0, 0.0

    dt_gt   = cv2.distanceTransform((1 - gt_b).astype(np.uint8),   cv2.DIST_L2, 3)
    dt_pred = cv2.distanceTransform((1 - pred_b).astype(np.uint8), cv2.DIST_L2, 3)

    pred_match = (pred_b == 1) & (dt_gt <= tol_px)
    gt_match   = (gt_b == 1)   & (dt_pred <= tol_px)

    precision = pred_match.sum() / (pred_pts + 1e-6)
    recall    = gt_match.sum()   / (gt_pts + 1e-6)
    f1 = (2 * precision * recall) / (precision + recall + 1e-6)
    return float(precision), float(recall), float(f1)

@torch.no_grad()
def evaluate(model, val_loader, tols=(2,4,8)):
    model.eval()
    dices, ious = [], []
    bf1 = {t: [] for t in tols}

    for batch in tqdm(val_loader, desc="Val", leave=False):
        x = batch["pixel_values"].to(device, non_blocking=True)
        if USE_CHANNELS_LAST:
            x = x.to(memory_format=torch.channels_last)

        if device.type == "cuda":
            with torch.autocast(device_type="cuda", dtype=AMP_DTYPE, enabled=True):
                logits = model(x)  # (1,1,h_pad,w_pad)
        else:
            logits = model(x)

        # unpad -> resized -> upsample to orig for metric (패딩 영향 최소화)
        (newH, newW) = batch["resized_hw"]
        (H, W)       = batch["orig_hw"]
        gt_orig = batch["gt_orig"].numpy().astype(np.uint8)

        logits_rs = logits[:, :, :newH, :newW]  # remove pad
        pred_up = F.interpolate(logits_rs, size=(H, W), mode="bilinear", align_corners=False)
        pred_prob = torch.sigmoid(pred_up.float())[0,0].detach().cpu().numpy()
        pred_bin = (pred_prob >= 0.5).astype(np.uint8)

        d, j = dice_iou(pred_bin, gt_orig)
        dices.append(d); ious.append(j)
        for t in tols:
            _, _, f1 = boundary_f1(pred_bin, gt_orig, tol_px=t)
            bf1[t].append(f1)

    out = {"dice": float(np.mean(dices)), "iou": float(np.mean(ious))}
    for t in tols:
        out[f"bf1@{t}px"] = float(np.mean(bf1[t])) if bf1[t] else 0.0
    return out

# ---------------------------
# 10) Saving: ckpt / best.pt / history
# ---------------------------
history = []
history_jsonl = os.path.join(RUN_DIR, "history.jsonl")
history_json  = os.path.join(RUN_DIR, "history.json")

def _atomic_write_json(path, obj):
    tmp = path + ".tmp"
    with open(tmp, "w") as f:
        json.dump(obj, f, indent=2)
    os.replace(tmp, path)

def append_history(row):
    history.append(row)
    with open(history_jsonl, "a") as f:
        f.write(json.dumps(row) + "\n")
    _atomic_write_json(history_json, history)

def save_ckpt(tag, epoch, metrics=None, extra=None, optimizer=None, scheduler=None, best_score=None):
    pt = os.path.join(CKPT_DIR, f"{tag}.pt")
    payload = {
        "epoch": epoch,
        "model_state": model.state_dict(),
        "optim_state": optimizer.state_dict() if optimizer is not None else None,
        "sched_state": scheduler.state_dict() if scheduler is not None else None,
        "metrics": metrics or {},
        "cfg": cfg,
        "extra": extra or {},
        "best_score": best_score,
        "saved_at": time.strftime("%Y-%m-%d %H:%M:%S"),
    }
    torch.save(payload, pt)

# ---------------------------
# 11) Optim / sched
# ---------------------------
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

best_score = -1.0  # select by bf1@4px
global_step = 0

# ---------------------------
# 12) Profiling helpers
# ---------------------------
def smi_snapshot():
    try:
        out = subprocess.check_output(
            ["nvidia-smi","--query-gpu=utilization.gpu,memory.used,memory.total","--format=csv,noheader,nounits"]
        ).decode().strip()
        return out
    except Exception:
        return "n/a"

def cuda_time_ms(fn):
    if device.type != "cuda":
        t0 = time.time()
        fn()
        return (time.time() - t0) * 1000.0
    starter = torch.cuda.Event(enable_timing=True)
    ender   = torch.cuda.Event(enable_timing=True)
    starter.record()
    fn()
    ender.record()
    torch.cuda.synchronize()
    return starter.elapsed_time(ender)

# ---------------------------
# 13) Train loop (with profiling)
# ---------------------------
for epoch in range(1, EPOCHS + 1):
    try:
        model.train()
        epoch_loss = 0.0
        epoch_bce  = 0.0
        epoch_dice = 0.0
        n_steps = 0

        it = iter(train_loader)
        steps_per_epoch = len(train_loader)
        pbar = tqdm(range(steps_per_epoch), desc=f"Train E{epoch}", leave=True)

        optimizer.zero_grad(set_to_none=True)

        for step_idx in pbar:
            # (A) fetch
            t_fetch0 = time.time()
            try:
                batch = next(it)
            except StopIteration:
                it = iter(train_loader)
                batch = next(it)
            t_fetch = time.time() - t_fetch0

            # (B) H2D
            if device.type == "cuda":
                torch.cuda.synchronize()
            t_h2d0 = time.time()

            x = batch["pixel_values"].to(device, non_blocking=True)
            y = batch["labels"].to(device, non_blocking=True)
            if USE_CHANNELS_LAST:
                x = x.to(memory_format=torch.channels_last)

            if device.type == "cuda":
                torch.cuda.synchronize()
            t_h2d = time.time() - t_h2d0

            # (C) fwd+bwd
            if device.type == "cuda":
                torch.cuda.synchronize()
            t_comp0 = time.time()

            if device.type == "cuda":
                with torch.autocast(device_type="cuda", dtype=AMP_DTYPE, enabled=True):
                    logits = model(x)
                    loss, bce_v, dice_v = loss_fn(logits, y)
                    loss_scaled = loss / GRAD_ACCUM_STEPS
            else:
                logits = model(x)
                loss, bce_v, dice_v = loss_fn(logits, y)
                loss_scaled = loss / GRAD_ACCUM_STEPS

            if USE_SCALER:
                scaler.scale(loss_scaled).backward()
            else:
                loss_scaled.backward()

            if device.type == "cuda":
                torch.cuda.synchronize()
            t_comp = time.time() - t_comp0

            # (D) optimizer step (accum boundary)
            did_step = False
            t_opt_ms = 0.0
            if (step_idx + 1) % GRAD_ACCUM_STEPS == 0:
                def _opt():
                    if USE_SCALER:
                        scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    if USE_SCALER:
                        scaler.step(optimizer)
                        scaler.update()
                    else:
                        optimizer.step()

                t_opt_ms = cuda_time_ms(_opt)
                optimizer.zero_grad(set_to_none=True)
                did_step = True

            # accumulate stats
            epoch_loss += float(loss.detach().cpu())
            epoch_bce  += float(bce_v)
            epoch_dice += float(dice_v)
            n_steps += 1
            global_step += 1

            iter_t = t_fetch + t_h2d + t_comp + (t_opt_ms/1000.0 if did_step else 0.0)
            imgs_s = TRAIN_BS / max(1e-9, iter_t)

            if (step_idx + 1) % PROFILE_EVERY == 0:
                mem_gb = torch.cuda.memory_allocated()/1024**3 if device.type=="cuda" else 0.0
                pbar.set_postfix({
                    "loss": f"{epoch_loss/n_steps:.4f}",
                    "fetch_s": f"{t_fetch:.3f}",
                    "h2d_s": f"{t_h2d:.3f}",
                    "comp_s": f"{t_comp:.3f}",
                    "opt_ms": f"{t_opt_ms:.1f}",
                    "iter_s": f"{iter_t:.3f}",
                    "imgs/s": f"{imgs_s:.1f}",
                    "memGB": f"{mem_gb:.1f}",
                })

            if (step_idx + 1) % SMI_EVERY == 0:
                print(f"[nvidia-smi] step={step_idx+1}/{steps_per_epoch} | {smi_snapshot()}")

        scheduler.step()

        avg_loss = epoch_loss / max(1, n_steps)
        avg_bce  = epoch_bce  / max(1, n_steps)
        avg_dice = epoch_dice / max(1, n_steps)

        # save epoch ckpt
        save_ckpt(
            tag=f"epoch_{epoch:03d}",
            epoch=epoch,
            optimizer=optimizer,
            scheduler=scheduler,
            best_score=float(best_score),
            extra={
                "train_loss": avg_loss,
                "train_bce": avg_bce,
                "train_dice_loss": avg_dice,
                "grad_accum_steps": GRAD_ACCUM_STEPS,
                "use_train_fixed_crop": USE_TRAIN_FIXED_CROP,
                "fixed_crop_size": FIXED_CROP_SIZE if USE_TRAIN_FIXED_CROP else None,
                "amp_dtype": str(AMP_DTYPE),
                "use_scaler": USE_SCALER,
            }
        )

        # eval
        metrics = evaluate(model, val_loader, tols=BND_TOLS)
        sel = float(metrics.get("bf1@4px", metrics["dice"]))

        print(
            f"[Epoch {epoch}] train_loss={avg_loss:.4f} (bce={avg_bce:.4f}, diceL={avg_dice:.4f}) | "
            f"val dice={metrics['dice']:.4f} iou={metrics['iou']:.4f} "
            f"bf1@2={metrics['bf1@2px']:.4f} bf1@4={metrics['bf1@4px']:.4f} bf1@8={metrics['bf1@8px']:.4f} "
            f"| sel(bf1@4)={sel:.4f}"
        )

        append_history({
            "epoch": epoch,
            "global_step": global_step,
            "train_loss": avg_loss,
            "train_bce": avg_bce,
            "train_dice_loss": avg_dice,
            "lr": float(optimizer.param_groups[0]["lr"]),
            **metrics,
            "sel_score": sel,
            "best_score_before": float(best_score),
            "grad_accum_steps": GRAD_ACCUM_STEPS,
            "use_train_fixed_crop": USE_TRAIN_FIXED_CROP,
            "fixed_crop_size": FIXED_CROP_SIZE if USE_TRAIN_FIXED_CROP else None,
            "encoder_name": ENCODER_NAME,
            "encoder_weights": ENCODER_WTS,
            "amp_dtype": str(AMP_DTYPE),
            "use_scaler": USE_SCALER,
            "time": time.strftime("%Y-%m-%d %H:%M:%S"),
        })

        # best update -> best.pt로 저장
        if sel > float(best_score):
            best_score = sel
            save_ckpt(
                tag="best",
                epoch=epoch,
                metrics=metrics,
                optimizer=optimizer,
                scheduler=scheduler,
                best_score=float(best_score),
                extra={"train_loss": avg_loss, "note": "best by bf1@4px"}
            )
            print(f"  -> saved BEST (bf1@4px={best_score:.4f}) to {os.path.join(CKPT_DIR,'best.pt')}")

    except Exception as e:
        append_history({
            "epoch": epoch,
            "global_step": global_step,
            "error": repr(e),
            "time": time.strftime("%Y-%m-%d %H:%M:%S"),
        })
        save_ckpt(
            tag="crash_last",
            epoch=epoch,
            metrics={"error": repr(e)},
            optimizer=optimizer,
            scheduler=scheduler,
            best_score=float(best_score),
            extra={"note": "crashed during training/eval"}
        )
        print("[ERROR] crashed; saved crash_last.pt and history. Raising...")
        raise

# final
save_ckpt(tag="final", epoch=EPOCHS, metrics={"best_score": float(best_score)},
          optimizer=optimizer, scheduler=scheduler, best_score=float(best_score))
print("[Done]")
print(" - RUN_DIR:", RUN_DIR)
print(" - CKPT_DIR:", CKPT_DIR)
print(" - history:", history_jsonl)
