# Map Segmentation (U-Net) — End-to-end notebook

**What this notebook does:**

- Loads paired satellite images and masks (robust pairing by base filename)
- Preprocesses and augments data (masks use nearest resampling)
- Builds a U-Net model, trains with BCE + Dice loss, and monitors Dice coefficient
- Provides inference, post-processing, and visualization utilities

**How to use:** edit `TRAIN_DIR` and `TEST_DIR` to point to your dataset folders. Image files should follow a naming convention like `0001_sat.jpg` and masks `0001_mask.png`.

Run all cells sequentially.


In [None]:
# Imports
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers, backend as K
import json
from tensorflow.keras.optimizers.experimental import AdamW
print('tf version', tf.__version__)

In [None]:
# ---------- Utilities: reading images and masks ----------
def read_image_rgb(path, shape):
    img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
    if img is None:
        raise FileNotFoundError(f'Image not found: {path}')
    if img.ndim == 2:  # grayscale -> convert to 3 channel
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
    elif img.shape[-1] == 4:  # BGRA -> BGR
        img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
    # convert BGR -> RGB
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (shape, shape), interpolation=cv2.INTER_AREA)
    img = img.astype(np.float32) / 255.0
    return img

def read_mask(path, shape):
    mask = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    if mask is None:
        raise FileNotFoundError(f'Mask not found: {path}')
    mask = cv2.resize(mask, (shape, shape), interpolation=cv2.INTER_NEAREST)
    mask = (mask > 127).astype(np.float32)
    mask = np.expand_dims(mask, axis=-1)
    return mask

def LoadData_pairwise(img_dir, mask_dir=None, shape=128, img_suffix='_sat.jpg', mask_suffix='_mask.png'):
    """Pair images and masks by base filename. Returns dict with 'img' and 'mask' numpy arrays."""
    if mask_dir is None:
        mask_dir = img_dir

    img_files = [f for f in os.listdir(img_dir) if f.endswith(img_suffix)]
    mask_files = [f for f in os.listdir(mask_dir) if f.endswith(mask_suffix)]

    img_map = {}
    for f in img_files:
        base = os.path.splitext(f)[0].replace('_sat','').rstrip('_')
        img_map[base] = os.path.join(img_dir, f)

    mask_map = {}
    for f in mask_files:
        base = os.path.splitext(f)[0].replace('_mask','').rstrip('_')
        mask_map[base] = os.path.join(mask_dir, f)

    common = sorted(set(img_map.keys()) & set(mask_map.keys()))
    print(f'Found {len(common)} paired examples.')

    images = []
    masks = []
    for base in common:
        try:
            images.append(read_image_rgb(img_map[base], shape))
            masks.append(read_mask(mask_map[base], shape))
        except Exception as e:
            print('Skipping', base, 'because', e)

    images = np.array(images, dtype=np.float32)
    masks = np.array(masks, dtype=np.float32)
    return {'img': images, 'mask': masks}

In [None]:
# ---------- Augmentation (tf.data friendly) ----------
data_augmentation = tf.keras.Sequential([
    layers.RandomFlip("horizontal_and_vertical"),
    layers.RandomContrast(0.2),
    layers.RandomZoom((-0.15, 0.0)),
    layers.RandomBrightness(0.1),
], name="data_augmentation")

def augment_fn(image, mask):
    stacked = tf.concat([image, mask], axis=-1)
    stacked = data_augmentation(stacked)
    image_aug = stacked[..., :3]
    mask_aug = stacked[..., 3:]
    
    # Ensure mask stays binary after augmentation
    mask_aug = tf.cast(mask_aug > 0.5, tf.float32)
    
    return image_aug, mask_aug
  
def make_dataset(frames, batch_size=16, augment=True, shuffle=True):
    images = frames['img']
    masks = frames['mask']
    ds = tf.data.Dataset.from_tensor_slices((images, masks))
    if shuffle:
        ds = ds.shuffle(buffer_size=max(1000, len(images)))
    if augment:
        ds = ds.map(lambda x,y: augment_fn(x,y), num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return ds

In [None]:
# ---------- U-Net model definition ----------
def Conv2dBlock(inputTensor, numFilters, kernelSize=3, doBatchNorm=True):
    x = layers.Conv2D(filters=numFilters, kernel_size=(kernelSize, kernelSize),
                      kernel_initializer='he_normal', padding='same')(inputTensor)
    if doBatchNorm:
        x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)

    x = layers.Conv2D(filters=numFilters, kernel_size=(kernelSize, kernelSize),
                      kernel_initializer='he_normal', padding='same')(x)
    if doBatchNorm:
        x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    return x

