
# Dirt Segmentation → Clean/Dirty Classification (PyTorch, Gradio)

This notebook implements the **cleanliness** part of our methodological framework:

- Loads your **COCO (segmentation)** dataset (uses only the `dirty` category to build masks).
- Trains a **lightweight U-Net** (pure PyTorch) for **dirt segmentation**.
- Converts the predicted mask into an **image-level clean/dirty** decision by thresholding the dirt-area fraction.
- Includes **single‑image inference** helpers and an optional **Gradio demo**.

> Tip: Keep **offline augmentations out** of validation/test to avoid leakage. Add on‑the‑fly augmentation only in the training dataset if needed.


In [14]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [15]:
# %% [code] 1 - Setup: GPU & packages

import os, sys, json, math, random, zipfile, shutil
from pathlib import Path

import numpy as np
from PIL import Image, ImageDraw, ImageOps

import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as TF

import cv2
import gradio as gr

random.seed(13)
np.random.seed(13)
torch.manual_seed(13)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(13)

print("Torch:", torch.__version__, "| CUDA:", torch.cuda.is_available())


Torch: 2.8.0+cu126 | CUDA: True



## Configure dataset path

The notebook expects a **Roboflow-style COCO segmentation** export:

```
dataset_root/
  train/  _annotations.coco.json  <images...>
  valid/  _annotations.coco.json  <images...>
  test/   _annotations.coco.json  <images...>
```


In [16]:
# %% [code] 2 - Config & Paths
DATA_ROOT = Path("/content/drive/MyDrive/decentrathon_dataset")  # will be auto-detected after unzip if needed

SIZE = 384
BATCH = 2
EPOCHS = 20

# thresholds (can be recalibrated later)
IMG_W = 0.7
CLEAN_W = 1.0
MASK_THR_DEFAULT  = 0.7
CLEAN_THR_DEFAULT = 0.12
MIN_BLOB_DEFAULT = 0.02
DEFAULT_MASK_THR  = 0.80

Path("checkpoints").mkdir(exist_ok=True, parents=True)


In [17]:
# Basic structure & path checks per split
for split in ['train', 'valid', 'test']:
    split_dir = DATA_ROOT / split
    assert split_dir.exists(), f"Missing split folder: {split_dir}"
    ann_path = split_dir / '_annotations.coco.json'
    assert ann_path.exists(), f"Missing COCO annotations: {ann_path}"

    data = json.loads(ann_path.read_text())
    uses_prefix = any(
        '/' in im.get('file_name', '') and im['file_name'].split('/')[0] in {'train','valid','test'}
        for im in data.get('images', [])[:10]
    )

    # Resolve a sample image path to catch path-style mismatches
    if data.get('images'):
        sample = data['images'][0]['file_name']
        # style A: "train/xxx.jpg" lives under DATA_ROOT
        p_a = DATA_ROOT / sample if uses_prefix else None
        # style B: "xxx.jpg" lives directly under split dir
        p_b = split_dir / Path(sample).name

        if uses_prefix:
            assert p_a.exists() or p_b.exists(), f"[{split}] Example image not found: tried {p_a} and {p_b}"
        else:
            assert p_b.exists() or (DATA_ROOT / sample).exists(), f"[{split}] Example image not found: tried {p_b} and {DATA_ROOT/sample}"

    print(f"{split}: images={len(data.get('images', []))} | annotations={len(data.get('annotations', []))}")

print("DATA_ROOT:", DATA_ROOT)

train: images=200 | annotations=236
valid: images=42 | annotations=86
test: images=27 | annotations=60
DATA_ROOT: /content/drive/MyDrive/decentrathon_dataset


In [18]:
# %% [code] 5 - Geometry & mask utilities
def _denorm_if_needed(flat_xy, w, h):
    if not flat_xy:
        return flat_xy
    mx, mn = max(flat_xy), min(flat_xy)
    # If coords in [0..1] range, scale to pixels
    if 0.0 <= mn and mx <= 1.5:
        out = []
        for i, v in enumerate(flat_xy):
            out.append(v * (w if (i % 2) == 0 else h))
        return out
    return flat_xy

