In [32]:
from pathlib import Path
import re, glob

DATA_ROOT = Path("AOI_3_Paris_Train")  # endre
RGB_DIR = DATA_ROOT / "RGB-PanSharpen"
GEO_DIR = DATA_ROOT / "geojson/buildings"

rgb_tifs = sorted(RGB_DIR.glob("*.tif"))
geojsons = sorted(GEO_DIR.glob("buildings_AOI_3_Paris_img*.geojson"))

print("tifs:", len(rgb_tifs))
print("geojsons:", len(geojsons))
print("example tif:", rgb_tifs[0].name)
print("example geo:", geojsons[0].name)


tifs: 1148
geojsons: 1148
example tif: RGB-PanSharpen_AOI_3_Paris_img10.tif
example geo: buildings_AOI_3_Paris_img10.geojson


In [33]:
def extract_img_id(name: str) -> str:
    # finner "img380" etc
    m = re.search(r"(img\d+)", name)
    if not m:
        raise ValueError(f"Could not find img### in: {name}")
    return m.group(1)

geo_by_img = {extract_img_id(p.name): p for p in geojsons}

# test
for p in rgb_tifs[:5]:
    img_id = extract_img_id(p.name)
    print(p.name, "->", geo_by_img.get(img_id))


RGB-PanSharpen_AOI_3_Paris_img10.tif -> AOI_3_Paris_Train/geojson/buildings/buildings_AOI_3_Paris_img10.geojson
RGB-PanSharpen_AOI_3_Paris_img100.tif -> AOI_3_Paris_Train/geojson/buildings/buildings_AOI_3_Paris_img100.geojson
RGB-PanSharpen_AOI_3_Paris_img1000.tif -> AOI_3_Paris_Train/geojson/buildings/buildings_AOI_3_Paris_img1000.geojson
RGB-PanSharpen_AOI_3_Paris_img1001.tif -> AOI_3_Paris_Train/geojson/buildings/buildings_AOI_3_Paris_img1001.geojson
RGB-PanSharpen_AOI_3_Paris_img1003.tif -> AOI_3_Paris_Train/geojson/buildings/buildings_AOI_3_Paris_img1003.geojson


In [34]:
from pathlib import Path
import numpy as np
import pandas as pd
from tqdm import tqdm
import rasterio
from rasterio.windows import Window
from rasterio.features import rasterize
import geopandas as gpd
import cv2
from shapely.geometry import box
import hashlib

OUT = Path("./tiles_paris_rgb")
(train_img, train_msk) = (OUT/"train/images", OUT/"train/masks")
(val_img, val_msk)     = (OUT/"val/images",   OUT/"val/masks")
for p in [train_img, train_msk, val_img, val_msk]:
    p.mkdir(parents=True, exist_ok=True)

TILE = 512
STRIDE = 384

def tile_starts(full, tile, stride):
    if full <= tile:
        return [0]
    starts = list(range(0, full - tile + 1, stride))
    last = full - tile
    if starts[-1] != last:
        starts.append(last)
    return starts

def is_val_scene(img_id, frac=0.2):
    h = int(hashlib.md5(img_id.encode()).hexdigest(), 16)
    return (h % 1000) < int(frac * 1000)

def write_png(path, arr):
    cv2.imwrite(str(path), cv2.cvtColor(arr, cv2.COLOR_RGB2BGR))


import numpy as np

