In [3]:
# SegFormer fine-tuning on /kaggle/input/taping-cracks (cracks + taping)
# ➜ 3 classes: 0=bg, 1=crack, 2=taping
# Logs per-epoch: loss, mIoU, Dice (classes 1&2), and saves best-by-mIoU.
# No augmentations (only resize).

import os, glob, json, random
from pathlib import Path

import numpy as np
import cv2
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, ConcatDataset

from transformers import (
    SegformerImageProcessor,
    SegformerForSemanticSegmentation,
    get_cosine_schedule_with_warmup,
)

# ======================
# Config
# ======================
DATA_CANDIDATES = [
    "/kaggle/input/taping-cracks",
    "/kaggle/input/taping-cracks/data copy",
    "/kaggle/input/taping-cracks/data_copy",
]
OUT_DIR   = Path("/kaggle/working/ckpts_segformer")
IMG_SIZE  = 512
BATCH     = 6
EPOCHS    = 20
LR        = 5e-5
WEIGHT_DECAY = 1e-4
WARMUP_FRAC  = 0.05
SEED = 42
WORKERS = 2
NUM_CLASSES = 3  # 0 bg, 1 crack, 2 taping

# ======================
# Utils
# ======================
def set_seed(s=SEED):
    random.seed(s); np.random.seed(s)
    torch.manual_seed(s); torch.cuda.manual_seed_all(s)
    torch.backends.cudnn.benchmark=False; torch.backends.cudnn.deterministic=True

def find_data_root(candidates):
    for base in candidates:
        root = Path(base)
        if (root/"cracks").exists() and (root/"taping").exists():
            return root
        if root.exists():
            for sub in root.iterdir():
                if sub.is_dir() and (sub/"cracks").exists() and (sub/"taping").exists():
                    return sub
    raise FileNotFoundError("Dataset not found in: " + ", ".join(candidates))

def imread_rgb(path: str) -> np.ndarray:
    arr = cv2.imread(path, cv2.IMREAD_COLOR)
    if arr is None:
        raise FileNotFoundError(path)
    return cv2.cvtColor(arr, cv2.COLOR_BGR2RGB)

def read_mask_or_zero(mask_path: str, hw) -> np.ndarray:
    """Return uint8 mask (0/255). If missing, zeros of requested (H,W)."""
    H, W = int(hw[0]), int(hw[1])
    if mask_path and os.path.exists(mask_path):
        m = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        if m is None:
            m = np.zeros((H, W), np.uint8)
    else:
        m = np.zeros((H, W), np.uint8)
    # binarize
    m = (m > 127).astype(np.uint8) * 255
    return m

def resize_image_and_masks(img, crack_mask, taping_mask, size=IMG_SIZE):
    img_r = cv2.resize(img, (size, size), interpolation=cv2.INTER_LINEAR)
    ck_r  = cv2.resize(crack_mask, (size, size), interpolation=cv2.INTER_NEAREST)
    tp_r  = cv2.resize(taping_mask, (size, size), interpolation=cv2.INTER_NEAREST)
    return img_r, ck_r, tp_r

def build_semantic_label(ck_mask_255, tp_mask_255):
    """Return HxW int64 class map with {0,1,2}; taping overrides crack on overlaps."""
    ck = (ck_mask_255 > 127).astype(np.uint8)
    tp = (tp_mask_255 > 127).astype(np.uint8)
    lab = np.zeros_like(ck, dtype=np.uint8)
    lab[ck == 1] = 1
    lab[tp == 1] = 2
    return lab.astype(np.int64)

