# Attention error generator
This notebook uses the image masks folders generated using [PIPELINE 1] and the attention maps generated using ./[task_type]/[[task_type] ADS] to calculate the attention mask errors.

## Usage:
Run all cells

## Requirements:
-Image folders:\
    ./[task_type]/content/output_plots/[domain_type]/[domain_type]_[ADS]_att
    ./[task_type]/content/output_plots/[domain_type]/[domain_type]_additional_mask

## Outputs:
    ./[task_type]/content/output_plots/[domain_type]/[domain_type]_att_loss

In [6]:
import os
import cv2
import numpy as np
import glob
import tensorflow as tf
import tensorflow as tf
import re
def natural_sort_key(s):
    return [int(text) if text.isdigit() else text.lower() for text in re.split(r'(\d+)', s)]

def calculate_scores(type,model_type,source_name):
    if type=='donkey':
        folder1_path = './'+type+'/content/output_plots/'+model_type+'/'+str(source_name)+'_driver_att/'
        folder2_path = './'+type+'/content/output_plots/real/real_mask/'
    else:
        folder1_path = './'+type+'/content/output_plots/'+model_type+'/'+str(source_name)+'_yolo_att/'
        folder2_path = './'+type+'/content/output_plots/real/real_additional_mask/'
    os.makedirs('./'+type+'/content/output_plots/'+model_type+'/'+str(source_name)+'_att_loss/', exist_ok=True)
    output_folder = './'+type+'/content/output_plots/'+model_type+'/'+str(source_name)+'_att_loss/'
    
    images1,path1 = load_images_from_folder(folder1_path)
    images2,path2 = load_images_from_folder(folder2_path)
    print(path1)
    print(path2)
    
    print("Calculating att scores...")
    if type=="kitti":
        att_loss = [calculate_att_loss_kitti(img1, img2) for img1, img2 in zip(images1, images2)]
    else:
        att_loss = [calculate_att_loss_donkey(img1, img2) for img1, img2 in zip(images1, images2)]
    for i,path in enumerate(path1):
        att_filename=output_folder+path.split("/")[-1].split(".")[0]+".txt"
        with open(att_filename, 'w') as f:
            f.write(str(att_loss[i]))

def colormap_jet_to_gray(colormap_jet_image):
    gray_image = cv2.cvtColor(colormap_jet_image, cv2.COLOR_BGR2GRAY)
    return gray_image

def colormap_jet2_to_gray(colormap_jet_image):
    gray_image = cv2.cvtColor(colormap_jet_image, cv2.COLOR_RGB2GRAY)
    return gray_image

def custom_mapping(image):
    mapping = {
        (255, 0, 0): 0,
        (128, 128, 0): 255
    }
    
    height, width, _ = image.shape
    single_channel_image = np.zeros((height, width), dtype=np.uint8)
    
    for key, value in mapping.items():
        mask = np.all(image == np.array(key), axis=-1)
        single_channel_image[mask] = value
    
    return single_channel_image

def calculate_att_loss_kitti(image1, image2):
    min_height = min(image1.shape[0], image2.shape[0])
    image1 = image1[:min_height, ...]
    image2 = image2[:min_height, ...]

    image1 = colormap_jet_to_gray(image1)
    image2 = custom_mapping(image2)

    image1[image1 < 20] = 0
    image1[image1 < 100] *= 2
    image1[image1 > 99] = 255

    mse_loss1 = np.mean((image2 - image1) ** 2)
    return mse_loss1

def load_images_from_folder(folder_path):
    images = []
    paths = []
    for img_path in sorted(glob.glob(folder_path+'*.png'),key=natural_sort_key):
        if os.path.isfile(img_path):
            img = cv2.imread(img_path)
            if img is not None:
                images.append(img)
                paths.append(img_path)
    return np.array(images),np.array(paths)

def calculate_att_loss_donkey(image1, image2):
    min_height = min(image1.shape[0], image2.shape[0])

    image1 = image1[:min_height, ...]
    image2 = image2[:min_height, ...]

    image1 = cv2.cvtColor(image1, cv2.COLOR_BGR2RGB)
    image1 = colormap_jet_to_gray(image1)
    image1=255-image1

    mask = np.all(image2 == [128, 128, 0], axis=2)
    mask2 = np.all(image2 == [0, 0, 0], axis=2)

    image2 = np.zeros_like(image2[:,:,0])
    image2[mask] = 255
    image2[mask2] = 255
    
    image1[image1 < 80] = 0
    image1[image1 > 79] = 255

    mse_loss1 = np.mean((image2 - image1) ** 2)

    return mse_loss1


### Calculate attention error metric for kitty domain

In [None]:
type,model_type,source_name='kitti','sim','sim'
calculate_scores(type,model_type,source_name)

type,model_type,source_name='kitti','real','real'
calculate_scores(type,model_type,source_name)

type,model_type='kitti','cyclegan'
source_names=["cyclegan_1","cyclegan_2","cyclegan_3"]
for source_name in source_names:
    calculate_scores(type,model_type,source_name)

type,model_type='kitti','pix2pix_mask_manual'
source_names=["pix2pix_mask_1_sim","pix2pix_mask_2_sim","pix2pix_mask_3_sim"]
for source_name in source_names:
    calculate_scores(type,model_type,source_name)

### Calculate attention error metric for donkey domain

In [None]:
type,model_type,source_name='donkey','sim','sim'
calculate_scores(type,model_type,source_name)

type,model_type,source_name='donkey','real','real'
calculate_scores(type,model_type,source_name)

type,model_type='donkey','cyclegan'
source_names=["cyclegan_1","cyclegan_2","cyclegan_3"]
for source_name in source_names:
    calculate_scores(type,model_type,source_name)

type,model_type='donkey','pix2pix_mask_manual'
source_names=["pix2pix_mask_1_sim","pix2pix_mask_2_sim","pix2pix_mask_3_sim"]
for source_name in source_names:
    calculate_scores(type,model_type,source_name)