## General Imports

In [1]:
%env SM_FRAMEWORK=tf.keras
import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')
import warnings

def function_that_warns():
    warnings.warn("deprecated", DeprecationWarning)

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    function_that_warns()  # this will not show a warning

import os
import numpy as np
import sys
import tensorflow as tf
from tensorflow import keras
import glob
import cv2
from patchify import patchify, unpatchify
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import jaccard_score, precision_recall_fscore_support
from PIL import Image

env: SM_FRAMEWORK=tf.keras


2024-01-11 14:00:28.624405: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.2


In [None]:

# Replace with your path
save_dir = '/home/alexis/workspace/DATA/ROY/'

## FINE TUNING FOR ROY

In [None]:
## Choose a model with the right path
#model = tf.keras.models.load_model('/home/alexis/workspace/notebook/models/model_1.h5')
#model = tf.keras.models.load_model('/home/alexis/workspace/notebook/models/model_1_cus.h5')
model = tf.keras.models.load_model('/home/alexis/workspace/DATA/EM/EM_model_3.h5')

In [None]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np

# Load the pre-trained UNet model
pretrained_model = model

# Replace the final layer with a new layer that predicts building locations
x = pretrained_model.layers[-2].output
x = Conv2D(1, (1, 1), activation='sigmoid')(x)
fine_tuned_model = Model(pretrained_model.input, x)

# Freeze some of the pre-trained layers
for layer in fine_tuned_model.layers[:-4]:
    layer.trainable = False

# Compile the model with a suitable loss function and optimizer
fine_tuned_model.compile(loss='binary_crossentropy', optimizer=Adam(lr=1e-4), metrics=['accuracy'])


In [None]:
# Split the data from path_in into training and validation sets
# Creating Images and masks path and deleting images whose masks didnt exist
count = 0
train_ids_paths=[]
train_masks_path=[]
base = '/home/alexis/workspace/DATA/ROY/TRAINING_DATA'

total_maps= os.listdir(base)
t = sorted(glob.glob(os.path.join(base, 'images', '*.png')))
m = sorted(glob.glob(os.path.join(base, 'labels','0', '*.png')))

print(len(t), len(m))

total_imgs=[]
total_masks=[]

for item in t:
    t1=item.split('/')
    total_imgs.append(t1[-1])

for item in m:
    m1=item.split('/')
    total_masks.append(m1[-1]) 

list_difference = []
for item in total_imgs:
    if item not in total_masks:
        list_difference.append(item)

for item in list_difference:
    toRemove= os.path.join(base, 'images', item)
    t.remove(toRemove)

if len(t) == len(m):
    train_ids_paths +=t
    train_masks_path += m
else:
    print( 'number of images',len(t), i) 
    print( 'number of masks',len(m), i)
print(len(train_ids_paths), len(train_masks_path))


In [None]:
#Reading Images and masks and creating tensors respectively
Masks=[]
Images=[]  
count = 0 
with tf.device('/gpu:0'):     
  for img_path,mask_path in tqdm(zip(train_ids_paths, train_masks_path),total=len(train_ids_paths),position=0, leave=False):
    img = Image.open(img_path)
    Images.append(np.array(img))
    mask= cv2.imread(mask_path,2)
    mask= np.array(mask)
    mask = np.expand_dims(mask,-1)
    mask[mask>0]=1
    Masks.append(mask)
    count = count +1

In [None]:
#Converting Images and masks into numpy array
Images=np.array(Images)
Masks=np.array(Masks)
print(Images.shape, Masks.shape)

#Shuffling images and masks to get rid of overfitting
from sklearn.utils import shuffle
Images, Masks=shuffle(Images, Masks, random_state=42)

#Display 10 images and respective maps
count = 0
for i in range(len(Images)):
  if count < 5:
      fig = plt.figure(figsize=(10, 7))
      rows,columns=1,2
      fig.add_subplot(rows, columns, 1)
      toShow=(1)*4  
      # showing image
      plt.imshow(Images[i])
      plt.axis('off')
      plt.title("Image")

      #Adds a subplot at the 2nd position
      fig.add_subplot(rows, columns, 2)
      plt.imshow(Masks[i][:,:,0], cmap='gray')
      plt.axis('off')
      plt.title("Mask")
  count +=1

print (count)

In [None]:
from sklearn.model_selection import KFold, StratifiedKFold

kf = KFold(n_splits = 3)
VALIDATION_ACCURACY = []
VALIDATION_LOSS = []

fold_var = 1
count=0

