In [40]:
import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Dropout, Conv2DTranspose, Concatenate, add, BatchNormalization, Activation
import os
import glob
import tifffile as tiff
import rasterio
import re
import tensorflow as tf
from tensorflow.keras.models import Model
import matplotlib.pyplot as plt
import cv2
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import precision_score, recall_score, f1_score


# import json

# with open('config.json') as json_file:
#     config = json.load(json_file)
#     corruption = config['corruption_level']


In [41]:
tf.config.list_physical_devices('GPU')

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [42]:
def conv_block(x, filters, batchnorm=True):
    conv1 = Conv2D(filters, (3, 3), kernel_initializer='he_normal', padding='same')(x)
    if batchnorm is True:
        conv1 = BatchNormalization(axis=3)(conv1)
    conv1 = Activation('relu')(conv1)    
    conv2 = Conv2D(filters, (3, 3), kernel_initializer='he_normal', padding='same')(conv1)
    if batchnorm is True:
        conv2 = BatchNormalization(axis=3)(conv2)
    conv2 = Activation("relu")(conv2)

    return conv2

In [43]:
def residual_conv_block(x, filters, batchnorm=True):
    conv1 = Conv2D(filters, (3, 3), kernel_initializer='he_normal', padding='same')(x)
    if batchnorm is True:
        conv1 = BatchNormalization(axis=3)(conv1)
    conv1 = Activation('relu')(conv1)    
    conv2 = Conv2D(filters, (3, 3), kernel_initializer='he_normal', padding='same')(conv1)
    if batchnorm is True:
        conv2 = BatchNormalization(axis=3)(conv2)
    conv2 = Activation("relu")(conv2)
        
    #skip connection    
    shortcut = Conv2D(filters, kernel_size=(1, 1), kernel_initializer='he_normal', padding='same')(x)
    if batchnorm is True:
        shortcut = BatchNormalization(axis=3)(shortcut)
    shortcut = Activation("relu")(shortcut)
    respath = add([shortcut, conv2])       
    return respath

In [44]:
def dense_block(inputs, num_filters):
    conv1 = conv_block(inputs, num_filters)
    concat = Concatenate()([inputs, conv1])
    return concat

In [45]:
def residual_unet(input_shape):
    inputs = Input(input_shape)
    
    # Encoder
    conv1 = residual_conv_block(inputs, 64)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    conv2 = residual_conv_block(pool1, 128)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    conv3 = residual_conv_block(pool2, 256)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    conv4 = residual_conv_block(pool3, 512)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
    
    # Bottleneck
    conv5 = Conv2D(1024, 3, activation='relu', kernel_initializer='he_normal', padding='same')(pool4)
    conv5 = Conv2D(1024, (3, 3), kernel_initializer='he_normal', padding='same')(conv5)
    drop5 = Dropout(0.5)(conv5)
    
    # Decoder
    up6 = Conv2DTranspose(512, (2, 2), strides=(2, 2), padding='same')(drop5)
    up6 = Concatenate()([up6, conv4])
    conv6 = residual_conv_block(up6, 512)
    up7 = Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv6)
    up7 = Concatenate()([up7, conv3])
    conv7 = residual_conv_block(up7, 256)
    up8 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv7)
    up8 = Concatenate()([up8, conv2])
    conv8 = residual_conv_block(up8, 128)
    up9 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv8)
    up9 = Concatenate()([up9, conv1])
    conv9 = residual_conv_block(up9, 64)
    
    # Output
    outputs = Conv2D(1, 1, activation='sigmoid')(conv9)
    
    model = Model(inputs=inputs, outputs=outputs)
    return model

