In [None]:
import os
import warnings
warnings.filterwarnings('ignore')

# Step 1: Import Modules
from DataLoader import ADE20KDownloader
from DataHandler import DataHandler 
from ADE20KLoader import ADE20KLoader
from DeepLabV3Plus import DeepLabV3Plus  # The DeepLabV3+ model implementation
from Trainer import Trainer  # The Trainer class to handle the training process
from ModelEvaluator import ModelEvaluator

IMG_SIZE = 256 # Replace with your image size
N_CLASSES = 21  # Replace with the number of classes in your dataset

BATCH_SIZE = 4
EPOCH = 2
BASE_DIR = os.getcwd()
DATASET_PATH = 'ADEChallengeData2016/images'
DOWNLOAD_PATH = 'path_to_download'
DOWNLOAD_PATH = os.path.join(BASE_DIR, DOWNLOAD_PATH)
DATASET_PATH = os.path.join(DOWNLOAD_PATH, DATASET_PATH)
downloader = ADE20KDownloader(DOWNLOAD_PATH)
downloader.download_ade(overwrite=False)  # Overwrite existing files if set to True

In [None]:
label_file = 'objectInfo151.txt'
colormap_file = 'ade20k_colormap.csv'
 
loader = ADE20KLoader(DATASET_PATH, label_file, colormap_file)
dataset = loader.prepare_dataset()

In [None]:
# Your code to create and train a model...
# model = create_model()
# train_model(model, train_dataset, val_dataset)

# Step 3: Initialize the Model
# Set the image size and the number of classes based on your dataset

deeplab = DeepLabV3Plus(image_size=IMG_SIZE, num_classes=N_CLASSES)


In [None]:
# Assume train_images, train_masks, val_images, val_masks are already prepared.


trainer = Trainer(model=deeplab.model, data_loader=loader ,dataset=dataset , batch_size=BATCH_SIZE, epochs=EPOCH)

# Step 5: Start Training
model_history = trainer.train()

# Optionally, visualize the results after training
# Visualization would depend on what your 'show_predictions()' function does
# For example, if you have a test image and mask, you could do:
# test_img, test_mask = data_loader.get_test_sample()  # Assuming this method exists
# predicted_mask = model.predict(test_img)
# display_segmentation(test_img, test_mask, predicted_mask)  # Assuming this function exists

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

def visualize_segmentation_sky(mask, color_dict, name_to_index_dict):
    # Initialize an output image with all zeros
    output_image = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
    
    # Retrieve the index for sky
    sky_index = name_to_index_dict['sky']
    # Define the color for non-sky areas
    other_color = [245,245,220]
    
    # Set the sky pixels to the color corresponding to 'sky' in the color dictionary
    output_image[mask == sky_index] = color_dict['sky']
    # Set all other pixels to 'other_color'
    output_image[mask != sky_index] = other_color
            
    return output_image 
    
def calculate_mIoU(mask_true, mask_pred, class_id):
    # Calculate intersection and union for sky class
    sky_intersection = np.sum((mask_true == class_id) & (mask_pred == class_id))
    sky_union = np.sum((mask_true == class_id) | (mask_pred == class_id))
    # Calculate IoU for sky
    sky_iou = sky_intersection / sky_union if sky_union != 0 else 0

    # Calculate intersection and union for non-sky classes
    non_sky_intersection = np.sum((mask_true != class_id) & (mask_pred != class_id))
    non_sky_union = np.sum((mask_true != class_id) | (mask_pred != class_id))
    # Calculate IoU for non-sky
    non_sky_iou = non_sky_intersection / non_sky_union if non_sky_union != 0 else 0

    # Calculate mean IoU by averaging sky and non-sky IoU values
    mIoU = (sky_iou + non_sky_iou) / 2
    
    return mIoU

def calculate_sky_iou(mask_true, mask_pred, class_id):
    # Calculate intersection and union for sky class
    intersection = np.sum((mask_true == class_id) & (mask_pred == class_id))
    union = np.sum((mask_true == class_id) | (mask_pred == class_id))
    # Calculate IoU for sky
    return intersection / union if union != 0 else 0