def unetBlock(inputImage, numFilters=16, droupouts=0.1, doBatchNorm=True):
    c1 = Conv2dBlock(inputImage, numFilters * 1, kernelSize=3, doBatchNorm=doBatchNorm)
    p1 = layers.MaxPooling2D((2,2))(c1)
    p1 = layers.Dropout(droupouts)(p1)

    c2 = Conv2dBlock(p1, numFilters * 2, kernelSize=3, doBatchNorm=doBatchNorm)
    p2 = layers.MaxPooling2D((2,2))(c2)
    p2 = layers.Dropout(droupouts)(p2)

    c3 = Conv2dBlock(p2, numFilters * 4, kernelSize=3, doBatchNorm=doBatchNorm)
    p3 = layers.MaxPooling2D((2,2))(c3)
    p3 = layers.Dropout(droupouts)(p3)

    c4 = Conv2dBlock(p3, numFilters * 8, kernelSize=3, doBatchNorm=doBatchNorm)
    p4 = layers.MaxPooling2D((2,2))(c4)
    p4 = layers.Dropout(droupouts)(p4)

    c5 = Conv2dBlock(p4, numFilters * 16, kernelSize=3, doBatchNorm=doBatchNorm)

    u6 = layers.Conv2DTranspose(numFilters*8, (3,3), strides=(2,2), padding='same')(c5)
    u6 = layers.concatenate([u6, c4])
    u6 = layers.Dropout(droupouts)(u6)
    c6 = Conv2dBlock(u6, numFilters * 8, kernelSize=3, doBatchNorm=doBatchNorm)

    u7 = layers.Conv2DTranspose(numFilters*4, (3,3), strides=(2,2), padding='same')(c6)
    u7 = layers.concatenate([u7, c3])
    u7 = layers.Dropout(droupouts)(u7)
    c7 = Conv2dBlock(u7, numFilters * 4, kernelSize=3, doBatchNorm=doBatchNorm)

    u8 = layers.Conv2DTranspose(numFilters*2, (3,3), strides=(2,2), padding='same')(c7)
    u8 = layers.concatenate([u8, c2])
    u8 = layers.Dropout(droupouts)(u8)
    c8 = Conv2dBlock(u8, numFilters * 2, kernelSize=3, doBatchNorm=doBatchNorm)

    u9 = layers.Conv2DTranspose(numFilters*1, (3,3), strides=(2,2), padding='same')(c8)
    u9 = layers.concatenate([u9, c1])
    u9 = layers.Dropout(droupouts)(u9)
    c9 = Conv2dBlock(u9, numFilters * 1, kernelSize=3, doBatchNorm=doBatchNorm)

    output = layers.Conv2D(1, (1,1), activation='sigmoid')(c9)
    model = tf.keras.Model(inputs=[inputImage], outputs=[output])
    return model

In [None]:
# ---------- Losses and metrics (BCE + Dice) ----------
def dice_coef(y_true, y_pred, smooth=1e-6):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_loss(y_true, y_pred):
    return 1.0 - dice_coef(y_true, y_pred)

def bce_dice_loss(y_true, y_pred):
    bce = tf.keras.losses.BinaryCrossentropy()(y_true, y_pred)
    return bce + dice_loss(y_true, y_pred)

In [None]:
def exact_count_dataset_images(dataset):
    total_images = 0
    for images, masks in dataset:
        batch_size = images.shape[0]
        total_images += batch_size
    return total_images

In [None]:
# ---------- Training setup & run ----------
# EDIT THESE PATHS before running
TRAIN_DIR = 'data/train'     # Your train folder
VALID_DIR = 'data/valid'     # Your validation folder
MASK_DIR = None              # set to None if masks are in same folders

IMG_SIZE = 128
BATCH_SIZE = 16
EPOCHS = 83

train_frames = LoadData_pairwise(TRAIN_DIR, mask_dir=MASK_DIR, shape=IMG_SIZE)
print('Train - images shape:', train_frames['img'].shape, 'masks shape:', train_frames['mask'].shape)