# -------- metrics (exclude background) --------
def batch_mean_iou_and_dice(pred_logits, labels, num_classes=NUM_CLASSES, eps=1e-6):
    """
    pred_logits: [B,C,H,W]; labels: [B,H,W] int64.
    Compute mIoU and Dice averaged over classes 1..num_classes-1 (exclude bg).
    """
    with torch.no_grad():
        pred = pred_logits.argmax(dim=1)  # [B,H,W]
        B = pred.shape[0]
        classes = list(range(1, num_classes))
        ious, dices = [], []
        for c in classes:
            pred_c = (pred == c)
            gt_c   = (labels == c)
            inter  = (pred_c & gt_c).sum(dim=(1,2)).float()  # [B]
            pred_sum = pred_c.sum(dim=(1,2)).float()
            gt_sum   = gt_c.sum(dim=(1,2)).float()
            union = pred_sum + gt_sum - inter
            # For images where class absent in both, skip from average
            valid = (union > 0)
            if valid.any():
                ious.append((inter[valid] / (union[valid] + eps)))
            # Dice
            denom = pred_sum + gt_sum
            valid_d = (denom > 0)
            if valid_d.any():
                dices.append((2*inter[valid_d] / (denom[valid_d] + eps)))
        # Mean over classes then over batch
        miou = torch.stack([x.mean() for x in ious]).mean().item() if ious else 0.0
        dice = torch.stack([x.mean() for x in dices]).mean().item() if dices else 0.0
        return miou, dice

# ======================
# Dataset
# ======================
class CrackTapingSemSeg(Dataset):
    """
    Expects:
      <root>/
        cracks/train/images/*.jpg (or png)
        cracks/train/masks/*.png
        taping/train/images/*.jpg
        taping/train/masks/*.png
      and same for 'val'.
    We build a unified image list from both tasks; mask paths are resolved per image name.
    """
    def __init__(self, root, split="train", img_size=IMG_SIZE):
        self.root  = Path(root)
        self.split = split
        self.size  = img_size

        self.cr_img_dir = self.root/"cracks"/split/"images"
        self.cr_msk_dir = self.root/"cracks"/split/"masks"
        self.tp_img_dir = self.root/"taping"/split/"images"
        self.tp_msk_dir = self.root/"taping"/split/"masks"

        # Gather all image paths from both sub-datasets
        self.items = []
        for d in [self.cr_img_dir, self.tp_img_dir]:
            for ip in sorted(glob.glob(str(d/"*"))):
                self.items.append(ip)
        if not self.items:
            raise RuntimeError(f"No images found under {self.cr_img_dir} and {self.tp_img_dir}")

        # For mask lookup, we’ll try the same stem in both mask dirs (png/jpg/jpeg)
        self.mask_exts = (".png", ".jpg", ".jpeg")

        # Minimal processor just to normalize/format pixel_values
        self.proc = SegformerImageProcessor.from_pretrained(
            "nvidia/segformer-b2-finetuned-ade-512-512"
        )

    def _find_mask(self, mdir: Path, stem: str):
        for ext in self.mask_exts:
            cand = mdir / f"{stem}{ext}"
            if cand.exists():
                return str(cand)
        return None

    def __len__(self): return len(self.items)

    def __getitem__(self, idx):
        ip = self.items[idx]  # <-- keep name consistent to avoid NameError
        img = imread_rgb(ip)
        H, W = img.shape[:2]

        stem = Path(ip).stem
        # Resolve masks from both categories (could be None)
        ck_mask_path = self._find_mask(self.cr_msk_dir, stem)
        tp_mask_path = self._find_mask(self.tp_msk_dir, stem)

        # Read or zero, then resize
        crack = read_mask_or_zero(ck_mask_path, (H, W))
        taping= read_mask_or_zero(tp_mask_path, (H, W))
        img_r, crack_r, taping_r = resize_image_and_masks(img, crack, taping, self.size)

        # Merge into class map (0/1/2)
        label = build_semantic_label(crack_r, taping_r)  # int64 [H,W]

        # Processor to float + normalize to pixel_values [3,H,W]
        # (We pass a PIL Image to keep consistency with HF preproc expectations)
        enc = self.proc(images=Image.fromarray(img_r), return_tensors="pt")
        pixel_values = enc["pixel_values"][0]  # [3,H,W], float

        return {
            "pixel_values": pixel_values,
            "labels": torch.from_numpy(label).long(),  # [H,W]
            "path": ip,
        }

# ======================
# Train
# ======================
set_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
OUT_DIR.mkdir(parents=True, exist_ok=True)