for train_index, val_index in kf.split(Images,Masks):
    training_data = Images[train_index]
    validation_data = Images[val_index]

    training_mask = Masks[train_index]
    validation_mask = Masks [val_index]
    
    model = fine_tuned_model
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    model_name= 'model_ROY'+str(fold_var)+'.h5'
    
    checkpoint = tf.keras.callbacks.ModelCheckpoint(save_dir+"ROY_model_"+str(fold_var)+".h5", monitor='val_accuracy', verbose=1, save_best_only=True, mode='max')
    callbacks = [
             tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=2),
              tf.keras.callbacks.TensorBoard(log_dir="logs"),checkpoint
        ]
    
    history = model.fit(x=training_data, y = training_mask, batch_size=64, epochs=10, verbose = 0, validation_data=(validation_data, validation_mask),callbacks=callbacks)	
    model.load_weights(save_dir+"ROY_model_"+str(fold_var)+".h5")
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    epochs = range(1, len(loss) + 1)
    plt.plot(epochs, loss, 'y', label='Training loss')
    plt.plot(epochs, val_loss, 'r', label='Validation loss')
    plt.title('Training and validation loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()
    
    results = model.evaluate(validation_data,validation_mask)
    results = dict(zip(model.metrics_names,results))
    VALIDATION_ACCURACY.append(results['accuracy'])
    VALIDATION_LOSS.append(results['loss'])
    tf.keras.backend.clear_session()
    fold_var += 1

## APPLY TO PREDICT IMAGES

In [None]:
from tensorflow.keras.models import load_model

gt_folder = '/home/alexis/workspace/DATA/ROY/TEST_DATA/labels/0/'

def calculate_metrics(pred, gt):
    # Binarize the masks
    gt[gt > 0] = 1
    pred[pred > 0] = 1

    # Compute precision, recall, and F1 score
    precision, recall, f1_score, _ = precision_recall_fscore_support(gt.ravel(), pred.ravel(), average='binary')
    
    # Compute IoU
    IoU = jaccard_score(gt.ravel(), pred.ravel(), average='binary')
    
    return precision, recall, f1_score, IoU

# define 9 thresholds:
threshold_values = [0.01, 0.012, 0.014, 0.016, 0.018, 0.020, 0.022, 0.024, 0.026]

# Load the model
model = load_model(save_dir + "ROY_model_2.h5")


### PREDICT IMAGES

In [None]:
# Load the images
images = '/home/alexis/workspace/DATA/ROY/ROY_SPLIT_Low/images'
output = '/home/alexis/workspace/DATA/ROY/ROY_SPLIT_Low/preds/'

for img in tqdm(os.listdir(images)):
    if img.endswith(".TIF"):
        #print('predicting image: ', img)
        # check that the image has not already been processed by checking whether the output file exists
        if os.path.isfile(os.path.join(output, img)):
            continue
        # Read the image
        large_image = cv2.imread(os.path.join(images,img),cv2.IMREAD_COLOR)
        #print("Original image shape: ",large_image.shape)
        
        # #Uncomment to show the image in original color converting from BGR to RGB
        # plt.imshow(cv2.cvtColor(large_image, cv2.COLOR_BGR2RGB))
        # plt.show()
        
        Large_Image=np.array(large_image)
        large_mask = cv2.imread(os.path.join(images,img), 0)
        if large_mask is None:
            print(f"Cannot read mask for image {img}")
            continue
        Large_mask=np.array(large_mask)
        Large_mask=np.expand_dims(Large_mask, axis=-1)
        #print("Mask shape: ",Large_mask.shape)

        remainderW =  (Large_Image.shape[0] - 256) % 256
        remainderH =  (Large_Image.shape[1] - 256) % 256

        if remainderW != 0:
            width= Large_Image.shape[0] -remainderW +256
        else:
            width = Large_Image.shape[0]

        if remainderH != 0:
            height= Large_Image.shape[1] -remainderH +256
        else:
            height = Large_Image.shape[1]
            
        container = np.zeros((width, height,3), dtype=int)
        container.shape
        container[0:Large_Image.shape[0], 0:Large_Image.shape[1],:] = Large_Image[0:Large_Image.shape[0],0:Large_Image.shape[1],:]
        #print('New image shape', container.shape)

        patches= patchify(container, (256 ,256,3),step=256)
        #print('Patches shape: ',patches.shape)
        patches1=patches.reshape(patches.shape[0]*patches.shape[1]*patches.shape[2], 256,256, 3)
        #print('Patches1 shape: ',patches1.shape)

        pred = model.predict(patches1, verbose=0) # type: ignore

        # Threshold predictions
        pred_thresholds = [pred > t for t in threshold_values]

        total_pixels = large_image.shape[0] * large_image.shape[1]
        final_pred = np.zeros((1, total_pixels))

        num_patches_x = container.shape[1] // 256
        num_patches_y = container.shape[0] // 256
        num_patches = num_patches_x * num_patches_y

        # Stack the thresholded predictions along a new axis
        stacked_preds = np.stack([t.ravel() for t in pred_thresholds], axis=0)

        # Sum the True values along the new axis and compare the sum to the threshold
        threshold = 4 # 4 out of 9 thresholds must be True
        final_pred = (np.sum(stacked_preds, axis=0) > threshold).astype(int)
        final_pred = final_pred.reshape((num_patches, 256, 256))

        # Merge the patches
        test_pred_threshold=final_pred
        merge_patches= patches1.reshape(patches.shape[0],patches.shape[1],patches.shape[2], 256,256, 3)
        merge_mask = test_pred_threshold.reshape(patches.shape[0],patches.shape[1],patches.shape[2], 256,256, 1)
        merge_patches=unpatchify(merge_patches,container.shape)
        merge_masks=unpatchify(merge_mask,(container.shape[0],container.shape[1],1))
        merge_masks = merge_masks.astype(int)
        merge_masks[merge_masks==1]=255
        np.unique(merge_masks)
        
        #print(merge_masks.shape)
        merge_masks = merge_masks[0:Large_Image.shape[0], 0:Large_Image.shape[1],:]

        merge_masks[merge_masks>40]=255
        merge_masks[merge_masks<40]=0

        # write the image as a tiff
        cv2.imwrite(os.path.join(output, img), merge_masks)

        # also write the image as a plain jpeg with the orignal image in background and the mask in bright blue
        # Ensure merge_masks is a 2D boolean array
        merge_masks = merge_masks.squeeze()  # Remove any extra dimensions
        merge_masks = merge_masks > 128  # Convert to boolean array

        # Create an RGBA mask (bright blue for the mask, transparent elsewhere)
        rgba_mask = np.zeros((merge_masks.shape[0], merge_masks.shape[1], 4), dtype=np.uint8)
        rgba_mask[merge_masks, :3] = [255, 0, 0]  # Bright blue color
        rgba_mask[merge_masks, 3] = 255  # Full opacity for the mask

        rgba_mask_image = Image.fromarray(rgba_mask)

        # Load the original image and convert to RGBA
        image = Image.fromarray(large_image).convert("RGBA")

        # Overlay RGBA mask onto the original image
        image.paste(rgba_mask_image, (0, 0), rgba_mask_image)

        # Convert to RGB and save
        final_image = image.convert("RGB")
        output_filename = img.replace('.TIF', '.jpg').replace('.tif', '.jpg')

        output_path = os.path.join(output, output_filename)
        final_image.save(output_path)
        print(f"Saved composite image to {output_path}")

        # Display the final image
        plt.imshow(final_image)
        plt.axis('off')
        plt.show()


## CHECK THRESHOLDS

In [None]:
# Load the images
image_test = '/home/alexis/workspace/DATA/ROY/ROY_SPLIT_High/images/'
mask_test = '/home/alexis/workspace/DATA/ROY/ROY_SPLIT_High/labels/0/'
output_test = '/home/alexis/workspace/DATA/ROY/ROY_SPLIT_High/test/'
gt_folder = '/home/alexis/workspace/DATA/ROY/ROY_SPLIT_High/labels/0/'

# load the predicted masks
preds = '/home/alexis/workspace/DATA/ROY/ROY_SPLIT_High/preds/'
preds_list = []
for img in os.listdir(preds):
    if img.endswith(".tif") or img.endswith(".TIF"):
        preds_list.append(img)

# load the ground truth masks
gt_list = []
for img in os.listdir(gt_folder):
    if img.endswith(".tif") or img.endswith(".TIF"):
        gt_list.append(img)

# only keep the images that are in both folders
preds_list = [x for x in preds_list if x in gt_list]
gt_list = [x for x in gt_list if x in preds_list]

print('Number of predicted masks: ', len(preds_list))
print('Number of ground truth masks: ', len(gt_list))

metrics_results = {img: {} for img in preds_list if img.endswith(".tif") or img.endswith(".TIF")}

for img in preds_list:
    if img.endswith(".tif") or img.endswith(".TIF"):
        print('checking image: ', img)
        large_image = cv2.imread(os.path.join(image_test,img),cv2.IMREAD_COLOR)
        #print("Original image shape: ",large_image.shape)
        
        # Uncomment to show the image in original color converting from BGR to RGB
        #plt.imshow(cv2.cvtColor(large_image, cv2.COLOR_BGR2RGB))
        #plt.show()
        
        Large_Image=np.array(large_image)
        large_mask = cv2.imread(os.path.join(image_test,img), 0)
        if large_mask is None:
            print(f"Cannot read mask for image {img}")
            continue
        Large_mask=np.array(large_mask)
        Large_mask=np.expand_dims(Large_mask, axis=-1)
        #print("Mask shape: ",Large_mask.shape)

        remainderW =  (Large_Image.shape[0] - 256) % 256
        remainderH =  (Large_Image.shape[1] - 256) % 256

        if remainderW != 0:
            width= Large_Image.shape[0] -remainderW +256
        else:
            width = Large_Image.shape[0]

        if remainderH != 0:
            height= Large_Image.shape[1] -remainderH +256
        else:
            height = Large_Image.shape[1]
            
        container = np.zeros((width, height,3), dtype=int)
        container.shape
        container[0:Large_Image.shape[0], 0:Large_Image.shape[1],:] = Large_Image[0:Large_Image.shape[0],0:Large_Image.shape[1],:]
        #print('New image shape', container.shape)

        patches= patchify(container, (256 ,256,3),step=256)
        #print('Patches shape: ',patches.shape)
        patches1=patches.reshape(patches.shape[0]*patches.shape[1]*patches.shape[2], 256,256, 3)
        #print('Patches1 shape: ',patches1.shape)

        test_pred = model.predict(patches1, verbose=0)

        # Threshold predictions
        test_pred_thresholds = [test_pred > t for t in threshold_values]

        total_pixels = large_image.shape[0] * large_image.shape[1]
        final_pred = np.zeros((1, total_pixels))

        num_patches_x = large_image.shape[1] // 256
        num_patches_y = large_image.shape[0] // 256
        num_patches = num_patches_x * num_patches_y

        # Stack the thresholded predictions along a new axis
        stacked_preds = np.stack([t.ravel() for t in test_pred_thresholds], axis=0)

        # find the corresponding gt mask for the image
        gt = cv2.imread(os.path.join(gt_folder,img),2)

        for idx, threshold_value in enumerate(threshold_values):
            # Calculate metrics
            precision, recall, f1_score, IoU = calculate_metrics(test_pred_thresholds[idx], gt)

            # Store metrics in the dictionary
            metrics_results[img][threshold_value] = {
                'precision': precision,
                'recall': recall,
                'f1_score': f1_score,
                'IoU': IoU
            }

        # Sum the True values along the new axis and compare the sum to the threshold
        threshold = 4 # 4 out of 9 thresholds must be True
        final_pred = (np.sum(stacked_preds, axis=0) > threshold).astype(int)
        final_pred = final_pred.reshape((num_patches, 256, 256, 1))

        # Merge the patches
        test_pred_threshold=final_pred
        merge_patches= patches1.reshape(patches.shape[0],patches.shape[1],patches.shape[2], 256,256, 3)
        merge_mask = test_pred_threshold.reshape(patches.shape[0],patches.shape[1],patches.shape[2], 256,256, 1)
        merge_patches=unpatchify(merge_patches,container.shape)
        merge_masks=unpatchify(merge_mask,(container.shape[0],container.shape[1],1))
        merge_masks = merge_masks.astype(int)
        merge_masks[merge_masks==1]=255
        np.unique(merge_masks)
        #print(merge_masks.shape)
        merge_masks = merge_masks[0:Large_Image.shape[0], 0:Large_Image.shape[1],:]

        merge_masks[merge_masks>40]=255
        merge_masks[merge_masks<40]=0

        # write the image
        cv2.imwrite(output_test+img, merge_masks)


In [None]:
# Calculate median F1 score and IoU for each threshold
median_f1_scores = []
median_IoUs = []
median_Precision = []
median_Recall = []


for threshold_value in threshold_values:
    f1_scores = [v.get(threshold_value, {}).get('f1_score', None) for v in metrics_results.values()]
    IoUs = [v.get(threshold_value, {}).get('IoU', None) for v in metrics_results.values()]
    Precision = [v.get(threshold_value, {}).get('precision', None) for v in metrics_results.values()]
    Recall = [v.get(threshold_value, {}).get('recall', None) for v in metrics_results.values()]

    # Filter out None values
    f1_scores = [score for score in f1_scores if score is not None]
    IoUs = [iou for iou in IoUs if iou is not None]
    Precision = [precision for precision in Precision if precision is not None]
    Recall = [recall for recall in Recall if recall is not None]

    median_f1_scores.append(np.median(f1_scores))
    median_IoUs.append(np.median(IoUs))
    median_Precision.append(np.median(Precision))
    median_Recall.append(np.median(Recall))

# Plot the median F1 score IoU and the recall and precision for each threshold
plt.plot(threshold_values, median_f1_scores, label='F1 score')
plt.plot(threshold_values, median_IoUs, label='IoU')
plt.plot(threshold_values, median_Precision, label='Precision')
plt.plot(threshold_values, median_Recall, label='Recall')
plt.xlabel('Threshold')
plt.ylabel('Score')
plt.legend()
plt.title('Median F1 score, IoU, Precision and Recall for each threshold')
plt.tight_layout()
plt.show()


## COLOR EXTRACTION for the Roy data

In [2]:
def colorextraction(images_list, color_ext):
    def extractRColor(src):
        """Extract the red component in an image and display it"""

        src_hsv = cv2.cvtColor(src, cv2.COLOR_BGR2HSV)

        # # Display the HSV image
        # plt.imshow(cv2.cvtColor(src_hsv, cv2.COLOR_HSV2RGB))
        # plt.title('HSV Image')
        # plt.show()

        # lower_red1 = np.array([160,40,140])
        # upper_red1 = np.array([180,255,255])

        # lower_red2 = np.array([0,40,30]) # empirically determined that 30 was the best threshold for ROY
        # upper_red2 = np.array([12,255,255])

        lower_red1 = np.array([0, 100, 100])
        upper_red1 = np.array([10, 255, 255])

        lower_red2 = np.array([160, 100, 100])
        upper_red2 = np.array([180, 255, 255])


        # Define the HSV range for the blue color to exclude
        lower_blue = np.array([110, 100, 100])  # Adjust these values based on specific blue
        upper_blue = np.array([130, 255, 255])

        # Creating masks
        mask_red1 = cv2.inRange(src_hsv, lower_red1, upper_red1)
        mask_red2 = cv2.inRange(src_hsv, lower_red2, upper_red2)
        mask_blue = cv2.inRange(src_hsv, lower_blue, upper_blue)

        # #dilate blue mask
        # kernel = np.ones((5,5), np.uint8)
        # mask_blue = cv2.dilate(mask_blue, kernel, iterations=1)

        # # dilate red masks
        # kernel = np.ones((25,25), np.uint8)
        # mask_red1 = cv2.dilate(mask_red1, kernel, iterations=1)
        # mask_red2 = cv2.dilate(mask_red2, kernel, iterations=1)

        # # show blue mask
        # plt.imshow(mask_blue, cmap='gray')
        # plt.title('Blue mask')
        # plt.show()

        # Combine red masks
        mask_red = cv2.add(mask_red1, mask_red2)
        # #show red mask
        # plt.imshow(mask_red, cmap='gray')
        # plt.title('Red mask')
        # plt.show()

        # Subtract blue mask from red mask
        mask = cv2.subtract(mask_red, mask_blue)
        # plt.imshow(mask, cmap='gray')
        # plt.title('Mask after blue subtraction')
        # plt.show()
        
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
        kernel_2 = np.ones((2,2), np.uint8)
        kernel_hm = np.array((
            [0, -1, 0],
            [-1, 1, -1],
            [0, -1, 0]), dtype="int")

        kernel_hm_diag = np.array((
            [-1, 0, -1],
            [0, 1, 0],
            [-1, 0, -1]), dtype="int")

        res_hm = cv2.morphologyEx(mask, cv2.MORPH_HITMISS, kernel_hm)
        res_hm_diag = cv2.morphologyEx(mask, cv2.MORPH_HITMISS, kernel_hm_diag)
        
        # # Display hit-miss results
        # plt.imshow(res_hm, cmap='gray')
        # plt.title('Hit-miss result (kernel_hm)')
        # plt.show()

        # plt.imshow(res_hm_diag, cmap='gray')
        # plt.title('Hit-miss result (kernel_hm_diag)')
        # plt.show()

        mask = cv2.subtract(mask, res_hm)
        mask = cv2.subtract(mask, res_hm_diag)

        # # Display mask after hit-miss subtraction
        # plt.imshow(mask, cmap='gray')
        # plt.title('Mask after hit-miss subtraction')

        res_hm = cv2.morphologyEx(mask, cv2.MORPH_HITMISS, kernel_hm)
        res_hm_diag = cv2.morphologyEx(mask, cv2.MORPH_HITMISS, kernel_hm_diag)
        
        # # Display second hit-miss results
        # plt.imshow(res_hm, cmap='gray')
        # plt.title('Second hit-miss result (kernel_hm)')
        # plt.show()

        # plt.imshow(res_hm_diag, cmap='gray')
        # plt.title('Second hit-miss result (kernel_hm_diag)')
        # plt.show()

        mask = cv2.subtract(mask, res_hm)
        mask = cv2.subtract(mask, res_hm_diag)

        # # Display mask after second hit-miss subtraction
        # plt.imshow(mask, cmap='gray')
        # plt.title('Mask after second hit-miss subtraction')
        # plt.show()

        lower_white = np.array([253,253,253])
        upper_white = np.array([255,255,255])
        src_bg = cv2.inRange(src, lower_white, upper_white)
        
        # # Display white background mask
        # plt.imshow(src_bg, cmap='gray')
        # plt.title('White background mask')
        # plt.show()

        src_final = 255 - cv2.add(255 - mask, src_bg)

        # # Display the final result
        # plt.imshow(src_final, cmap='gray')
        # # show title with the name of the image wihtout the path
        # plt.title('Final result for image: ' + filename.split('/')[-1].split('.')[0])
        # plt.show()

        return src_final

    for filename in tqdm(images_list, total=len(images_list), position=0, leave=True):
        filename_wo_path = filename.split('/')[-1].split('.')[0]
        large_image = cv2.imread(filename)
        red = extractRColor(large_image)
        out_path = os.path.join(color_ext, filename_wo_path +'.tif')
        cv2.imwrite(out_path, red)

    return 'Colour extraction completed'


In [5]:
# apply to images from ROY dataset

images = r'/home/alexis/workspace/DATA/ROY/ROY_SPLIT'
colour_ext = r'/home/alexis/workspace/DATA/ROY/ROY_COLOR_PREDS'

images_list = glob.glob(os.path.join(images, '*.TIF'))
images_list.sort()

colorextraction(images_list, colour_ext)

100%|██████████| 1202/1202 [10:28<00:00,  1.91it/s]


'Colour extraction completed'

## MERGE CNN AND COLOR

### MERGE WITH CNN SO THAT ONLY COMPONENTS DETECTED BY THE CNN ARE MODIFIED 

In [2]:
def get_base_classification_path(file_path, base_preds):
    file_name = os.path.basename(file_path)
    file_name_no_ext = os.path.splitext(file_name)[0]
    return os.path.join(base_preds, file_name_no_ext)

def get_color_classification_path(file_path, color_ext):
    file_name = os.path.basename(file_path)
    file_name_no_ext = os.path.splitext(file_name)[0]
    return os.path.join(color_ext, file_name_no_ext)

def classification_files_exist(cnnpred_file, colorpred_file):
    return os.path.exists(cnnpred_file) and os.path.exists(colorpred_file)

def get_fusion_path(file_path, fusion_ext):
    file_name = os.path.basename(file_path)
    file_name_no_ext = os.path.splitext(file_name)[0]
    return os.path.join(fusion_ext, file_name_no_ext)

def process_images(cnnpred, colorpred):
    """Replace connected components in base predictions with connected components from color predictions."""
    # Find connected components in colorpred and cnnpred
    nb_lbl_color, lbl_color = cv2.connectedComponents(colorpred, connectivity=8)
    nb_lbl_cnn, lbl_cnn = cv2.connectedComponents(cnnpred, connectivity=8)

    # Create an empty array for the output
    output = np.zeros_like(cnnpred)

    # Iterate through the connected components in cnnpred
    for i in range(1, nb_lbl_cnn):
        # Get the mask for the current connected component in cnnpred
        cnn_component_mask = (lbl_cnn == i)

        # Find all the connected components in colorpred that intersect with the current connected component in cnnpred
        color_component_labels = np.unique(lbl_color[cnn_component_mask])
        color_component_labels = color_component_labels[color_component_labels != 0]  # Ignore the background component

        # Replace the current connected component in cnnpred with the corresponding connected components from colorpred
        for color_label in color_component_labels:
            output[(lbl_color == color_label)] = 255

    return output

# version where the color extraction serves to filter CNN predictions 
def filter_images(cnnpred, colorpred):
    """Replace connected components in base predictions with connected components from color predictions."""
    # Find connected components in colorpred and cnnpred
    nb_lbl_color, lbl_color = cv2.connectedComponents(colorpred, connectivity=8)
    nb_lbl_cnn, lbl_cnn = cv2.connectedComponents(cnnpred, connectivity=8)

    # Create an empty array for the output
    output = np.zeros_like(cnnpred)

    # Iterate through the connected components in cnnpred
    for i in range(1, nb_lbl_cnn):
        # Get the mask for the current connected component in cnnpred
        cnn_component_mask = (lbl_cnn == i)

        # Find all the connected components in cnnpred that intersect with the color component
        color_component_labels = np.unique(lbl_color[cnn_component_mask])
        color_component_labels = color_component_labels[color_component_labels != 0]

        # Keep only the connected component in cnnpred that intersect with connected components from colorpred
        for color_label in color_component_labels:
            output[(lbl_cnn == i)] = 255
        

    return output


In [None]:
#Define paths
images = r'/home/alexis/workspace/DATA/ROY/ROY_SPLIT/'
base_preds = r'/home/alexis/workspace/DATA/ROY/ROY_SPLIT_PREDS'
color_ext = r'/home/alexis/workspace/DATA/ROY/ROY_COLOR_PREDS'
fusion_ext = r'/home/alexis/workspace/DATA/ROY/ROY_FUSION'

# Get the list of image files
images_list = glob.glob(os.path.join(base_preds, '*.TIF'))
images_list.sort()

new_image_list = []

# Filter image list by checking the existence of classification files
for file_path in images_list:
    colorpred_file = get_color_classification_path(file_path, color_ext) + ".tif"
    cnnpred_file = get_base_classification_path(file_path, base_preds) + ".TIF"

    if classification_files_exist(cnnpred_file, colorpred_file):
        new_image_list.append(get_base_classification_path(file_path, base_preds))

print(len(new_image_list), "images have predictions and colour extraction")

# Processing
for base_pred_file in tqdm(new_image_list, total=len(new_image_list), position=0, leave=True):
    try:
        # extract filename from base_pred_file wihtout path and extension
        base_fusion_file = get_fusion_path(base_pred_file, fusion_ext)
        base_color_file = get_color_classification_path(base_pred_file, color_ext)

        #if the file has already been processed, skip it
        #if os.path.exists(base_fusion_file + '.tif'):
            #continue

        cnnpred = Image.open(base_pred_file + '.TIF')            
        cnnpred = cnnpred.convert('L')
        cnnpred = np.array(cnnpred)
        colorpred = cv2.imread(base_color_file + '.tif',0)
        src_final = process_images(cnnpred, colorpred)
        
        # Saving
        base_pred_file = os.path.basename(base_pred_file)
        cv2.imwrite(base_fusion_file + '.tif', src_final)

        # also write the image as a plain jpeg with the orignal image in background and the mask in bright blue
        # Ensure merge_masks is a 2D boolean array
        src_final = src_final.squeeze()  # Remove any extra dimensions
        src_final = src_final > 128  # Convert to boolean array

        # Create an RGBA mask (bright pink for the mask, transparent elsewhere)
        rgba_mask = np.zeros((src_final.shape[0], src_final.shape[1], 4), dtype=np.uint8)
        rgba_mask[src_final, :3] = [200,30,180] # Bright pink color
        rgba_mask[src_final, 3] = 255  # Full opacity for the mask

        rgba_mask_image = Image.fromarray(rgba_mask)

        # Load the original image and convert to RGBA
        image = Image.fromarray(cv2.imread(os.path.join(images, base_pred_file+'.TIF')))

        # Overlay RGBA mask onto the original image
        image.paste(rgba_mask_image, (0, 0), rgba_mask_image)

        # Convert to RGB and save
        final_image = image.convert("RGB")
        output_filename = base_pred_file.replace('.TIF', '.jpg').replace('.tif', '.jpg')

        output_path = os.path.join(fusion_ext, output_filename+'.jpg')
        final_image.save(output_path)
        print(f"Saved composite image to {output_path}")

        # Display the final image
        plt.imshow(final_image)
        plt.axis('off')

    except Exception as e:
        print(e)
        print("Error processing file: " + base_pred_file)


### (optional) VERSION WHERE COLOR SERVES TO FILTER CNN OUTPUT

In [3]:
#Define paths
images = r'/home/alexis/workspace/DATA/ROY/ROY_SPLIT/'
base_preds = r'/home/alexis/workspace/DATA/ROY/ROY_SPLIT_PREDS'
color_ext = r'/home/alexis/workspace/DATA/ROY/ROY_COLOR_PREDS'
fusion_ext = r'/home/alexis/workspace/DATA/ROY/ROY_FUSION'

# Get the list of image files
images_list = glob.glob(os.path.join(base_preds, '*.TIF'))
images_list.sort()

new_image_list = []

# Filter image list by checking the existence of classification files
for file_path in images_list:
    colorpred_file = get_color_classification_path(file_path, color_ext) + ".tif"
    cnnpred_file = get_base_classification_path(file_path, base_preds) + ".TIF"

    if classification_files_exist(cnnpred_file, colorpred_file):
        new_image_list.append(get_base_classification_path(file_path, base_preds))

print(len(new_image_list), "images have predictions and colour extraction")

# Processing
for base_pred_file in tqdm(new_image_list, total=len(new_image_list), position=0, leave=True):
    try:
        # extract filename from base_pred_file wihtout path and extension
        base_fusion_file = get_fusion_path(base_pred_file, fusion_ext)
        base_color_file = get_color_classification_path(base_pred_file, color_ext)

        # if the file has already been processed, skip it
        if os.path.exists(base_fusion_file + '.tif'):
            continue

        cnnpred = Image.open(base_pred_file + '.TIF')            
        cnnpred = cnnpred.convert('L')
        cnnpred = np.array(cnnpred)
        colorpred = cv2.imread(base_color_file + '.tif',0)
        src_final = filter_images(cnnpred, colorpred)
        
        # Saving
        base_pred_file = os.path.basename(base_pred_file)
        cv2.imwrite(base_fusion_file + '.tif', src_final)

        # # also write the image as a plain jpeg with the orignal image in background and the mask in bright blue
        # # Ensure merge_masks is a 2D boolean array
        # src_final = src_final.squeeze()  # Remove any extra dimensions
        # src_final = src_final > 128  # Convert to boolean array

        # # Create an RGBA mask (bright pink for the mask, transparent elsewhere)
        # rgba_mask = np.zeros((src_final.shape[0], src_final.shape[1], 4), dtype=np.uint8)
        # rgba_mask[src_final, :3] = [200,30,180] # Bright pink color
        # rgba_mask[src_final, 3] = 255  # Full opacity for the mask

        # rgba_mask_image = Image.fromarray(rgba_mask)

        # # Load the original image and convert to RGBA
        # image = Image.fromarray(cv2.imread(os.path.join(images, base_pred_file+'.TIF')))

        # # Overlay RGBA mask onto the original image
        # image.paste(rgba_mask_image, (0, 0), rgba_mask_image)

        # # Convert to RGB and save
        # final_image = image.convert("RGB")
        # output_filename = base_pred_file.replace('.tif', '.jpeg').replace('.TIF', '.jpeg')

        # output_path = os.path.join(fusion_ext, output_filename)
        # final_image.save(output_path)
        # print(f"Saved composite image to {output_path}")

        # # Display the final image
        # plt.imshow(final_image)
        # plt.axis('off')

    except Exception as e:
        print(e)
        print("Error processing file: " + base_pred_file)

1641 images have predictions and colour extraction


  0%|          | 0/1641 [00:00<?, ?it/s]

100%|██████████| 1641/1641 [2:29:11<00:00,  5.45s/it]  


### Visualisation

Show 3 random images with their predicted masks.

In [None]:
fig, ax = plt.subplots(3, 4, figsize=(20, 15))
img_folder = r'/home/alexis/workspace/DATA/ROY/ROY_SPLIT/'
fusion_folder = r'/home/alexis/workspace/DATA/ROY/ROY_FUSION/'
preds_folder = r'/home/alexis/workspace/DATA/ROY/ROY_SPLIT_PREDS/'
color_folder = r'/home/alexis/workspace/DATA/ROY/ROY_COLOR_PREDS/'

img_list = glob.glob(os.path.join(img_folder, '*.TIF'))
img_list.sort()

preds_list = glob.glob(os.path.join(preds_folder, '*.TIF'))
preds_list.sort()

color_list = glob.glob(os.path.join(color_folder, '*.tif'))
color_list.sort()

fusion_list = glob.glob(os.path.join(fusion_folder, '*.tif'))
fusion_list.sort()

for i in range(3):
    random_index = np.random.randint(0, len(img_list))
    img = Image.open(os.path.join(img_folder,img_list[random_index]))
    pred = Image.open(os.path.join(preds_folder,preds_list[random_index]))
    color = Image.open(os.path.join(color_folder,color_list[random_index]))
    fusion = Image.open(os.path.join(fusion_folder,fusion_list[random_index]))
    
    # convert to 8-bit format
    pred = pred.convert('L')
    fusion = fusion.convert('L')
    color = color.convert('L')

    # convert to numpy arrays
    pred = np.array(pred)
    fusion = np.array(fusion)
    color = np.array(color)
    
    # show the orginal image
    # keep only the filenmae without the path and extension
    img_name = os.path.basename(img_list[random_index])
    ax[i, 0].imshow(img)
    ax[i, 0].set_title('{}'.format(img_name))
    ax[i, 0].axis('off')

    # show the predicted mask
    pred_name = os.path.basename(preds_list[random_index])
    ax[i, 1].imshow(pred, cmap='gray')
    ax[i, 1].set_title('pred for: {}'.format(pred_name))
    ax[i, 1].axis('off')

    # show the color mask
    color_name = os.path.basename(color_list[random_index])
    ax[i, 2].imshow(color, cmap='gray')
    ax[i, 2].set_title('color for: {}'.format(color_name))
    ax[i, 2].axis('off')
    
    # show the fusion mask
    fusion_name = os.path.basename(fusion_list[random_index])
    ax[i, 3].imshow(fusion, cmap='gray')
    ax[i, 3].set_title('fusion for {}'.format(fusion_name))
    ax[i, 3].axis('off')

plt.tight_layout()
plt.show()

## REMOVE BLANK IMAGES AND CLEAN

In [None]:
#check if any original image is entirely white or black

import os
import numpy as np
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
import glob

img_folder = r'/home/alexis/workspace/DATA/ROY/ROY_SPLIT/'
img_list = glob.glob(os.path.join(img_folder, '*.TIF'))
img_list.sort()

from concurrent.futures import ProcessPoolExecutor

white_images = []

def process_image(img_path):
    img = Image.open(img_path)
    img_np = np.array(img)

    # if the image contains only values over 249
    if np.all(img_np > 240):
        #print("Image {} is all white or black".format(img_path))
        img_name = os.path.basename(img_path)
        return img_np, img_name
    return None, None

def parallel_check_images(img_list, num_workers=4):
    with ProcessPoolExecutor(max_workers=num_workers) as executor:
        results = list(tqdm(executor.map(process_image, img_list), total=len(img_list)))

    for img, img_name in results:
        if img is not None and img_name is not None:
            #add image to list
            white_images.append(img_name)

parallel_check_images(img_list)


img_folder = r'/home/alexis/workspace/DATA/ROY/ROY_SPLIT/'
white_folder = r'/home/alexis/workspace/DATA/ROY/ROY_SPLIT_WHITE/'

for img in tqdm(white_images):
    shutil.move(img_folder + img, white_folder + img)

fusion_folder = r'/home/alexis/workspace/DATA/ROY/GEOPRED_10_04_2023/'
img_list = glob.glob(os.path.join(white_folder, '*.TIF'))
img_list.sort()

# delete from the fusion folder the images that are the white_folder

for img in tqdm(img_list):
    img_name = os.path.basename(img)
    #replace .TIF with .tif
    img_name = img_name.replace('.TIF', '_geo.tif')
    os.remove(fusion_folder + img_name)


## MEASURE ACCURACY

In [None]:
# apply color extraction and fusion to test images

base_img = r'/home/alexis/workspace/DATA/ROY/TEST_DATA/images/'
base_preds = r'/home/alexis/workspace/DATA/ROY/TEST_DATA/preds/'
color_ext = r'/home/alexis/workspace/DATA/ROY/TEST_DATA/color/'
fusion_ext = r'/home/alexis/workspace/DATA/ROY/TEST_DATA/fusion/'

images_list = glob.glob(os.path.join(base_img, '*.tif'))
images_list.sort()

images_list = images_list

colorextraction(images_list, color_ext)

new_image_list = []

# Filter image list by checking the existence of classification files
for file_path in images_list:
    colorpred_file = get_color_classification_path(file_path, color_ext)+'.tif'
    cnnpred_file = get_base_classification_path(file_path, base_preds)+'.tif'

    if classification_files_exist(cnnpred_file, colorpred_file):
        new_image_list.append(get_base_classification_path(file_path, base_preds))

print(len(new_image_list), "images have predictions and colour extraction")

# Processing
for base_pred_file in tqdm(new_image_list, total=len(new_image_list), position=0, leave=True):
    try:
        # extract filename from base_pred_file wihtout path and extension
        base_color_file = get_color_classification_path(base_pred_file, color_ext)
        cnnpred = Image.open(base_pred_file + '.tif')
        # convert to 8-bit format
        cnnpred = cnnpred.convert('L')
        cnnpred = np.array(cnnpred)
        colorpred = cv2.imread(base_color_file + '.tif',0)
        src_final = process_images(cnnpred, colorpred)

        # Saving
        base_pred_file = os.path.basename(base_pred_file)
        cv2.imwrite(os.path.join(fusion_ext, base_pred_file + '.tif'), src_final)

    except Exception as e:
        print(e)
        print("Error processing file: " + base_pred_file)

In [None]:
from sklearn.metrics import jaccard_score, precision_score, recall_score
from sklearn.metrics import precision_recall_fscore_support
from PIL import Image

# load the predicted masks
preds = '/home/alexis/workspace/DATA/ROY/TEST_DATA/fusion/'
preds_list = []
for img in os.listdir(preds):
    if img.endswith(".tif"):
        preds_list.append(img)

# load the ground truth masks
gt_folder = '/home/alexis/workspace/DATA/ROY/TEST_DATA/labels/0/'
gt_list = []
for img in os.listdir(gt_folder):
    if img.endswith(".tif"):
        gt_list.append(img)

# only keep the images that are in both folders
preds_list = [x for x in preds_list if x in gt_list]
gt_list = [x for x in gt_list if x in preds_list]

print('Number of predicted masks: ', len(preds_list))
print('Number of ground truth masks: ', len(gt_list))

# create dictionaries to store the precision, recall, F1 score and IoU, with image names as keys
precision_dict = {}
recall_dict = {}
f1_score_dict = {}
iou_dict = {}

for i in range(len(preds_list)):
    pred = Image.open(os.path.join(preds,preds_list[i]))
    gt = Image.open(os.path.join(gt_folder,gt_list[i]))
    
    # convert to 8-bit format
    pred = pred.convert('L')
    gt = gt.convert('L')
    
    # convert to numpy arrays
    pred = np.array(pred)
    gt = np.array(gt)
    
    # in the gt mask all values above 0 are considered as 1
    gt[gt > 0] = 1
    pred[pred > 0] = 1

    # compute precision, recall, and F1 score, for 0 and 1 values
    precision, recall, f1_score, _ = precision_recall_fscore_support(gt.ravel(), pred.ravel())
    precision_dict[preds_list[i]] = precision
    recall_dict[preds_list[i]] = recall
    f1_score_dict[preds_list[i]] = f1_score

    # compute IoU
    IoU = jaccard_score(gt.ravel(), pred.ravel(), average=None)
    iou_dict[preds_list[i]] = IoU

# print the mean precision, recall, IoU, and F1 score
print('Mean precision: ', np.mean(list(precision_dict.values())))
print('Mean recall: ', np.mean(list(recall_dict.values())))
print('Mean IoU: ', np.mean(list(iou_dict.values())))
print('Mean F1 score: ', np.mean(list(f1_score_dict.values())))

In [None]:
# #order the dictionaries in decreasing order of values
# precision_dict = dict(sorted(precision_dict.items(), key=lambda item: item[1], reverse=True))
# recall_dict = dict(sorted(recall_dict.items(), key=lambda item: item[1], reverse=True))
# f1_score_dict = dict(sorted(f1_score_dict.items(), key=lambda item: item[1], reverse=True))
# iou_dict = dict(sorted(iou_dict.items(), key=lambda item: item[1], reverse=True))

# set dir for original images
img_folder = r'/home/alexis/workspace/DATA/ROY/TEST_DATA/images/'
col_folder = r'/home/alexis/workspace/DATA/ROY/TEST_DATA/color/'
preds_folder = r'/home/alexis/workspace/DATA/ROY/TEST_DATA/preds/'
fusion_folder = r'/home/alexis/workspace/DATA/ROY/TEST_DATA/fusion/'

# for each key, plot on the same figure, the precision, recall, F1 score, and IoU, only iterating over the first 5 keys
for key in list(precision_dict.keys())[:5]:
    #if precision_dict[key] < 0.5:
    fig, ax = plt.subplots(1, 4, figsize=(20, 5))
    fig.suptitle('{}'.format(key))

    #load the original image
    print(os.path.join(img_folder,key))
    img = Image.open(os.path.join(img_folder,key))
    img = np.array(img)
    
    # load the predicted mask
    pred = Image.open(os.path.join(preds_folder,key))
    pred = pred.convert('L')
    pred = np.array(pred)
    pred[pred > 0] = 1

    # load the color pred
    color = Image.open(os.path.join(col_folder,key))
    color = color.convert('L')
    color = np.array(color)
    color[color > 0] = 1

    # load the ground truth mask
    gt = Image.open(os.path.join(gt_folder,key))
    gt = gt.convert('L')
    gt = np.array(gt)
    gt[gt > 0] = 1

    # load the fusion mask
    fusion = Image.open(os.path.join(fusion_folder,key))
    fusion = fusion.convert('L')
    fusion = np.array(fusion)
    fusion[fusion > 0] = 1
    
    
    # show the predicted mask
    ax[0].imshow(pred, cmap='gray')
    ax[0].set_title('PRED, precision: {}'.format(precision_dict[key]))
    ax[0].axis('off')

    # show the color mask
    ax[1].imshow(fusion, cmap='gray')
    ax[1].set_title('FUSION, recall: {}'.format(recall_dict[key]))
    ax[1].axis('off')

    
    # show the ground truth mask
    ax[2].imshow(gt, cmap='gray')
    ax[2].set_title('GT')
    ax[2].axis('off')

    # show the original image
    ax[3].imshow(img, cmap='gray')
    ax[3].set_title('original image')
    ax[3].axis('off')
    
    
    plt.tight_layout()
    plt.show()
# # show the names of the images with the lower recall
# for i in range(len(recall_list)): 
#     if recall_list[i] < 0.1:
#         print(preds_list[i])


### Visualisation

Show 3 random images with their predicted masks, and the corresponding ground truth masks.

In [None]:
fig, ax = plt.subplots(3, 3, figsize=(15, 15))
img_folder = '/home/alexis/workspace/DATA/ROY/TEST_DATA/images/'
for i in range(3):
    #random_index = np.random.randint(0, len(preds_list))
    random_index = 30+i
    pred = Image.open(os.path.join(preds,gt_list[random_index]))
    gt = Image.open(os.path.join(gt_folder,gt_list[random_index]))
    img = Image.open(os.path.join(img_folder,gt_list[random_index]))
    
    # convert to 8-bit format
    pred = pred.convert('L')
    gt = gt.convert('L')
    
    # convert to numpy arrays
    pred = np.array(pred)
    gt = np.array(gt)
    
    # show the orginal image
    ax[i, 0].imshow(img)
    ax[i, 0].set_title('Original image for file {}'.format(preds_list[random_index]))
    ax[i, 0].axis('off')

    # show the original mask
    ax[i, 1].imshow(pred, cmap='gray')
    ax[i, 1].set_title('Predicted mask for file {}'.format(preds_list[random_index]))
    ax[i, 1].axis('off')
    
    # show the ground truth mask
    ax[i, 2].imshow(gt, cmap='gray')
    ax[i, 2].set_title('Ground truth mask for file {}'.format(gt_list[random_index]))
    ax[i, 2].axis('off')

plt.tight_layout()
plt.show()

# Detect lines in image prediction

In [8]:
import subprocess
import os
import glob
from tqdm.auto import tqdm

# Set environment variable to access opencv static libraries
cv_lib_dir = '../opencv4/build/install_folder/lib:'
if os.getenv('LD_LIBRARY_PATH') == None:
    os.environ['LD_LIBRARY_PATH'] = cv_lib_dir
else:
    os.environ['LD_LIBRARY_PATH'] = cv_lib_dir + os.getenv('LD_LIBRARY_PATH')

print(os.getenv('LD_LIBRARY_PATH'))

base_classif = '/home/alexis/workspace/DATA/ROY/ROY_FUSION'
images_list = glob.glob(os.path.join(base_classif, '*.tif'))
images_list.sort()
# images_list = images_list[388:]
print(len(images_list))

base_rep_flash = "/home/alexis/workspace/notebook/yfaula_app/build/"
base_results_flash = "/home/alexis/workspace/notebook/yfaula_app/images/"

if not os.path.exists(base_results_flash):
    os.makedirs(base_results_flash)

compteur = 0

for filename in tqdm(images_list, total=len(images_list), position=0, leave=True):
    rproc = subprocess.run([base_rep_flash + "flash", "-p="+base_rep_flash + "flash_default_parameters.txt", 
                            "--output_dir=" + base_results_flash, filename], capture_output=True)
    if rproc.returncode == 0:
        compteur += 1
    else:
        print( 'Problem with : ', filename)
        print(rproc.stderr)
        
print('NB processed images: ', compteur, '/', len(images_list))

../opencv4/build/install_folder/lib:../opencv4/build/install_folder/lib:../opencv4/build/install_folder/lib:../opencv4/build/install_folder/lib:../opencv4/build/install_folder/lib:../opencv4/build/install_folder/lib:../opencv4/build/install_folder/lib:/usr/local/cuda/lib64
1606


100%|██████████| 1606/1606 [12:33<00:00,  2.13it/s]

NB processed images:  1606 / 1606





In [10]:
import cv2
base_output = '/home/alexis/workspace/DATA/ROY/FLASH_ELIMINATION/'


if not os.path.exists(base_output):
    os.makedirs(base_output)

compteur=0

for filename in tqdm(images_list, total=len(images_list), position=0, leave=True):
    im_flash_path = base_results_flash + filename.split('/')[-1].split('.')[0] + "_flashlines.png"
    
    pred = cv2.imread(filename, 0)
    
    if os.path.exists(im_flash_path):
        lines = cv2.imread(im_flash_path, 0)
        
        ### POST PROCESSING OF LINE DETECTION : 2 possibilities :
        ## change the if condition in order to choose a solution (False or True)
        ## 1) (SOLUTION PROPOSAL 1) enlarge the detection to the adjacent pixels
        if False:
            nb_lbl, lbl = cv2.connectedComponents(pred,connectivity=8)

            #pred_merge_and  = cv2.bitwise_and(pred, cnnpred)

            counter_lbl = [ 0 ] * nb_lbl
            counter_lbl_lines = [ 0 ] * nb_lbl

            for i in range(lbl.shape[0]):
                for j in range(lbl.shape[1]):
                    if pred[i, j]:
                        counter_lbl[lbl[i, j]] += 1
                    if lines[i, j] and pred[i, j]:
                        counter_lbl_lines[lbl[i, j]] += 1

            for i in range(pred.shape[0]):
                for j in range(pred.shape[1]):
                    # between .20 and .33 because of flash mask size
                    if pred[i, j] > 0 and counter_lbl[lbl[i, j]] > 0:
                        pred[i, j] = 0 if counter_lbl_lines[lbl[i, j]]/counter_lbl[lbl[i, j]] > 0.25  else 255

            pred_final = pred
            
        else:
        ## 2) (BASIC SOLUTION) enlarge the detection with morpholgy operations
            ker = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (6, 6)) # depends on the size of flash kernel (7, 7)
            lines = cv2.dilate(lines, ker, iterations=1)
            
            mask_flash2 = cv2.bitwise_and(lines, pred)
            pred_final = cv2.bitwise_xor(pred, mask_flash2)
        
        compteur+=1
    else:
        pred_final = pred

    cv2.imwrite(base_output + filename.split('/')[-1].split('.')[0] + '.tif', pred_final)

print('NB lines elimination: ', compteur)

100%|██████████| 1606/1606 [04:44<00:00,  5.64it/s]

NB lines elimination:  1606