In [46]:
def dense_unet(input_shape):
    inputs = Input(input_shape)
    
    # Encoder
    conv1 = dense_block(inputs, 64)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    conv2 = dense_block(pool1, 128)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    conv3 = dense_block(pool2, 256)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    conv4 = dense_block(pool3, 512)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    # Bottleneck
    conv5 = Conv2D(1024, 3, activation='relu', kernel_initializer='he_normal', padding='same')(pool4)
    conv5 = Conv2D(1024, (3, 3), kernel_initializer='he_normal', padding='same')(conv5)
    drop5 = Dropout(0.5)(conv5)
    
    # Decoder
    up6 = Conv2DTranspose(512, (2, 2), strides=(2, 2), padding='same')(drop5)
    up6 = Concatenate()([up6, conv4])
    conv6 = residual_conv_block(up6, 512)
    up7 = Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv6)
    up7 = Concatenate()([up7, conv3])
    conv7 = residual_conv_block(up7, 256)
    up8 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv7)
    up8 = Concatenate()([up8, conv2])
    conv8 = residual_conv_block(up8, 128)
    up9 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv8)
    up9 = Concatenate()([up9, conv1])
    conv9 = residual_conv_block(up9, 64)
    
    # Output
    outputs = Conv2D(1, 1, activation='sigmoid')(conv9)
    
    model = Model(inputs=inputs, outputs=outputs)
    return model

In [47]:
def dice_coefficient(y_true, y_pred, smooth=1):
    y_true_f = y_true.flatten()
    y_pred_f = y_pred.flatten()
    intersection = np.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (np.sum(y_true_f) + np.sum(y_pred_f) + smooth)

def iou(y_true, y_pred, smooth=1):
    y_true_f = y_true.flatten()
    y_pred_f = y_pred.flatten()
    intersection = np.sum(y_true_f * y_pred_f)
    union = np.sum(y_true_f) + np.sum(y_pred_f) - intersection
    return (intersection + smooth) / (union + smooth)

def save_image(image, filepath):
    image = (image * 255).astype(np.uint8)  # Scale image to 0-255
    tiff.imwrite(filepath, image)