def to_uint8_percentile_nonzero(rgb_chw, p_low=2, p_high=98, min_range=50, fallback_hi=1200):
    """
    rgb_chw: (3,H,W) uint16/float
    - bruker kun nonzero piksler per kanal for percentiler
    - hvis for få gyldige piksler eller hi-lo for liten -> fallback til [0..fallback_hi]
    - holder nodata (alle kanaler 0) som 0 i output
    """
    x = rgb_chw.astype(np.float32)
    out = np.empty_like(x, dtype=np.float32)

    nodata = (x[0] == 0) & (x[1] == 0) & (x[2] == 0)

    for c in range(3):
        vals = x[c].ravel()
        vals = vals[vals > 0]  # ignorer 0 (nodata/black)

        if vals.size < 1000:
            lo, hi = 0.0, float(fallback_hi)
        else:
            lo = np.percentile(vals, p_low)
            hi = np.percentile(vals, p_high)
            if (hi - lo) < min_range:
                lo, hi = 0.0, float(fallback_hi)

        out[c] = (x[c] - lo) / (hi - lo + 1e-6)

    out = np.clip(out, 0, 1)
    out = (out * 255).astype(np.uint8)

    # bevar nodata som svart
    out[:, nodata] = 0
    return out



def rasterize_mask(polys_gdf, win_transform, h, w):
    if len(polys_gdf) == 0:
        return np.zeros((h, w), dtype=np.uint8)
    shapes = [(geom, 1) for geom in polys_gdf.geometry]
    return rasterize(
        shapes,
        out_shape=(h, w),
        transform=win_transform,
        fill=0,
        dtype=np.uint8,
        all_touched=False,
    )

rows = []

for tif in tqdm(rgb_tifs, desc="tifs"):
    img_id = extract_img_id(tif.name)
    geo_path = geo_by_img.get(img_id)
    if geo_path is None:
        print("WARN: no geojson for", tif.name, "img_id", img_id)
        continue

    with rasterio.open(tif) as src:
        W, H = src.width, src.height
        img_crs = src.crs

        gdf = gpd.read_file(geo_path)
        if gdf.crs != img_crs:
            gdf = gdf.to_crs(img_crs)
        _ = gdf.sindex

        xs = tile_starts(W, TILE, STRIDE)
        ys = tile_starts(H, TILE, STRIDE)

        scene_val = is_val_scene(img_id, frac=0.2)

        for y0 in ys:
            for x0 in xs:
                win = Window(x0, y0, TILE, TILE)
                win_transform = rasterio.windows.transform(win, src.transform)

                left, bottom, right, top = rasterio.windows.bounds(win, src.transform)
                cand_idx = list(gdf.sindex.intersection((left, bottom, right, top)))
                cand = gdf.iloc[cand_idx]

                bbox_geom = box(left, bottom, right, top)
                cand = cand[cand.intersects(bbox_geom)]

                rgb = src.read([1,2,3], window=win)      # (3,H,W) uint16
                rgb = to_uint8_percentile_nonzero(rgb)          # (3,H,W) uint8
                rgb = np.transpose(rgb, (1,2,0))        # (H,W,3)

                msk = rasterize_mask(cand, win_transform, TILE, TILE)

                split = "val" if scene_val else "train"
                base = f"{Path(tif).stem}_x{x0}_y{y0}"

                if split == "train":
                    write_png(train_img / f"{base}.png", rgb)
                    cv2.imwrite(str(train_msk / f"{base}.png"), (msk*255).astype(np.uint8))
                else:
                    write_png(val_img / f"{base}.png", rgb)
                    cv2.imwrite(str(val_msk / f"{base}.png"), (msk*255).astype(np.uint8))

                rows.append({"id": base, "tif": tif.name, "img_id": img_id, "x0": x0, "y0": y0, "split": split})

print(f"DEBUG: Total rows appended: {len(rows)}")

meta = pd.DataFrame(rows)
meta.to_csv(OUT/"tiles_meta.csv", index=False)

print(
    "tiles:", len(meta),
    "train:", (meta["split"] == "train").sum(),
    "val:",   (meta["split"] == "val").sum()
)


tifs: 100%|██████████| 1148/1148 [01:46<00:00, 10.73it/s]
tifs: 100%|██████████| 1148/1148 [01:46<00:00, 10.73it/s]


DEBUG: Total rows appended: 4592
tiles: 4592 train: 3660 val: 932


In [None]:
import pandas as pd
meta = pd.read_csv("./tiles_vegas_rgb/tiles_meta.csv")
print(meta.groupby("img_id").size().describe())
print(meta.groupby("img_id").size().head())