def polygons_to_mask(polygons, height, width):
    mask = Image.new("L", (width, height), 0)
    draw = ImageDraw.Draw(mask)
    if isinstance(polygons, dict) and "counts" in polygons:
        return np.zeros((height, width), dtype=np.uint8)  # ignore RLE
    if isinstance(polygons, list) and len(polygons) > 0:
        polys = polygons if isinstance(polygons[0], (list, tuple)) else [polygons]
        for poly in polys:
            if len(poly) >= 6:
                flat = _denorm_if_needed(list(poly), width, height)
                pts = [(flat[i], flat[i+1]) for i in range(0, len(flat), 2)]
                if len(pts) >= 3:
                    draw.polygon(pts, outline=1, fill=1)
    return np.array(mask, dtype=np.uint8)

def letterbox(img: Image.Image, size=512, fill=128):
    w, h = img.size
    s = size
    scale = min(s / w, s / h)
    nw, nh = int(round(w * scale)), int(round(h * scale))
    img = img.resize((nw, nh), Image.Resampling.BILINEAR)
    canvas = Image.new("RGB", (s, s), (fill, fill, fill))
    pad = ((s - nw) // 2, (s - nh) // 2)
    canvas.paste(img, pad)
    return canvas, pad, (nw, nh)


In [19]:
# %% [code] 6 - Dataset
class DirtCocoSemDataset(Dataset):
    def __init__(self, root: Path, split: str, size=512, augment=False):
        self.root = Path(root)/split
        self.size = size
        self.augment = augment
        data = json.loads((self.root/"_annotations.coco.json").read_text())
        self.images = {im['id']: im for im in data['images']}
        self.anns_by_img = {}
        for a in data.get('annotations', []):
            self.anns_by_img.setdefault(a['image_id'], []).append(a)
        self.cat_ok = {c['id']: any(k in c['name'].lower() for k in ['dirt','mud','grime'])
                       for c in data.get('categories', [])}
        self.list = list(self.images.keys())

    def __len__(self): return len(self.list)

    def _load_img(self, info):
        fn = info['file_name']
        p = (self.root/fn) if "/" not in fn else (self.root.parent/fn)
        if not p.exists():
            p = self.root/Path(fn).name
        return Image.open(p).convert("RGB")

    def _mask_for_img(self, img_id, H, W):
        mask = np.zeros((H, W), dtype=np.uint8)
        for a in self.anns_by_img.get(img_id, []):
            if not self.cat_ok.get(a.get('category_id'), False):
                continue
            seg = a.get('segmentation', [])
            mask |= polygons_to_mask(seg, H, W)
        return mask  # all-zero => clean

    def __getitem__(self, idx):
        img_id = self.list[idx]
        info = self.images[img_id]
        im = self._load_img(info)
        W, H = info.get('width', im.size[0]), info.get('height', im.size[1])
        m = self._mask_for_img(img_id, H, W)

        lb_img, (px, py), (nw, nh) = letterbox(im, self.size)
        mask_img = Image.fromarray(m).resize((nw, nh), Image.Resampling.NEAREST)
        lb_mask = Image.new("L", (self.size, self.size), 0)
        lb_mask.paste(mask_img, (px, py))

        if self.augment:
            if random.random() < 0.5:
                lb_img = ImageOps.mirror(lb_img); lb_mask = ImageOps.mirror(lb_mask)
            if random.random() < 0.2:
                lb_img = ImageOps.autocontrast(lb_img)

        x = TF.to_tensor(lb_img)
        y = torch.from_numpy(np.array(lb_mask)).float()
        y = (y > 0).float()
        y_img = (y.max() > 0).long()
        return x, y, y_img


In [20]:
# %% [code] 8 - Model & loss
class DoubleConv(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_c, out_c, 3, 1, 1), nn.BatchNorm2d(out_c), nn.ReLU(inplace=True),
            nn.Conv2d(out_c, out_c, 3, 1, 1), nn.BatchNorm2d(out_c), nn.ReLU(inplace=True),
        )
    def forward(self, x): return self.net(x)