valid_frames = LoadData_pairwise(VALID_DIR, mask_dir=MASK_DIR, shape=IMG_SIZE)
print('Valid - images shape:', valid_frames['img'].shape, 'masks shape:', valid_frames['mask'].shape)

train_ds = make_dataset(train_frames, batch_size=BATCH_SIZE, augment=True, shuffle=True)
valid_ds = make_dataset(valid_frames, batch_size=BATCH_SIZE, augment=False, shuffle=False)


print("=== EXACT COUNTS ===")
print(f"Training - Original: {len(train_frames['img'])}, Augmented per epoch: {exact_count_dataset_images(train_ds)}")
print(f"Validation - Original: {len(valid_frames['img'])}, Augmented per epoch: {exact_count_dataset_images(valid_ds)}")

print("=== EXACT COUNTS ===")
print(f"Training - Original: {len(train_frames['img'])}, Augmented per epoch: {exact_count_dataset_images(train_ds)}")
print(f"Validation - Original: {len(valid_frames['img'])}, Augmented per epoch: {exact_count_dataset_images(valid_ds)}")

inputs = layers.Input((IMG_SIZE, IMG_SIZE, 3))
model = unetBlock(inputs, droupouts=0.05)
model.compile(
    optimizer=AdamW(1e-3, weight_decay=1e-4),
    loss=bce_dice_loss,
    metrics=[dice_coef, tf.keras.metrics.BinaryAccuracy(), 
             tf.keras.metrics.Precision(), tf.keras.metrics.Recall()]
)

callbacks = [
    tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=6, min_lr=1e-7, verbose=1),
    tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=15, restore_best_weights=True, verbose=1),
    tf.keras.callbacks.ModelCheckpoint('best_model.keras', monitor='val_dice_coef', mode='max', save_best_only=True, verbose=1),
]

history = model.fit(
    train_ds, 
    validation_data=valid_ds,
    epochs=EPOCHS, 
    callbacks=callbacks, 
    verbose=1
)

In [None]:
# ---------- Plot training history ----------
try:
    import matplotlib.pyplot as plt
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))

    axes[0,0].plot(history.history['loss'], label='Train Loss')
    axes[0,0].plot(history.history['val_loss'], label='Val Loss')
    axes[0,0].set_title('Loss')
    axes[0,0].legend()
    axes[0,0].grid(True)

    axes[0,1].plot(history.history['dice_coef'], label='Train Dice')
    axes[0,1].plot(history.history['val_dice_coef'], label='Val Dice')
    axes[0,1].set_title('Dice Coefficient')
    axes[0,1].legend()
    axes[0,1].grid(True)
    plt.tight_layout()
    plt.show()
except Exception as e:
    print('No history object found:', e)

In [None]:
# ---------- Inference + postprocessing ----------
def postprocess_mask(pred, min_area=50):
    m = (pred > 0.5).astype('uint8')
    if m.ndim == 3 and m.shape[-1] == 1:
        m = np.squeeze(m, axis=-1)
    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(m, connectivity=8)
    out = np.zeros_like(m, dtype='uint8')
    for i in range(1, num_labels):
        if stats[i, cv2.CC_STAT_AREA] >= min_area:
            out[labels==i] = 1
    return out

def predict_image(model, image_path, img_size=128):
    img = read_image_rgb(image_path, img_size)
    pred = model.predict(np.expand_dims(img, axis=0))[0,...,0]
    return pred

# Example usage (edit paths if you have test images)
TEST_IMAGES = ['data/test/1_sat.jpg']  # modify as needed
for p in TEST_IMAGES:
    try:
        pred = predict_image(model, p, IMG_SIZE)
        proc = postprocess_mask(pred, min_area=30)
        # show results
        orig = read_image_rgb(p, IMG_SIZE)
        plt.figure(figsize=(8,4))
        plt.subplot(1,3,1); plt.imshow(orig); plt.axis('off'); plt.title('orig')
        plt.subplot(1,3,2); plt.imshow(pred, vmin=0, vmax=1); plt.axis('off'); plt.title('raw pred')
        plt.subplot(1,3,3); plt.imshow(proc, cmap='gray'); plt.axis('off'); plt.title('postproc')
        plt.show()
    except Exception as e:
        print('Skipping', p, 'because', e)