In [None]:
import os
import numpy as np

import tensorflow as tf
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, concatenate, Conv2DTranspose, BatchNormalization
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
import tensorflow.keras.backend as K
import tensorflow.keras.preprocessing.image as prep

from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight

import json
from matplotlib import pyplot as plt
import cv2
import segmentation_models as sm

import OpenEXR as exr
import Imath

K.clear_session()

In [None]:
def load_image_paths(base_dir, scene_id):
    normal_before = base_dir + '/' + scene_id + '_change-0.png'
    normal_after = base_dir + '/' + scene_id + '_change-1.png'
    nomats_before = base_dir + '/' + scene_id + '_change-0-nomats0001.png'
    nomats_after = base_dir + '/' + scene_id + '_change-1-nomats0001.png'
    randommats_before = base_dir + '/' + scene_id + '_change-0-randommats.png'
    randommats_after = base_dir + '/' + scene_id + '_change-1-randommats.png'
    depth_before = base_dir + '/' + scene_id + '_change-0-depth0001.exr'
    depth_after = base_dir + '/' + scene_id + '_change-1-depth0001.exr'
    mask = base_dir + '/' + scene_id + '_mask.png'
    return normal_before, normal_after, nomats_before, nomats_after, randommats_before, randommats_after, depth_before, depth_after, mask