class UNetSmall(nn.Module):
    """
    width: base channel count. width=32 matches your old model.
           width=16 halves channels (much lower memory).
    """
    def __init__(self, in_ch=3, out_ch=1, width=32):
        super().__init__()
        c1, c2, c3, cb = width, width*2, width*4, width*8

        self.d1 = DoubleConv(in_ch, c1); self.p1 = nn.MaxPool2d(2)
        self.d2 = DoubleConv(c1, c2);    self.p2 = nn.MaxPool2d(2)
        self.d3 = DoubleConv(c2, c3);    self.p3 = nn.MaxPool2d(2)
        self.b  = DoubleConv(c3, cb)

        self.u3  = nn.ConvTranspose2d(cb, c3, 2, 2); self.dc3 = DoubleConv(c3 + c3, c3)
        self.u2  = nn.ConvTranspose2d(c3, c2, 2, 2); self.dc2 = DoubleConv(c2 + c2, c2)
        self.u1  = nn.ConvTranspose2d(c2, c1, 2, 2); self.dc1 = DoubleConv(c1 + c1, c1)
        self.out = nn.Conv2d(c1, out_ch, 1)

    def forward(self, x):
        c1 = self.d1(x)
        c2 = self.d2(self.p1(c1))
        c3 = self.d3(self.p2(c2))
        b  = self.b(self.p3(c3))

        x  = self.u3(b); x = self.dc3(torch.cat([x, c3], 1))
        x  = self.u2(x); x = self.dc2(torch.cat([x, c2], 1))
        x  = self.u1(x); x = self.dc1(torch.cat([x, c1], 1))
        return self.out(x)

def dice_loss(logits, target, eps=1e-6):
    probs = torch.sigmoid(logits)
    num = 2 * (probs * target).sum() + eps
    den = probs.sum() + target.sum() + eps
    return 1 - num / den


In [21]:
# %% [code] 9 - DataLoaders
train_ds = DirtCocoSemDataset(DATA_ROOT, "train", SIZE, augment=True)
valid_ds = DirtCocoSemDataset(DATA_ROOT, "valid", SIZE, augment=False)
test_ds  = DirtCocoSemDataset(DATA_ROOT, "test",  SIZE, augment=False)

train_loader = DataLoader(train_ds, batch_size=BATCH, shuffle=True,  num_workers=2, pin_memory=True)
valid_loader = DataLoader(valid_ds, batch_size=BATCH, shuffle=False, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=BATCH, shuffle=False, num_workers=2, pin_memory=True)

len(train_ds), len(valid_ds), len(test_ds)


(200, 42, 27)

In [22]:
# %% [code] 10 - Losses, optimizer, metrics
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNetSmall(width=16).to(device)

def estimate_pos_weight(loader, max_batches=100):
    pos = tot = 0
    for b, (_, y, _) in enumerate(loader):
        y = y.unsqueeze(1).float()
        pos += y.sum().item()
        tot += y.numel()
        if b+1 >= max_batches: break
    p = max(1e-6, pos / tot)
    return max(1.0, (1-p)/p)

LR = 3e-4
POS_WEIGHT = estimate_pos_weight(train_loader)
print(f"Estimated pos_weight ≈ {POS_WEIGHT:.2f}")
opt = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-4)
sch  = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=10, eta_min=LR*0.2)

bce_seg = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([POS_WEIGHT], device=device))
bce_img = nn.BCEWithLogitsLoss()