sum_sky_iou_val = 0
sky_count = 0
count = 0
sum_mIoU_val = 0
sky_ratio = 0

def display_sample_sky_iou_mIoU(display_list, name_to_index_dict, color_dict):
    global sum_sky_iou_val
    global sum_mIoU_val
    global sky_count
    global count
    global sky_ratio

    # Get the index for sky class
    sky_index = name_to_index_dict['sky']
    # Extract the true and predicted masks
    true_mask = display_list[1].numpy().squeeze().astype(int)
    predicted_mask = display_list[2].numpy().squeeze().astype(int)
    # Calculate IoU for sky
    sky_iou_val = calculate_sky_iou(true_mask, predicted_mask, sky_index)
    # Calculate mean IoU for all classes
    mIoU_val = calculate_mIoU(true_mask, predicted_mask, sky_index)
    
    # Visualize the input image, true mask, and predicted mask
    plt.figure(figsize=(18, 18))
    title = ['Input Image', 'True Mask', 'Predicted Mask']

    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i])
        
        # Ensure the image has 3 channels for visualization
        img_to_display = display_list[i]
        if len(img_to_display.shape) == 2:
            img_to_display = np.expand_dims(img_to_display, axis=-1)
        
        # If displaying masks, visualize segmentation
        if i in [1, 2]:  # True mask or predicted mask
            mask_to_display = img_to_display.numpy().squeeze().astype(int)
            display_img = visualize_segmentation_sky(mask_to_display, color_dict, name_to_index_dict)
            plt.imshow(display_img, interpolation='lanczos')
        else:
            # If displaying image, convert tensor to image
            plt.imshow(tf.keras.preprocessing.image.array_to_img(img_to_display))
        plt.axis('off')
    plt.show()

    # Print IoU and mIoU values
    print(f"Sky IoU: {sky_iou_val:.4f}", end=" | ")
    sum_sky_iou_val += sky_iou_val

    print(f"mIoU: {mIoU_val:.4f}", end=" ")
    sum_mIoU_val += mIoU_val

    sky_count += 1
    print(sky_count, count)

def sky_pixel_ratio(mask):
    # Calculate the total number of pixels
    total_pixels = mask.size
    # Calculate the number of sky pixels
    sky_pixels = np.sum(mask == 3)
    # Calculate the ratio of sky pixels to total pixels
    ratio = sky_pixels / total_pixels
    return ratio

def create_mask(pred_mask: tf.Tensor) -> tf.Tensor:
    # Perform argmax to select the top prediction for each pixel
    pred_mask = tf.argmax(pred_mask, axis=-1)
    # Expand dimensions to add channel axis for visualization
    pred_mask = tf.expand_dims(pred_mask, axis=-1)
    return pred_mask
    
# Initialize the count variables
count = 0
# Create a dictionary mapping class names to indices
name_to_index_dict = {name: index for index, name in enumerate(loader.class_names, start=0)}
# Loop through the validation dataset
for image, mask in dataset['val']:
    sample_image_test, sample_mask_test = image, mask

    # Calculate the ratio of sky pixels in the true mask
    sky_ratio = sky_pixel_ratio(sample_mask_test.numpy())
    
    # If the sky ratio is more than 2%, visualize the sample and calculate IoU
    if sky_ratio >= 0.02:
        if round(sky_ratio,4) == 0.5130:
            continue  # Skip if the sky ratio is exactly 51.30%
        # Predict the mask using the model
        one_img_batch = sample_image_test[0][tf.newaxis, ...]
        inference = deeplab.model.predict(one_img_batch, verbose=0)
        pred_mask_test = create_mask(inference)
        # Display the sample with Sky IoU and mIoU calculations
        display_sample_sky_iou_mIoU([sample_image_test[0], sample_mask_test[0], pred_mask_test[0]], name_to_index_dict, loader.color_dict)
    count += 1

# Print the overall statistics
print(f"sky_count: {sky_count}")
print(f"Average Sky IoU: {sum_sky_iou_val/sky_count:.4f}")
print(f"Average mIoU: {sum_mIoU_val/sky_count:.4f}")