In [None]:
# Data generator
class ChangeDetectionDataset(tf.keras.utils.Sequence):
    def __init__(self, image_pairs = None, depth_pairs = None, masks = None, batch_size=1, image_size=(512, 512), shuffle=True, augment=True):
        self.json_file = open('utils/synthetic_anno.json')
        self.coco = json.load(self.json_file) 
        self.process_images()
        if(image_pairs is None and masks is None):
            self.image_pairs, self.depth_pairs, self.masks = self.get_image_pairs_and_masks('data/renders_multicam_diff_1')
        else:
            self.image_pairs = image_pairs
            self.depth_pairs = depth_pairs
            self.masks = masks
        self.batch_size = batch_size
        self.image_size = image_size
        self.shuffle = shuffle
        self.indices = np.arange(len(self.image_pairs))
        self.augment = augment
        if self.augment:
            self.image_datagen = prep.ImageDataGenerator(
                rotation_range=180,
                width_shift_range=0.2,
                height_shift_range=0.2,
                #shear_range=0.2,
                zoom_range=0.2,
                horizontal_flip=True,
                fill_mode='nearest'
            )
        self.on_epoch_end()
        
    def get_image_pairs_and_masks(self, base_dir):
        masks = []
        pairs = []
        depth_pairs = []
        
        scene_ids = [item['scene'] for item in self.coco['images']]

        for scene_id in scene_ids:
            normal_before, normal_after, nomats_before, nomats_after, randommats_before, randommats_after, depth_before, depth_after, mask = load_image_paths(base_dir, scene_id)
            if not os.path.exists(normal_before):
                continue
            if not os.path.exists(normal_after):
                continue
            if not os.path.exists(nomats_before):
                continue
            if not os.path.exists(nomats_after):
                continue
            if not os.path.exists(randommats_before):
                continue
            if not os.path.exists(randommats_after):
                continue
            if not os.path.exists(depth_before):
                continue
            if not os.path.exists(depth_after):
                continue
            if not os.path.exists(mask):
                continue
            pairs.append((normal_before, normal_after))
            pairs.append((nomats_before, nomats_after))
            pairs.append((randommats_before, randommats_after))
            
            depth_pairs.append((depth_before, depth_after))
            depth_pairs.append((depth_before, depth_after))
            depth_pairs.append((depth_before, depth_after))
            
            masks.append(mask)
            masks.append(mask)
            masks.append(mask)

        return pairs, depth_pairs, masks
    
    def __len__(self):
        return int(np.ceil(len(self.image_pairs) / self.batch_size))
    
    def __getitem__(self, index):
        indices = self.indices[index * self.batch_size:(index + 1) * self.batch_size]
        batch_image_pairs = [self.image_pairs[i] for i in indices]
        batch_depth_pairs = [self.depth_pairs[i] for i in indices]
        batch_masks = [self.masks[i] for i in indices]
        
        X, y = self.__data_generation(batch_image_pairs, batch_depth_pairs, batch_masks)
        
        return tf.convert_to_tensor(X), tf.convert_to_tensor(y)
    
    def on_epoch_end(self):
        self.indices = np.arange(len(self.image_pairs))
        if self.shuffle:
            np.random.shuffle(self.indices)
    
    def __data_generation(self, batch_image_pairs, batch_depth_pairs, batch_masks):
        X = np.zeros((self.batch_size, *self.image_size, 8), dtype=np.float32)  # 6 channels for concatenated images
        y = np.zeros((self.batch_size, *self.image_size, 4), dtype=np.float32)  # 3 channels for color-coded mask
        
        for i, (img_paths, depth_paths, mask_path) in enumerate(zip(batch_image_pairs, batch_depth_pairs, batch_masks)):
            before_img = img_to_array(load_img(img_paths[0], target_size=self.image_size)) / 255.0
            after_img = img_to_array(load_img(img_paths[1], target_size=self.image_size)) / 255.0
            before_depth = self.load_depth_image(depth_paths[0], self.image_size) 
            after_depth = self.load_depth_image(depth_paths[1], self.image_size)
            mask = img_to_array(load_img(mask_path, target_size=self.image_size))
            
            if self.augment:
                seed = np.random.randint(1e6)
                before_img = self.image_datagen.random_transform(before_img, seed=seed)
                after_img = self.image_datagen.random_transform(after_img, seed=seed)
                before_depth = self.image_datagen.random_transform(np.expand_dims(before_depth, axis=-1), seed=seed)
                after_depth = self.image_datagen.random_transform(np.expand_dims(after_depth, axis=-1), seed=seed)
                mask = self.image_datagen.random_transform(mask, seed=seed)

                before_depth = np.squeeze(before_depth)
                after_depth = np.squeeze(after_depth)
            

            mask = self.rgb_to_onehot(mask)
            
            X[i, :, :, :3] = before_img
            X[i, :, :, 3:6] = after_img
            X[i, :, :, 6] = before_depth
            X[i, :, :, 7] = after_depth
            y[i, :, :, :] = mask
            
        return X, y
    
    def process_images(self):
        self.images = {}
        for image in self.coco['images']:
            image_id = image['id']
            if image_id in self.images:
                print("ERROR: Skipping duplicate image id: {}".format(image))
            else:
                self.images[image_id] = image

    def rgb_to_onehot(self,rgb_image):
        onehot_image = np.zeros((rgb_image.shape[0], rgb_image.shape[1], 4), dtype=np.float32)
        onehot_image[(rgb_image == [0, 0, 0]).all(axis=-1)] = [1, 0, 0, 0]     # Background
        onehot_image[(rgb_image == [255, 0, 0]).all(axis=-1)] = [0, 1, 0, 0]   # Red (Taken)
        onehot_image[(rgb_image == [0, 255, 0]).all(axis=-1)] = [0, 0, 1, 0]   # Green (Added)
        onehot_image[(rgb_image == [0, 0, 255]).all(axis=-1)] = [0, 0, 0, 1]   # Blue (Shifted)
        return onehot_image
    
    def onehot_to_rgb(self, onehot_mask):
        rgb_image = np.zeros((onehot_mask.shape[0], onehot_mask.shape[1], 3), dtype=np.float32)
        rgb_image[onehot_mask[:, :, 0] == 1] = [0, 0, 0]   # Background
        rgb_image[onehot_mask[:, :, 1] == 1] = [255, 0, 0] # Red (Taken)
        rgb_image[onehot_mask[:, :, 2] == 1] = [0, 255, 0] # Green (Added)
        rgb_image[onehot_mask[:, :, 3] == 1] = [0, 0, 255] # Blue (Shifted)
        return rgb_image
    
    def load_depth_image(self, path, target_size):
        file = exr.InputFile(path)
        dw = file.header()['dataWindow']
        size = (dw.max.x - dw.min.x + 1, dw.max.y - dw.min.y + 1)

        FLOAT = Imath.PixelType(Imath.PixelType.FLOAT)
        (R, G, B) = file.channels("RGB", FLOAT)

        depth_map = np.fromstring(R, dtype=np.float32)
        depth_map = np.reshape(depth_map, (size[1], size[0]))
        
        depth_map_resized = cv2.resize(depth_map, target_size, interpolation=cv2.INTER_LINEAR)
        return depth_map_resized
    