DATA_ROOT = find_data_root(DATA_CANDIDATES)
print("Using DATA_ROOT:", DATA_ROOT)

# Datasets / loaders
ds_tr = CrackTapingSemSeg(DATA_ROOT, "train", IMG_SIZE)
ds_va = CrackTapingSemSeg(DATA_ROOT, "val",   IMG_SIZE)

dl_tr = DataLoader(ds_tr, batch_size=BATCH, shuffle=True,  num_workers=WORKERS,
                   pin_memory=True, drop_last=True)
dl_va = DataLoader(ds_va, batch_size=BATCH, shuffle=False, num_workers=WORKERS,
                   pin_memory=True)

# Model
id2label = {0: "background", 1: "crack", 2: "taping"}
label2id = {v:k for k,v in id2label.items()}
model = SegformerForSemanticSegmentation.from_pretrained(
    "nvidia/segformer-b2-finetuned-ade-512-512",
    num_labels=NUM_CLASSES,
    id2label=id2label, label2id=label2id,
    ignore_mismatched_sizes=True,  # ADE head is 150 classes
).to(device)

# Optim + sched
opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
steps_total = EPOCHS * len(dl_tr)
warm = max(1, int(WARMUP_FRAC * steps_total))
sched = get_cosine_schedule_with_warmup(opt, num_warmup_steps=warm, num_training_steps=steps_total)

use_amp = torch.cuda.is_available()
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

best_miou, best_ckpt = -1.0, None

for epoch in range(1, EPOCHS+1):
    model.train()
    running = 0.0
    pbar = tqdm(dl_tr, desc=f"Epoch {epoch}/{EPOCHS}", total=len(dl_tr))
    for step, batch in enumerate(pbar, 1):
        pixel_values = batch["pixel_values"].to(device, non_blocking=True)  # [B,3,H,W]
        labels       = batch["labels"].to(device, non_blocking=True)        # [B,H,W]
        opt.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=use_amp):
            out = model(pixel_values=pixel_values, labels=labels)
            loss = out.loss
        scaler.scale(loss).backward()
        scaler.step(opt); scaler.update()
        sched.step()
        running += loss.item()
        pbar.set_postfix(loss=f"{running/step:.4f}")

    # ---- validation (mIoU & Dice over classes 1,2) ----
    model.eval()
    miou_list, dice_list = [], []
    with torch.no_grad():
        for batch in tqdm(dl_va, leave=False, desc="Valid"):
            pixel_values = batch["pixel_values"].to(device)
            labels       = batch["labels"].to(device)
            with torch.cuda.amp.autocast(enabled=use_amp):
                logits = model(pixel_values=pixel_values).logits  # [B,C,h,w]
            # Upsample logits to labels size if needed
            if logits.shape[-2:] != labels.shape[-2:]:
                logits = F.interpolate(logits, size=labels.shape[-2:], mode="bilinear", align_corners=False)
            mi, di = batch_mean_iou_and_dice(logits, labels, NUM_CLASSES)
            miou_list.append(mi); dice_list.append(di)
    m_miou = float(np.mean(miou_list)) if miou_list else 0.0
    m_dice = float(np.mean(dice_list)) if dice_list else 0.0
    print(f"[Epoch {epoch}] Val mIoU={m_miou:.4f}  Dice={m_dice:.4f}")

    # Save best-by-mIoU
    if m_miou > best_miou:
        best_miou = m_miou
        best_ckpt = OUT_DIR / f"segformer_b2_best_e{epoch}_miou{m_miou:.4f}.pt"
        torch.save({
            "model": model.state_dict(),
            "epoch": epoch,
            "val_miou": m_miou,
            "val_dice": m_dice,
            "id2label": id2label,
        }, best_ckpt)
        print("Saved:", best_ckpt)

# Save final
final_ckpt = OUT_DIR / "segformer_b2_final.pt"
torch.save({"model": model.state_dict(), "id2label": id2label}, final_ckpt)
print("Final:", final_ckpt, "| Best:", best_ckpt)