def predict1(model, mask_dir, corrupted_mask_dir, image_dir, output_dir, output_dir_vis, model_type, number=-1):
    dice_scores = []
    iou_scores = []
    precisions = []
    recalls = []
    f1_scores = []

    os.makedirs(output_dir, exist_ok=True)

    mask_files = [f for f in os.listdir(mask_dir) if f.endswith('_resized.tif')]
    i = 0
    print(len(mask_files))
    for mask_file in mask_files:
        # try:
        # Extract the number i from the mask file name
        print(mask_file)
        i_str = mask_file.split('_')[2]
        print(i_str)
        image_file = f"{i_str}.tif"

        corrupt_mask_file = mask_file.replace(".tif", "_corrupt.tif")
        print(corrupt_mask_file)

        mask_path = os.path.join(mask_dir, mask_file)
        image_path = os.path.join(image_dir, image_file)
        corrupt_mask_path = os.path.join(corrupted_mask_dir, corrupt_mask_file)
        print(image_path)

        # Load the image and mask
        arbitrary_img = tiff.imread(image_path)
        arbitrary_mask = tiff.imread(mask_path)
        arbitrary_corrupt_mask = tiff.imread(corrupt_mask_path)

        # Ensure the image is in the correct shape (512, 512, 1)
        if len(arbitrary_img.shape) == 2:
            arbitrary_img = np.expand_dims(arbitrary_img, axis=-1)
        elif arbitrary_img.shape[0] == 2:
            arbitrary_img = arbitrary_img[0]  # assuming you need the first channel
        
        predicted_mask = model.predict(np.expand_dims(arbitrary_img, axis=0))[0]

        # Apply thresholding to predicted mask
        predicted_mask_thresh = (predicted_mask > 0.5).astype(np.uint8)

        # Save predicted mask images
        ################################################### UNCOMMENT THIS CODE FOR SAVING ###############################################################
        save_image(predicted_mask, f"./{output_dir}/{model_type}_Predicted_Image_{i_str}.tif")
        save_image(predicted_mask_thresh, f"./{output_dir}/{model_type}_{i_str}.tif")
        ################################################### UNCOMMENT THIS CODE FOR SAVING ###############################################################


        # Plot and save the arbitrary image, actual mask, and predicted mask
        plt.figure(figsize=(15, 5))

        plt.subplot(2, 3, 1)
        plt.imshow(arbitrary_img.squeeze(), cmap='gray')
        plt.title('Arbitrary Image')
        plt.axis('off')

        plt.subplot(2, 3, 2)
        plt.imshow(arbitrary_mask, cmap='gray')
        plt.title('Actual Mask')
        plt.axis('off')

        plt.subplot(2, 3, 3)
        plt.imshow(arbitrary_corrupt_mask, cmap='gray')
        plt.title('Corrupted Mask')
        plt.axis('off')

        plt.subplot(2, 3, 4)
        plt.imshow(predicted_mask.squeeze(), cmap='gray')
        plt.title('Predicted Mask without thresholding')
        plt.axis('off')

        plt.subplot(2, 3, 5)
        plt.imshow(predicted_mask_thresh.squeeze(), cmap='gray')
        plt.title('Predicted Mask')
        plt.axis('off')

        plt.savefig(f"./{output_dir_vis}/{model_type}_Visualization_{i_str}.jpg")
        plt.close()

        # Calculate metrics
        dice = dice_coefficient(arbitrary_mask, predicted_mask_thresh)
        iou_score = iou(arbitrary_mask, predicted_mask_thresh)
        precision = precision_score(arbitrary_mask.flatten(), predicted_mask_thresh.flatten())
        recall = recall_score(arbitrary_mask.flatten(), predicted_mask_thresh.flatten())
        f1 = f1_score(arbitrary_mask.flatten(), predicted_mask_thresh.flatten())

        dice_scores.append(dice)
        iou_scores.append(iou_score)
        precisions.append(precision)
        recalls.append(recall)
        f1_scores.append(f1)
    
        i += 1
        if i == number:
            break
    # Print and return the average metrics
    mean_dice = np.mean(dice_scores)
    mean_iou = np.mean(iou_scores)
    mean_precision = np.mean(precisions)
    mean_recall = np.mean(recalls)
    mean_f1 = np.mean(f1_scores)

    print(f"Mean Dice Coefficient: {mean_dice}")
    print(f"Mean IoU: {mean_iou}")
    print(f"Mean Precision: {mean_precision}")
    print(f"Mean Recall: {mean_recall}")
    print(f"Mean F1 Score: {mean_f1}")

    return mean_dice, mean_iou, mean_precision, mean_recall, mean_f1

