# Train U-Net on the CAS Landslide Detection Dataset

*Authors: Abdelouahed Drissi*

## Dataset

In [1]:
import os
import tensorflow as tf
from tensorflow.keras import backend
import matplotlib.pyplot as plt

from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate, Conv2DTranspose, BatchNormalization, Dropout, Lambda

from tensorflow import keras
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from losses import dice_loss

### Function to read the image file

In [2]:
def load_image_file(image_path, mask_path):
    image = tf.io.read_file(image_path)
    mask = tf.io.read_file(mask_path)

    image = tf.image.decode_png(image, channels=3)
    mask = tf.image.decode_png(mask, channels=1)

    return {"image": image, "segmentation_mask": mask}

### Loading the dataset

In [3]:
# Dataset paths
train_image_dir = "./dataset/train/images"
train_mask_dir = "./dataset/train/masks"
valid_image_dir = "./dataset/validation/images"
valid_mask_dir = "./dataset/validation/masks"
test_image_dir = "./dataset/test/images"
test_mask_dir = "./dataset/test/masks"

# Load datasets and match images with masks
def load_data(image_dir, mask_dir):
    image_names = sorted(os.listdir(image_dir))
    mask_names = sorted(os.listdir(mask_dir))
    pairs = []
    for img_name in image_names:
        mask_name = img_name.replace("image", "mask")
        if mask_name in mask_names:
            pairs.append((os.path.join(image_dir, img_name), os.path.join(mask_dir, mask_name)))
    data = [load_image_file(image_path, mask_path) for image_path, mask_path in pairs]
    return data

data_train = load_data(train_image_dir, train_mask_dir)
data_valid = load_data(valid_image_dir, valid_mask_dir)
data_test = load_data(test_image_dir, test_mask_dir)

len(data_train), len(data_valid), len(data_test)

(1385, 396, 199)

### Normalization and Image Resizing

In [5]:
# Normalize and preprocess images and masks
image_size = 256
mean = tf.constant([0.485, 0.456, 0.406])
std = tf.constant([0.229, 0.224, 0.225])


def normalize(input_image, input_mask):
    input_image = tf.image.convert_image_dtype(input_image, tf.float32)
    input_image = (input_image - mean) / tf.maximum(std, backend.epsilon())
    input_mask = input_mask / 255
    return input_image, input_mask


def load_image(datapoint):
    input_image = tf.image.resize(datapoint["image"], (image_size, image_size))
    input_mask = tf.image.resize(
        datapoint["segmentation_mask"],
        (image_size, image_size),
        method="bilinear",
    )
    
    input_image, input_mask = normalize(input_image, input_mask)
    return ({"pixel_values": input_image}, input_mask)

In [6]:
train_data = [load_image(datapoint) for datapoint in data_train]
valid_data = [load_image(datapoint) for datapoint in data_valid]
test_data = [load_image(datapoint) for datapoint in data_test]

### Build input pipeline

In [7]:
batch_size = 4
auto = tf.data.AUTOTUNE
# Create dataset generators
def generator(data):
    for datapoint in data:
        yield datapoint


train_ds = tf.data.Dataset.from_generator(
    lambda: generator(train_data),
    output_types=({"pixel_values": tf.float32}, tf.int32),
    output_shapes=({"pixel_values": (image_size, image_size, 3)}, (image_size, image_size,1))
).cache().shuffle(batch_size * 10).batch(batch_size).repeat().prefetch(auto)

valid_ds = tf.data.Dataset.from_generator(
    lambda: generator(valid_data),
    output_types=({"pixel_values": tf.float32}, tf.int32),
    output_shapes=({"pixel_values": (image_size, image_size, 3)}, (image_size, image_size,1))
).batch(batch_size).repeat().prefetch(auto)

test_ds = tf.data.Dataset.from_generator(
    lambda: generator(test_data),
    output_types=({"pixel_values": tf.float32}, tf.int32),
    output_shapes=({"pixel_values": (image_size, image_size, 3)}, (image_size, image_size,1))
).batch(batch_size).prefetch(auto)


In [8]:
print(train_ds.element_spec)


({'pixel_values': TensorSpec(shape=(None, 256, 256, 3), dtype=tf.float32, name=None)}, TensorSpec(shape=(None, 256, 256, 1), dtype=tf.int32, name=None))