# Tiny inference helper
helper = OUT_DIR / "inference_helper.py"
helper.write_text(f"""import torch, cv2
import numpy as np
from PIL import Image
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
proc = SegformerImageProcessor.from_pretrained("nvidia/segformer-b2-finetuned-ade-512-512")
ckpt = r"{str(best_ckpt if best_ckpt else final_ckpt)}"
state = torch.load(ckpt, map_location="cpu")
id2label = state.get("id2label", {{0: "background", 1: "crack", 2: "taping"}})
model = SegformerForSemanticSegmentation.from_pretrained(
    "nvidia/segformer-b2-finetuned-ade-512-512", num_labels=len(id2label),
    ignore_mismatched_sizes=True, id2label=id2label, label2id={{v:k for k,v in id2label.items()}}
)
model.load_state_dict(state["model"], strict=False)
model.to(device).eval()

def predict_semantic(img_path, out_png, img_size={IMG_SIZE}):
    img = Image.open(img_path).convert("RGB").resize((img_size, img_size), resample=Image.BILINEAR)
    enc = proc(images=img, return_tensors="pt")
    with torch.no_grad():
        logits = model(pixel_values=enc["pixel_values"].to(device)).logits
    # upsample to processor size (already img_size)
    logits = torch.nn.functional.interpolate(logits, size=(img_size, img_size), mode="bilinear", align_corners=False)
    pred = logits.argmax(dim=1)[0].cpu().numpy().astype(np.uint8)
    # write colorful visualization for quick checking
    palette = np.array([[0,0,0],[255,0,0],[0,255,0]], dtype=np.uint8)
    vis = palette[pred]
    cv2.imwrite(out_png, cv2.cvtColor(vis, cv2.COLOR_RGB2BGR))

# Example:
# predict_semantic("/kaggle/input/taping-cracks/cracks/val/images/ANY.jpg", "/kaggle/working/ANY_segformer_vis.png")
""")
print("Helper written:", helper)