In [48]:
def plotAllResults(corruptions, image_dir, mask_dir, corrupted_mask_parent_dir, predicted_mask_parent_dir, output_dir, model_type, number=-1):

    mask_files = [f for f in os.listdir(mask_dir) if f.endswith('_resized.tif')]
    i = 0
    print(len(mask_files))

    plt.rcParams.update({'font.size': 8})
    subplot_rows = len(corruptions)
    subplot_columns = 4


    for mask_file in mask_files:
        # try:
        # Extract the number i from the mask file name
        print(mask_file)
        i_str = mask_file.split('_')[2]
        print(i_str)
        image_file = f"{i_str}.tif"

        image_path = os.path.join(image_dir, image_file)
        arbitrary_img = tiff.imread(image_path)

        corrupt_mask_file = mask_file.replace(".tif", "_corrupt.tif")
        predicted_mask_file_thresh = f'{model_type}_{i_str}.tif'
        predicted_mask_file = f'{model_type}_Predicted_Image_{i_str}.tif'
        print(corrupt_mask_file)

        plt.figure(figsize=(12, 20))
        j = 0

        for corruption in corruptions:

            # corrupted_mask_dir = f'./GEE_Masks/GEE_resized/train_gee/train_{corruption}_gee_with_diff_kernels'
            corrupted_mask_dir = f'{corrupted_mask_parent_dir}/train_{corruption}_gee_with_diff_kernels'
            # predicted_mask_dir = f'./Training_data_outputs/Masks/New_20_Epoch_{corruption}_with_diff_kernels'
            predicted_mask_dir = f'{predicted_mask_parent_dir}/New_20_Epoch_{corruption}_with_diff_kernels'

            corrupt_mask_path = os.path.join(corrupted_mask_dir, corrupt_mask_file)
            predicted_mask_path = os.path.join(predicted_mask_dir, predicted_mask_file)
            predicted_mask_path_thresh = os.path.join(predicted_mask_dir, predicted_mask_file_thresh)
            # print(image_path)

            # Load the image and mask
            arbitrary_corrupt_mask = tiff.imread(corrupt_mask_path)

            # Ensure the image is in the correct shape (512, 512, 1)
            if len(arbitrary_img.shape) == 2:
                arbitrary_img = np.expand_dims(arbitrary_img, axis=-1)
            elif arbitrary_img.shape[0] == 2:
                arbitrary_img = arbitrary_img[0]  # assuming you need the first channel
            

            predicted_mask_path = os.path.join(predicted_mask_dir, predicted_mask_file)
            predicted_mask = tiff.imread(predicted_mask_path)
            predicted_mask_thresh = tiff.imread(predicted_mask_path_thresh)
            
            # predicted_mask = model.predict(np.expand_dims(arbitrary_img, axis=0))[0]

            # Apply thresholding to predicted mask
            # predicted_mask_thresh = (predicted_mask > 0.5).astype(np.uint8)

            # Save predicted mask images
            ################################################### UNCOMMENT THIS CODE FOR SAVING ###############################################################
            # save_image(predicted_mask, f"./{output_dir}/{model_type}_Predicted_Image_{i_str}.tif")
            # save_image(predicted_mask_thresh, f"./{output_dir}/{model_type}_{i_str}.tif")
            ################################################### UNCOMMENT THIS CODE FOR SAVING ###############################################################


            # Plot and save the arbitrary image, actual mask, and predicted mask
            # plt.figure(figsize=(15, 5))

            plt.subplot(subplot_rows, subplot_columns, j+1)
            plt.imshow(arbitrary_img.squeeze(), cmap='gray')
            plt.title('Input Image')
            plt.axis('off')

            plt.subplot(subplot_rows, subplot_columns, j+2)
            plt.imshow(arbitrary_corrupt_mask, cmap='gray')
            plt.title(f'Mask with {corruption}% corruption')
            plt.axis('off')

            plt.subplot(subplot_rows, subplot_columns, j+3)
            plt.imshow(predicted_mask, cmap='gray')
            plt.title('Predicted Mask without thresholding')
            plt.axis('off')

            plt.subplot(subplot_rows, subplot_columns, j+4)
            plt.imshow(predicted_mask_thresh, cmap='gray')
            plt.title('Predicted Mask')
            plt.axis('off')

            j+=4

        # plt.subplots_adjust(wspace=0.2, hspace=0.2)
        plt.savefig(f"./{output_dir}/{model_type}_Visualization_{i_str}.jpg", bbox_inches="tight")
        plt.close()
        print("Done")
 
        i += 1
        if i == number:
            break

In [49]:
corruptions = [0,2,12,15,17,20,25,30]

In [50]:
# for corruption in corruptions:

#     input_shape = (512, 512, 1)
#     model = dense_unet(input_shape)
#     model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

#     input_dir = './data_new/'
#     mask_dir = './GEE_Masks/GEE_resized/train_gee/'
#     corrupt_mask_dir = f'./GEE_Masks/GEE_resized/train_gee/train_{corruption}_gee_with_diff_kernels'
#     output_dir = f'./Training_data_outputs/Masks/New_20_Epoch_{corruption}_with_diff_kernels'
#     output_dir_vis = f'./Training_data_outputs/New_20_Epoch_{corruption}_with_diff_kernels'
#     os.makedirs(output_dir,exist_ok=True)
#     os.makedirs(output_dir_vis,exist_ok=True)