In [None]:
dataset = ChangeDetectionDataset(augment=True)

In [None]:
# Example usage with provided scene_ids
base_dir = 'data/renders_multicam_diff1'
images_arr = dataset.images

# Split dataset
image_pairs_train, image_pairs_test, depth_pairs_train, depth_pairs_test,  masks_train, masks_test = train_test_split(
    dataset.image_pairs, dataset.depth_pairs, dataset.masks, test_size=0.2, random_state=42
)

image_pairs_train, image_pairs_validation, depth_pairs_train, depth_pairs_validation, masks_train, masks_validation = train_test_split(
    image_pairs_train, depth_pairs_train, masks_train, test_size=0.25, random_state=42
)

# Create datasets
train_dataset = ChangeDetectionDataset(image_pairs_train, depth_pairs_train, masks_train)
validation_dataset = ChangeDetectionDataset(image_pairs_validation, depth_pairs_validation, masks_validation)
test_dataset = ChangeDetectionDataset(image_pairs_test, depth_pairs_test, masks_test, shuffle=False)

In [None]:
masks = []

for i in range(0, dataset.__len__()):
    mask = dataset.masks[i]
    mask = dataset.rgb_to_onehot(img_to_array(load_img(mask, target_size=(256, 256))))
    masks.append(mask)

y_train_flat = np.argmax(masks, axis=-1).flatten()
class_weights = compute_class_weight('balanced', classes=np.unique(y_train_flat), y=y_train_flat)

class_weights_dict = dict(enumerate(class_weights))

print("Class weights:", class_weights_dict)

class_weights_tensor = tf.constant(class_weights, dtype=tf.float32)

def weighted_categorical_crossentropy(y_true, y_pred):
    # Compute the categorical crossentropy loss
    loss = CategoricalCrossentropy()(y_true, y_pred)

    weights = tf.reduce_sum(class_weights_tensor * y_true, axis=-1)
    weighted_loss = loss * weights
    
    return tf.reduce_mean(weighted_loss)

In [None]:
model = sm.Unet('resnet34', input_shape=(512, 512, 8), classes=4, activation='softmax', encoder_weights=None)

optimizer = Adam(learning_rate=0.001)

model.compile(
    optimizer=optimizer,
    loss=weighted_categorical_crossentropy,
    metrics=[sm.metrics.iou_score, sm.metrics.f1_score, sm.metrics.precision, sm.metrics.recall, 'accuracy'],
)

early_stopping = EarlyStopping(
    monitor='val_loss',  # Metric to monitor
    patience=5,         # Number of epochs with no improvement after which training will be stopped
    restore_best_weights=True  # Restore model weights from the epoch with the best value of the monitored quantity
)

reduce_lr = ReduceLROnPlateau(
    monitor='val_loss',  # Metric to monitor
    factor=0.2,          # Factor by which the learning rate will be reduced
    patience=3,          # Number of epochs with no improvement after which learning rate will be reduced
    min_lr=0.00001         # Lower bound on the learning rate
)

In [None]:
model.fit(train_dataset, batch_size=1, epochs=25, validation_data=validation_dataset, callbacks=[early_stopping, reduce_lr])

In [None]:
score = model.evaluate(test_dataset)

In [None]:

def visualize_predictions(model, dataset, batch_index=0):
    # Get a batch of data
    X, y_true = dataset[batch_index]
    
    # Summed mask for visualization check
    summed_mask = np.sum(y_true, axis=3)
    print('Summed Mask Shape: ', summed_mask.shape)

    # Assuming X is concatenated as [before_img, after_img, before_depth, after_depth]
    # with 6 channels for images and 2 channels for depth
    img_channels = 3  # number of channels in the image (RGB)
    depth_channels = 1  # number of channels in depth (assuming 1 here, adjust if different)

    # Split the data back into before images, after images, and depth images
    X1 = X[:, :, :, :img_channels]
    X2 = X[:, :, :, img_channels:2*img_channels]
    depth_before = X[:, :, :, 2*img_channels:2*img_channels+depth_channels]
    depth_after = X[:, :, :, 2*img_channels+depth_channels:]

    # Predict the masks
    y_pred = model.predict(X)

    # Convert predictions and ground truth to class labels
    y_true_labels = np.argmax(y_true, axis=-1, keepdims=True)
    y_pred_labels = np.argmax(y_pred, axis=-1, keepdims=True)

    # Convert one-hot encoded masks to RGB for visualization
    y_true_rgb = np.array([dataset.onehot_to_rgb(np.eye(4)[y_true_labels[i].squeeze()]) for i in range(y_true_labels.shape[0])])
    y_pred_rgb = np.array([dataset.onehot_to_rgb(np.eye(4)[y_pred_labels[i].squeeze()]) for i in range(y_pred_labels.shape[0])])

    # Function to plot images and masks
    def plot_comparison(before_img, after_img, depth_before, depth_after, true_mask, pred_mask, index=0):
        fig, axs = plt.subplots(1, 6, figsize=(25, 5))

        axs[0].imshow(before_img[index])
        axs[0].set_title('Before Image')
        axs[0].axis('off')

        axs[1].imshow(after_img[index])
        axs[1].set_title('After Image')
        axs[1].axis('off')

        axs[2].imshow(depth_before[index], cmap='gray')
        axs[2].set_title('Before Depth')
        axs[2].axis('off')

        axs[3].imshow(depth_after[index], cmap='gray')
        axs[3].set_title('After Depth')
        axs[3].axis('off')

        axs[4].imshow(true_mask[index])
        axs[4].set_title('Ground Truth Mask')
        axs[4].axis('off')

        axs[5].imshow(pred_mask[index])
        axs[5].set_title('Predicted Mask')
        axs[5].axis('off')

        plt.show()

    # Plot the results for the first image in the batch
    plot_comparison(X1, X2, depth_before, depth_after, y_true_rgb, y_pred_rgb)


In [None]:
for i in range(10):
    visualize_predictions(model, test_dataset, batch_index=i)

In [None]:
# Siamese U-Net Implementation based on 
# @misc{růžička2020deep,
#       title={Deep Active Learning in Remote Sensing for data efficient Change Detection}, 
#       author={Vít Růžička and Stefano D'Aronco and Jan Dirk Wegner and Konrad Schindler},
#       year={2020},
#       eprint={2008.11201},
#       archivePrefix={arXiv},
#       primaryClass={cs.CV}
# }

def unet_encoder(input_tensor, name_prefix):
    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same', name=f'{name_prefix}_conv1_1')(input_tensor)
    conv1 = BatchNormalization()(conv1)
    conv1 = Conv2D(32, (3, 3), activation='relu', padding='same', name=f'{name_prefix}_conv1_2')(conv1)
    conv1 = BatchNormalization()(conv1)
    pool1 = MaxPooling2D((2, 2), name=f'{name_prefix}_pool1')(conv1)

    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same', name=f'{name_prefix}_conv2_1')(pool1)
    conv2 = BatchNormalization()(conv2)
    conv2 = Conv2D(64, (3, 3), activation='relu', padding='same', name=f'{name_prefix}_conv2_2')(conv2)
    conv2 = BatchNormalization()(conv2)
    pool2 = MaxPooling2D((2, 2), name=f'{name_prefix}_pool2')(conv2)

    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same', name=f'{name_prefix}_conv3_1')(pool2)
    conv3 = BatchNormalization()(conv3)
    conv3 = Conv2D(128, (3, 3), activation='relu', padding='same', name=f'{name_prefix}_conv3_2')(conv3)
    conv3 = BatchNormalization()(conv3)
    pool3 = MaxPooling2D((2, 2), name=f'{name_prefix}_pool3')(conv3)

    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same', name=f'{name_prefix}_conv4_1')(pool3)
    conv4 = BatchNormalization()(conv4)
    conv4 = Conv2D(256, (3, 3), activation='relu', padding='same', name=f'{name_prefix}_conv4_2')(conv4)
    conv4 = BatchNormalization()(conv4)
    pool4 = MaxPooling2D((2, 2), name=f'{name_prefix}_pool4')(conv4)

    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same', name=f'{name_prefix}_conv5_1')(pool4)
    conv5 = BatchNormalization()(conv5)
    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same', name=f'{name_prefix}_conv5_2')(conv5)
    conv5 = BatchNormalization()(conv5)

    return conv1, conv2, conv3, conv4, conv5