dtype: uint16
min/max per band: [(np.uint16(0), np.uint16(1525)), (np.uint16(0), np.uint16(1772)), (np.uint16(0), np.uint16(1202))]
p2/p98 per band: [(np.float64(0.0), np.float64(807.0)), (np.float64(0.0), np.float64(890.0)), (np.float64(0.0), np.float64(610.0))]


In [None]:
import torch
from torch.utils.data import DataLoader
import albumentations as A
import numpy as np
from PIL import Image
from pathlib import Path

class BuildingsTiles(torch.utils.data.Dataset):
    def __init__(self, img_dir, msk_dir, augment=None):
        self.img_paths = sorted(Path(img_dir).glob("*.png"))
        self.msk_dir = Path(msk_dir)
        self.augment = augment

    def __len__(self): return len(self.img_paths)

    def __getitem__(self, i):
        p = self.img_paths[i]
        img = np.array(Image.open(p).convert("RGB"))
        msk = np.array(Image.open(self.msk_dir / p.name))
        msk = (msk > 0).astype(np.uint8)

        if self.augment:
            out = self.augment(image=img, mask=msk)
            img, msk = out["image"], out["mask"]

        img = torch.from_numpy(img).permute(2,0,1).float() / 255.0
        msk = torch.from_numpy(msk).long()  # (H,W) class id 0/1
        return img, msk

train_aug = A.Compose([
    A.RandomRotate90(p=0.5),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.5),
])

train_ds = BuildingsTiles("./tiles_vegas_rgb/train/images", "./tiles_vegas_rgb/train/masks", train_aug)
val_ds   = BuildingsTiles("./tiles_vegas_rgb/val/images",   "./tiles_vegas_rgb/val/masks",   None)

train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=4, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=16, shuffle=False, num_workers=4, pin_memory=True)

print(len(train_ds), len(val_ds))


In [None]:
import torch
import torch.nn.functional as F
from transformers import SegformerForSemanticSegmentation

device = "cuda" if torch.cuda.is_available() else "cpu"

model = SegformerForSemanticSegmentation.from_pretrained(
    "nvidia/segformer-b2-finetuned-ade-512-512",
    num_labels=2,
    ignore_mismatched_sizes=True,
).to(device)

opt = torch.optim.AdamW(model.parameters(), lr=3e-5, weight_decay=1e-2)
scaler = torch.cuda.amp.GradScaler(enabled=(device=="cuda"))

def iou_from_logits(logits, target):
    # logits: (B,2,H,W), target: (B,H,W)
    pred = logits.argmax(1)
    inter = ((pred==1) & (target==1)).sum().item()
    union = ((pred==1) | (target==1)).sum().item()
    return inter / (union + 1e-6)

best = -1
for epoch in range(1, 6):
    model.train()
    for imgs, masks in train_loader:
        imgs, masks = imgs.to(device), masks.to(device)
        opt.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=(device=="cuda")):
            out = model(pixel_values=imgs)
            logits = F.interpolate(out.logits, size=masks.shape[-2:], mode="bilinear", align_corners=False)
            loss = F.cross_entropy(logits, masks)
        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()

    model.eval()
    ious = []
    with torch.no_grad():
        for imgs, masks in val_loader:
            imgs, masks = imgs.to(device), masks.to(device)
            out = model(pixel_values=imgs)
            logits = F.interpolate(out.logits, size=masks.shape[-2:], mode="bilinear", align_corners=False)
            ious.append(iou_from_logits(logits, masks))

    miou = float(np.mean(ious))
    print(f"epoch {epoch} | val IoU={miou:.4f}")
    if miou > best:
        best = miou
        torch.save(model.state_dict(), "best_segformer_vegas.pth")
        print("saved best")
    #save latest model with optimizer state
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': opt.state_dict(),
        }, "latest_segformer_vegas.pth")
    