#     model_name = f'UNet_dense_GEE_20_epoch_{corruption}_with_diff_kernels'
#     model = tf.keras.models.load_model(model_name, compile=False)
#     model.compile()
#     predict1(model, mask_dir, corrupt_mask_dir, input_dir, output_dir, output_dir_vis, 'dense')


In [51]:
# input_shape = (512, 512, 1)
# corruption = 0

# model = dense_unet(input_shape)
# model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# input_dir = './data_new/'
# mask_dir = './GEE_Masks/GEE_resized/train_gee/'
# corrupt_mask_dir = f'./GEE_Masks/GEE_resized/train_gee/train_{corruption}_gee_with_diff_kernels'
# output_dir = f'./Training_data_outputs/Masks/New_20_Epoch_{corruption}_with_diff_kernels'
# output_dir_vis = f'./Training_data_outputs/New_20_Epoch_{corruption}_with_diff_kernels'
# os.makedirs(output_dir,exist_ok=True)
# os.makedirs(output_dir_vis,exist_ok=True)

# model_name = f'UNet_dense_GEE_20_epoch_{corruption}_with_diff_kernels'
# model = tf.keras.models.load_model(model_name, compile=False)
# model.compile()
# predict1(model, mask_dir, corrupt_mask_dir, input_dir, output_dir, output_dir_vis, 'dense')

In [52]:
input_dir = './data_new/'
mask_dir = './GEE_Masks/GEE_resized/train_gee/'
output_dir = f'./Training_data_outputs/All_Visualisations'

corrupted_mask_parent_dir = './GEE_Masks/GEE_resized/train_gee'
predicted_mask_parent_dir = './Training_data_outputs/Masks'
model_type = 'dense'

# output_dir_vis = f'./Training_data_outputs/New_20_Epoch_{corruption}_with_diff_kernels'
os.makedirs(output_dir,exist_ok=True)
plotAllResults(corruptions, input_dir, mask_dir, corrupted_mask_parent_dir, predicted_mask_parent_dir, output_dir, model_type)

1010
NDWI_Mask_0_resized.tif
0
NDWI_Mask_0_resized_corrupt.tif
Done
NDWI_Mask_1000_resized.tif
1000
NDWI_Mask_1000_resized_corrupt.tif
Done
NDWI_Mask_1001_resized.tif
1001
NDWI_Mask_1001_resized_corrupt.tif
Done
NDWI_Mask_1002_resized.tif
1002
NDWI_Mask_1002_resized_corrupt.tif
Done
NDWI_Mask_1003_resized.tif
1003
NDWI_Mask_1003_resized_corrupt.tif
Done
NDWI_Mask_1004_resized.tif
1004
NDWI_Mask_1004_resized_corrupt.tif
Done
NDWI_Mask_1005_resized.tif
1005
NDWI_Mask_1005_resized_corrupt.tif
Done
NDWI_Mask_1006_resized.tif
1006
NDWI_Mask_1006_resized_corrupt.tif
Done
NDWI_Mask_1007_resized.tif
1007
NDWI_Mask_1007_resized_corrupt.tif
Done
NDWI_Mask_1009_resized.tif
1009
NDWI_Mask_1009_resized_corrupt.tif
Done
NDWI_Mask_100_resized.tif
100
NDWI_Mask_100_resized_corrupt.tif
Done
NDWI_Mask_1011_resized.tif
1011
NDWI_Mask_1011_resized_corrupt.tif
Done
NDWI_Mask_1013_resized.tif
1013
NDWI_Mask_1013_resized_corrupt.tif
Done
NDWI_Mask_1014_resized.tif
1014
NDWI_Mask_1014_resized_corrupt.tif
Done