def remove_small(predm: torch.Tensor, min_frac=0.003):
    # predm: [B,1,H,W] bool. Morphological opening-ish with pooling to kill tiny blobs.
    b,_,h,w = predm.shape
    k = int(((h*w*min_frac)**0.5)//2*2 + 1)   # odd kernel size
    x = predm.float()
    x = F.max_pool2d(x, k, 1, k//2)
    x = 1 - F.max_pool2d(1-x, k, 1, k//2)
    return (x>0.5)

def tversky_loss(logits, target, alpha=0.35, beta=0.65, gamma=1.2, eps=1e-6):
    p = torch.sigmoid(logits)
    tp = (p*target).sum()
    fp = (p*(1-target)).sum()
    fn = ((1-p)*target).sum()
    tversky = (tp + eps) / (tp + alpha*fn + beta*fp + eps)
    return (1.0 - tversky).pow(gamma)

@torch.no_grad()
def eval_metrics(model, loader, mask_thr=MASK_THR_DEFAULT, clean_thr=CLEAN_THR_DEFAULT, min_blob=MIN_BLOB_DEFAULT):
    model.eval()
    inter_all = union_all = 0.0
    tp=tn=fp=fn = 0.0
    ious_pos = []
    clean_total = clean_fp = 0

    for x, y, y_img in loader:
        x = x.to(device)
        y = y.to(device).unsqueeze(1).float()
        y_bool = (y > 0.5).bool()

        logits = model(x)
        prob   = torch.sigmoid(logits)            # [B,1,H,W]

        # 1) threshold -> boolean mask
        predm  = (prob > mask_thr)                # bool [B,1,H,W]

        # 2) optional speckle clean-up
        predm  = remove_small(predm, min_frac=0.002)

        # global IoU
        inter_all += (predm & y_bool).sum().item()
        union_all += (predm | y_bool).sum().item()

        # per-positive-image IoU and clean-FPR
        bs = y.shape[0]
        for i in range(bs):
            if y_bool[i].any():
                inter = (predm[i] & y_bool[i]).sum().item()
                union = (predm[i] | y_bool[i]).sum().item()
                ious_pos.append(inter / (union + 1e-9))
            else:
                clean_total += 1
                clean_fp   += int(predm[i].any())

        # image-level clean/dirty
        frac   = predm.float().mean(dim=(2,3)).squeeze(1)  # fraction of positives per image
        yhat   = (frac >= clean_thr)
        y_true = y_img.to(device).view(-1).bool()

        tp += ( yhat &  y_true).sum().item()
        tn += (~yhat & ~y_true).sum().item()
        fp += ( yhat & ~y_true).sum().item()
        fn += (~yhat &  y_true).sum().item()

    iou_all  = inter_all / max(1.0, union_all)
    iou_pos  = float(sum(ious_pos) / max(1, len(ious_pos)))
    clean_fpr = clean_fp / max(1, clean_total)
    acc = (tp+tn) / max(1.0, tp+tn+fp+fn)
    return iou_all, iou_pos, clean_fpr, acc

Estimated pos_weight ≈ 2.85


In [23]:
# %% [code] 11 - Training
from contextlib import nullcontext
scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())
autocast_ctx = (lambda: torch.autocast('cuda', dtype=torch.float16)) if torch.cuda.is_available() else nullcontext

WARMUP_EPOCHS = 5  # try 5–8 if you want to pretrain pure segmentation
ACCUM_STEPS = 1

best = (0.0, 0.0)
CKPT = "checkpoints/unet_dirt_best.pth"

for ep in range(1, EPOCHS+1):
    model.train()
    total = 0.0

    for step, (x, y, y_img) in enumerate(train_loader):
        x = x.to(device)
        y = y.to(device).unsqueeze(1).float()
        y_img = y_img.to(device).float().unsqueeze(1)

        with autocast_ctx():
            logits   = model(x)

            seg_loss = 0.5*bce_seg(logits, y) + 0.5*tversky_loss(logits, y, alpha=0.9, beta=0.1)

            img_logit = logits.mean(dim=(2,3))          # [B,1] logits for image-level
            img_loss  = bce_img(img_logit, y_img)

            # extra: penalize area ABOVE clean threshold for clean images only
            prob      = torch.sigmoid(logits)
            area      = prob.mean(dim=(2,3), keepdim=True)             # [B,1]
            over      = (area - CLEAN_THR_DEFAULT).clamp(min=0)        # > 0 only if too big
            clean_pen = ((1 - y_img) * over).mean()

            loss = seg_loss + IMG_W*img_loss + CLEAN_W*clean_pen

        scaler.scale(loss/ACCUM_STEPS).backward()
        if (step+1) % ACCUM_STEPS == 0:
            scaler.step(opt); scaler.update()
            opt.zero_grad(set_to_none=True)

        total += loss.item()

    sch.step()

    torch.cuda.empty_cache()
    iou_all, iou_pos, clean_fpr, acc = eval_metrics(
        model, valid_loader,
        mask_thr=MASK_THR_DEFAULT,
        clean_thr=CLEAN_THR_DEFAULT,
        min_blob=MIN_BLOB_DEFAULT
    )

    print(f"Epoch {ep}/{EPOCHS} | loss={total/len(train_loader):.3f} | "
          f"IoU_all={iou_all:.3f} | IoU_pos={iou_pos:.3f} | cleanFPR={clean_fpr:.3f} | acc={acc:.3f}")

    if iou_all>best[0] or (iou_all==best[0] and acc>best[1]):
        best = (iou_all, acc)
        torch.save({"model": model.state_dict(), "size": SIZE}, CKPT)
        print("  ↳ saved:", CKPT)

print("Best:", best)

  scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())


Epoch 1/20 | loss=1.300 | IoU_all=0.364 | IoU_pos=0.391 | cleanFPR=1.000 | acc=0.762
  ↳ saved: checkpoints/unet_dirt_best.pth
Epoch 2/20 | loss=1.228 | IoU_all=0.346 | IoU_pos=0.371 | cleanFPR=1.000 | acc=0.643
Epoch 3/20 | loss=1.207 | IoU_all=0.246 | IoU_pos=0.253 | cleanFPR=1.000 | acc=0.524
Epoch 4/20 | loss=1.189 | IoU_all=0.478 | IoU_pos=0.486 | cleanFPR=1.000 | acc=0.857
  ↳ saved: checkpoints/unet_dirt_best.pth
Epoch 5/20 | loss=1.194 | IoU_all=0.468 | IoU_pos=0.493 | cleanFPR=1.000 | acc=0.833
Epoch 6/20 | loss=1.123 | IoU_all=0.501 | IoU_pos=0.537 | cleanFPR=1.000 | acc=0.762
  ↳ saved: checkpoints/unet_dirt_best.pth
Epoch 7/20 | loss=1.135 | IoU_all=0.518 | IoU_pos=0.543 | cleanFPR=1.000 | acc=0.786
  ↳ saved: checkpoints/unet_dirt_best.pth
Epoch 8/20 | loss=1.131 | IoU_all=0.354 | IoU_pos=0.338 | cleanFPR=0.800 | acc=0.667
Epoch 9/20 | loss=1.121 | IoU_all=0.537 | IoU_pos=0.568 | cleanFPR=1.000 | acc=0.786
  ↳ saved: checkpoints/unet_dirt_best.pth
Epoch 10/20 | loss=1.132 

In [24]:
# %% [code] 12 - Threshold sweep
for mt in [0.65, 0.70, 0.75, 0.80]:
    for ct in [0.65, 0.70, 0.75, 0.80]:
      for mb in [0.12, 0.14, 0.16, 0.18, 0.20]:
        i_all, i_pos, fpr, acc = eval_metrics(
            model, valid_loader, mask_thr=mt, clean_thr=ct, min_blob=MIN_BLOB_DEFAULT
        )
        print(f"mask={mt:.2f} | clean_thr={ct:.3f} | IoU_all={i_all:.3f} | "
              f"IoU_pos={i_pos:.3f} | cleanFPR={fpr:.3f} | acc={acc:.3f}")

mask=0.65 | clean_thr=0.650 | IoU_all=0.348 | IoU_pos=0.329 | cleanFPR=0.900 | acc=0.238
mask=0.65 | clean_thr=0.650 | IoU_all=0.348 | IoU_pos=0.329 | cleanFPR=0.900 | acc=0.238
mask=0.65 | clean_thr=0.650 | IoU_all=0.348 | IoU_pos=0.329 | cleanFPR=0.900 | acc=0.238
mask=0.65 | clean_thr=0.650 | IoU_all=0.348 | IoU_pos=0.329 | cleanFPR=0.900 | acc=0.238
mask=0.65 | clean_thr=0.650 | IoU_all=0.348 | IoU_pos=0.329 | cleanFPR=0.900 | acc=0.238
mask=0.65 | clean_thr=0.700 | IoU_all=0.348 | IoU_pos=0.329 | cleanFPR=0.900 | acc=0.238
mask=0.65 | clean_thr=0.700 | IoU_all=0.348 | IoU_pos=0.329 | cleanFPR=0.900 | acc=0.238
mask=0.65 | clean_thr=0.700 | IoU_all=0.348 | IoU_pos=0.329 | cleanFPR=0.900 | acc=0.238
mask=0.65 | clean_thr=0.700 | IoU_all=0.348 | IoU_pos=0.329 | cleanFPR=0.900 | acc=0.238
mask=0.65 | clean_thr=0.700 | IoU_all=0.348 | IoU_pos=0.329 | cleanFPR=0.900 | acc=0.238
mask=0.65 | clean_thr=0.750 | IoU_all=0.348 | IoU_pos=0.329 | cleanFPR=0.900 | acc=0.238
mask=0.65 | clean_thr

In [25]:
# %% [code] 13 - Inference utilities
def area_filter(mask_np: np.ndarray, min_pixels: int):
    n, labels, stats, _ = cv2.connectedComponentsWithStats(mask_np.astype('uint8'), connectivity=8)
    keep = np.zeros_like(mask_np, dtype=np.uint8)
    for i in range(1, n):
        if stats[i, cv2.CC_STAT_AREA] >= min_pixels:
            keep[labels == i] = 1
    return keep

@torch.no_grad()
def predict_image(pil_im: Image.Image, mask_thr=DEFAULT_MASK_THR, clean_thr=CLEAN_THR_DEFAULT):
    model.eval()
    img, pad, _ = letterbox(pil_im.convert("RGB"), SIZE)
    x = TF.to_tensor(img).unsqueeze(0).to(device)
    prob = torch.sigmoid(model(x))[0,0].cpu().numpy()
    pred = (prob > mask_thr).astype(np.uint8)
    pred = area_filter(pred, int(0.001 * pred.size))   # drop tiny blobs (~0.1% of image)
    dirt_frac = float(pred.mean())
    label = "DIRTY" if dirt_frac >= clean_thr else "CLEAN"

    sem = Image.fromarray((pred*255).astype(np.uint8))
    ov  = img.convert("RGBA")
    red = ImageOps.colorize(sem, black=(0,0,0), white=(255,0,0)).convert("RGBA")
    out = Image.alpha_composite(ov, red).convert("RGB")
    return out, f"dirt_fraction={dirt_frac:.3f} → {label}"


In [26]:
# %% [code] 14 - Load checkpoint
CKPT = "checkpoints/unet_dirt_best.pth"
state = torch.load(CKPT, map_location=device)
model.load_state_dict(state["model"])
model.eval()
print("Loaded:", CKPT)


Loaded: checkpoints/unet_dirt_best.pth


In [27]:
# %% [code] 15 - Gradio demo
with gr.Blocks() as demo:
    gr.Markdown("### Dirt segmentation + image-level CLEAN/DIRTY")
    with gr.Row():
        inp = gr.Image(type="pil", label="Upload or paste a photo")
        out_img = gr.Image(label="Mask overlay")
    with gr.Row():
        mask_thr  = gr.Slider(0.1, 0.95, value=DEFAULT_MASK_THR, step=0.01, label="Mask threshold")
        clean_thr = gr.Slider(0.0, 0.4,  value=CLEAN_THR_DEFAULT, step=0.01, label="Clean/Dirty threshold (fraction)")
    out_txt = gr.Textbox(label="Result")
    btn = gr.Button("Submit")
    btn.click(predict_image, [inp, mask_thr, clean_thr], [out_img, out_txt])

demo.launch(share=True)


Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://a538010c8202051b2d.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