Using DATA_ROOT: /kaggle/input/taping-cracks/data copy


Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/segformer-b2-finetuned-ade-512-512 and are newly initialized because the shapes did not match:
- decode_head.classifier.weight: found shape torch.Size([150, 768, 1, 1]) in the checkpoint and torch.Size([3, 768, 1, 1]) in the model instantiated
- decode_head.classifier.bias: found shape torch.Size([150]) in the checkpoint and torch.Size([3]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
  with torch.cuda.amp.autocast(enabled=use_amp):
Epoch 1/20: 100%|██████████| 2932/2932 [18:33<00:00,  2.63it/s, loss=0.3210]
  with torch.cuda.amp.autocast(enabled=use_amp):
                                                        

[Epoch 1] Val mIoU=0.4356  Dice=0.5712
Saved: /kaggle/working/ckpts_segformer/segformer_b2_best_e1_miou0.4356.pt


Epoch 2/20: 100%|██████████| 2932/2932 [18:26<00:00,  2.65it/s, loss=0.0820]
                                                        

[Epoch 2] Val mIoU=0.5303  Dice=0.6667
Saved: /kaggle/working/ckpts_segformer/segformer_b2_best_e2_miou0.5303.pt


Epoch 3/20: 100%|██████████| 2932/2932 [18:20<00:00,  2.66it/s, loss=0.0638]
                                                        

[Epoch 3] Val mIoU=0.5833  Dice=0.7084
Saved: /kaggle/working/ckpts_segformer/segformer_b2_best_e3_miou0.5833.pt


Epoch 4/20: 100%|██████████| 2932/2932 [18:22<00:00,  2.66it/s, loss=0.0542]
                                                        

[Epoch 4] Val mIoU=0.5853  Dice=0.7120
Saved: /kaggle/working/ckpts_segformer/segformer_b2_best_e4_miou0.5853.pt


Epoch 5/20: 100%|██████████| 2932/2932 [18:25<00:00,  2.65it/s, loss=0.0478]
                                                        

[Epoch 5] Val mIoU=0.5931  Dice=0.7149
Saved: /kaggle/working/ckpts_segformer/segformer_b2_best_e5_miou0.5931.pt


Epoch 6/20: 100%|██████████| 2932/2932 [18:22<00:00,  2.66it/s, loss=0.0427]
                                                        

[Epoch 6] Val mIoU=0.6097  Dice=0.7334
Saved: /kaggle/working/ckpts_segformer/segformer_b2_best_e6_miou0.6097.pt


Epoch 7/20: 100%|██████████| 2932/2932 [18:24<00:00,  2.66it/s, loss=0.0387]
                                                        

[Epoch 7] Val mIoU=0.6281  Dice=0.7482
Saved: /kaggle/working/ckpts_segformer/segformer_b2_best_e7_miou0.6281.pt


Epoch 8/20: 100%|██████████| 2932/2932 [18:22<00:00,  2.66it/s, loss=0.0356]
                                                        

[Epoch 8] Val mIoU=0.6317  Dice=0.7528
Saved: /kaggle/working/ckpts_segformer/segformer_b2_best_e8_miou0.6317.pt


Epoch 9/20: 100%|██████████| 2932/2932 [18:24<00:00,  2.65it/s, loss=0.0329]
                                                        

[Epoch 9] Val mIoU=0.6410  Dice=0.7570
Saved: /kaggle/working/ckpts_segformer/segformer_b2_best_e9_miou0.6410.pt


Epoch 11/20: 100%|██████████| 2932/2932 [18:22<00:00,  2.66it/s, loss=0.0283]
                                                        

[Epoch 11] Val mIoU=0.6318  Dice=0.7494


Epoch 12/20: 100%|██████████| 2932/2932 [18:21<00:00,  2.66it/s, loss=0.0267]
                                                        

[Epoch 12] Val mIoU=0.6513  Dice=0.7642
Saved: /kaggle/working/ckpts_segformer/segformer_b2_best_e12_miou0.6513.pt


Epoch 13/20: 100%|██████████| 2932/2932 [18:25<00:00,  2.65it/s, loss=0.0252]
                                                        

[Epoch 13] Val mIoU=0.6540  Dice=0.7668
Saved: /kaggle/working/ckpts_segformer/segformer_b2_best_e13_miou0.6540.pt


Epoch 14/20: 100%|██████████| 2932/2932 [18:28<00:00,  2.65it/s, loss=0.0240]
                                                        

[Epoch 14] Val mIoU=0.6536  Dice=0.7664


Epoch 15/20: 100%|██████████| 2932/2932 [18:25<00:00,  2.65it/s, loss=0.0230]
                                                        

[Epoch 15] Val mIoU=0.6576  Dice=0.7700
Saved: /kaggle/working/ckpts_segformer/segformer_b2_best_e15_miou0.6576.pt


Epoch 16/20: 100%|██████████| 2932/2932 [18:25<00:00,  2.65it/s, loss=0.0223]
                                                        

[Epoch 16] Val mIoU=0.6560  Dice=0.7684


Epoch 17/20: 100%|██████████| 2932/2932 [18:28<00:00,  2.64it/s, loss=0.0217]
                                                        

[Epoch 17] Val mIoU=0.6570  Dice=0.7686


Epoch 18/20: 100%|██████████| 2932/2932 [18:29<00:00,  2.64it/s, loss=0.0213]
                                                        

[Epoch 18] Val mIoU=0.6570  Dice=0.7688


Epoch 19/20: 100%|██████████| 2932/2932 [18:27<00:00,  2.65it/s, loss=0.0210]
                                                        

[Epoch 19] Val mIoU=0.6591  Dice=0.7706
Saved: /kaggle/working/ckpts_segformer/segformer_b2_best_e19_miou0.6591.pt


Epoch 20/20: 100%|██████████| 2932/2932 [18:26<00:00,  2.65it/s, loss=0.0210]
                                                        

[Epoch 20] Val mIoU=0.6581  Dice=0.7698
Final: /kaggle/working/ckpts_segformer/segformer_b2_final.pt | Best: /kaggle/working/ckpts_segformer/segformer_b2_best_e19_miou0.6591.pt
Helper written: /kaggle/working/ckpts_segformer/inference_helper.py