def unet_decoder(conv1_b, conv1_a, conv2_b, conv2_a, conv3_b, conv3_a, conv4_b, conv4_a, center_b, center_a, num_classes):
    merge1 = concatenate([center_b, center_a], axis=-1)
    up1 = Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same', activation='relu')(merge1)
    merge1 = concatenate([up1, conv4_b, conv4_a], axis=-1)
    conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(merge1)
    conv6 = BatchNormalization()(conv6)
    conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv6)
    conv6 = BatchNormalization()(conv6)

    up2 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same', activation='relu')(conv6)
    merge2 = concatenate([up2, conv3_b, conv3_a], axis=-1)
    conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(merge2)
    conv7 = BatchNormalization()(conv7)
    conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv7)
    conv7 = BatchNormalization()(conv7)

    up3 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same', activation='relu')(conv7)
    merge3 = concatenate([up3, conv2_b, conv2_a], axis=-1)
    conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(merge3)
    conv8 = BatchNormalization()(conv8)
    conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv8)
    conv8 = BatchNormalization()(conv8)

    up4 = Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same', activation='relu')(conv8)
    merge4 = concatenate([up4, conv1_b, conv1_a], axis=-1)
    conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(merge4)
    conv9 = BatchNormalization()(conv9)
    conv9 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv9)
    conv9 = BatchNormalization()(conv9)

    output = Conv2D(num_classes, (1, 1), activation='sigmoid')(conv9)

    return output

def siamese_unet(input_size=(512, 512, 8), num_classes=4):
    inputs = Input(shape=input_size)

    # Split the input into before and after images
    before_image = tf.slice(inputs, [0, 0, 0, 0], [-1, -1, -1, 3])
    after_image = tf.slice(inputs, [0, 0, 0, 3], [-1, -1, -1, 3])

    # Encoder for before image
    conv1_b, conv2_b, conv3_b, conv4_b, center_b = unet_encoder(before_image, 'before')
    
    # Encoder for after image
    conv1_a, conv2_a, conv3_a, conv4_a, center_a = unet_encoder(after_image, 'after')

    # Decoder
    output = unet_decoder(conv1_b, conv1_a, conv2_b, conv2_a, conv3_b, conv3_a, conv4_b, conv4_a, center_b, center_a, num_classes)

    model = Model(inputs=inputs, outputs=output)
    return model



early_stopping = EarlyStopping(
    monitor='val_loss',  # Metric to monitor
    patience=5,         # Number of epochs with no improvement after which training will be stopped
    restore_best_weights=True  # Restore model weights from the epoch with the best value of the monitored quantity
)

reduce_lr = ReduceLROnPlateau(
    monitor='val_loss',  # Metric to monitor
    factor=0.2,          # Factor by which the learning rate will be reduced
    patience=3,          # Number of epochs with no improvement after which learning rate will be reduced
    min_lr=0.00001         # Lower bound on the learning rate
)


# EXAMPLE USAGE
#model = siamese_unet(input_size=(512, 512, 8), num_classes=4)
#optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
#model.compile(optimizer=optimizer, loss=weighted_categorical_crossentropy, metrics=['accuracy'])