In [None]:
import torch
import torch.utils.data
import torchvision.transforms as T

import numpy as np
import cv2
import os


class DatasetTrain_SYNTHIA(torch.utils.data.Dataset):
    def __init__(self, option): # could add self.option.data_path and self.option.metadata_path

        self.option = option

        self.img_dir = '/content/drive/MyDrive/BiasMitigation/Datasets/SYNTHIA/RGB'
        self.label_dir = '/content/drive/MyDrive/BiasMitigation/Datasets/SYNTHIA/GT/SEMANTIC_LABELS'

        self.img_h = 760   
        self.img_w = 1280

        self.new_img_h = 512
        self.new_img_w = 1024

        train_img_dir_path = self.img_dir + "/train"
        file_names = os.listdir(train_img_dir_path)
        self.examples = []

        for file_name in file_names:
            img_path = train_img_dir_path + '/' + file_name
            label_img_path = self.label_dir + '/train/' + file_name

            example = {}
            example["img_path"] = img_path
            example["label_img_path"] = label_img_path
            example["img_id"] = file_name 
            self.examples.append(example)

        self.num_examples = len(self.examples)


    def __getitem__(self, index):
        example = self.examples[index]
        greyscale = self.option.train_greyscale 

        img_path = example["img_path"]
        
        if greyscale: 
            grey_img =  cv2.imread(img_path, 0) # (shape: (760, 1280))
            dummy_RGB_image = np.ndarray(shape=(grey_img.shape[0], grey_img.shape[1], 3), dtype= np.uint8) 
            dummy_RGB_image[:, :, 0] = grey_img[:, :]
            dummy_RGB_image[:, :, 1] = grey_img[:, :]
            dummy_RGB_image[:, :, 2] = grey_img[:, :]
            img = dummy_RGB_image  # (shape: (760, 1280, 3)) grey values copied along all 3 channels for use in pretrained networks expecting 3 channel input
        else:
            img = cv2.imread(img_path, -1) # (shape: (760, 1280, 3))
        
        # resize img without interpolation (want the image to still match
        # label_img, which we resize below):
        img = cv2.resize(img, (self.new_img_w, self.new_img_h),
                         interpolation=cv2.INTER_NEAREST) # (shape: (512, 1024, 3))

        label_img_path = example["label_img_path"]
        label_img = cv2.imread(label_img_path, -1) # (shape: (760, 1280))
        # resize label_img without interpolation (want the resulting image to
        # still only contain pixel values corresponding to an object class):
        label_img = cv2.resize(label_img, (self.new_img_w, self.new_img_h),
                               interpolation=cv2.INTER_NEAREST) # (shape: (512, 1024))


        #### IMAGE AUGMENTATION #### - This must be completed for labels, bias labels and raw images ** UPDATE changed bias label to extract after img aug steps **
        # flip the img and the label with 0.5 probability:
        flip = np.random.randint(low=0, high=2)
        if flip == 1:
            img = cv2.flip(img, 1)
            label_img = cv2.flip(label_img, 1)
            #bias_label = cv2.flip(bias_label, 1)

        ########################################################################
        # randomly scale the img and the label:
        ########################################################################
        scale = np.random.uniform(low=0.7, high=2.0)
        new_img_h = int(scale*self.new_img_h)
        new_img_w = int(scale*self.new_img_w)

        # resize img without interpolation (want the image to still match
        # label_img, which we resize below):
        img = cv2.resize(img, (new_img_w, new_img_h),
                         interpolation=cv2.INTER_NEAREST) # (shape: (new_img_h, new_img_w, 3))

        # resize label_img without interpolation (want the resulting image to
        # still only contain pixel values corresponding to an object class):
        label_img = cv2.resize(label_img, (new_img_w, new_img_h),
                               interpolation=cv2.INTER_NEAREST) # (shape: (new_img_h, new_img_w))

        #bias_label = cv2.resize(bias_label, (new_img_w, new_img_h),
        #                       interpolation=cv2.INTER_NEAREST) # (shape: (new_img_h, new_img_w))
        
        ########################################################################

        # # # # # # # # debug visualization START
        # print (scale)
        # print (new_img_h)
        # print (new_img_w)
        #
        # cv2.imshow("test", img)
        # cv2.waitKey(0)
        #
        # cv2.imshow("test", label_img)
        # cv2.waitKey(0)
        # # # # # # # # debug visualization END

        ########################################################################
        # select a 256x256 random crop from the img and label:
        ########################################################################
        start_x = np.random.randint(low=0, high=(new_img_w - 256))
        end_x = start_x + 256
        start_y = np.random.randint(low=0, high=(new_img_h - 256))
        end_y = start_y + 256

        img = img[start_y:end_y, start_x:end_x] # (shape: (256, 256, 3))
        label_img = label_img[start_y:end_y, start_x:end_x] # (shape: (256, 256))
      
        ########################################################################
        #~~ At this stage all images enter the model as 256 x 256

        # # # # # # # # debug visualization START
        # print (img.shape)
        # print (label_img.shape)
        #
        # cv2_imshow(img)
        # cv2.waitKey(0)
        #
        # cv2_imshow(label_img)
        # cv2.waitKey(0)
        # # # # # # # # debug visualization END

        ## Getting bias labels...

        bias_label = img
        bias_label = cv2.resize(bias_label, (32, 32), # 32x32 becuase we're using resnet 8
                               interpolation=cv2.INTER_NEAREST) # Hard code for deeplab size for now...see note below. 
        ''' could have something here like
        if option.network_type == SegNet:
            feat_map_height = 256/.. , feat_map_width = 256/ ..
        elif option.network_type == Deeplab:
            feat_map_height = 256/8 , feat_map_width = 256/8 # using resnet OS8
        else:
            print('Warning no information for expected feature map size of this Network')
        
        bias_label = cv2.resize(bias_label, (feat_map_height, feat_map_width),
                               interpolation=cv2.INTER_NEAREST)
        '''
        bias_label = torch.from_numpy(np.transpose(bias_label,(2, 0, 1)))
        #mask_image = torch.lt(bias_label.float()-0.00001, 0.) * 255   # Useful if any negative numbers appear
        bias_label = torch.div(bias_label,32)  # Change the division here for higher resolution. Divide by 16 would give us 256/16 = 16 bins instead of 256/32 = 8 bins 
        #bias_label = bias_label + mask_image
        bias_label = bias_label.long()
        #~~~~~~~~~~~~~~~~~~~~~~~~~

        # normalize the img (with the mean and std for the pretrained ResNet):
        img = img/255.0
        img = img - np.array([0.485, 0.456, 0.406])
        img = img/np.array([0.229, 0.224, 0.225]) # (shape: (256, 256, 3))
        img = np.transpose(img, (2, 0, 1)) # (shape: (3, 256, 256))
        img = img.astype(np.float32)

        # convert numpy -> torch:
        img = torch.from_numpy(img) # (shape: (3, 256, 256))
        label_img = torch.from_numpy(label_img) # (shape: (256, 256))
        
        return (img, bias_label, label_img)


    def __len__(self):
        return self.num_examples


class DatasetVal_SYNTHIA(torch.utils.data.Dataset):
    def __init__(self, option):
        
        self.option = option
        self.network = option.network_type

        self.img_dir = '/content/drive/MyDrive/BiasMitigation/Datasets/SYNTHIA/RGB'
        self.label_dir = '/content/drive/MyDrive/BiasMitigation/Datasets/SYNTHIA/GT/LABELS'

        self.img_h = 760   
        self.img_w = 1280

        self.new_img_h = 512
        self.new_img_w = 1024

        val_img_dir_path = self.img_dir + "/val"
        file_names = os.listdir(val_img_dir_path)
        self.examples = []

        for file_name in file_names:
            img_path = val_img_dir_path + '/' + file_name
            label_img_path = self.label_dir + '/val/' + file_name

            example = {}
            example["img_path"] = img_path
            example["label_img_path"] = label_img_path
            example["img_id"] = file_name 
            self.examples.append(example)

        self.num_examples = len(self.examples)


    def __getitem__(self, index):
        example = self.examples[index]
        greyscale = self.option.train_greyscale
        
        # These are used for evaluation of the LNTL and baseline methods
        val_only_greyscale = self.option.val_only_greyscale
        val_only_jitter = self.option.val_only_jitter

        img_id = example["img_id"]

        img_path = example["img_path"]

        if greyscale: # Do we need an option argument to use this? 
            grey_img =  cv2.imread(img_path, 0) # (shape: (1024, 2048))
            dummy_RGB_image = np.ndarray(shape=(grey_img.shape[0], grey_img.shape[1], 3), dtype= np.uint8) 
            dummy_RGB_image[:, :, 0] = grey_img[:, :]
            dummy_RGB_image[:, :, 1] = grey_img[:, :]
            dummy_RGB_image[:, :, 2] = grey_img[:, :]
            img = dummy_RGB_image  # (shape: (1024, 2048, 3)) grey values copied along all 3 channels for use in pretrained networks expecting 3 channel input
        else:
            img = cv2.imread(img_path, -1) # (shape: (1024, 2048, 3))


        ##### Functions for creating a bias testing image set #####
        ##### ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #####
        if val_only_greyscale: 
            grey_img =  cv2.imread(img_path, 0) # (shape: (1024, 2048))
            dummy_RGB_image = np.ndarray(shape=(grey_img.shape[0], grey_img.shape[1], 3), dtype= np.uint8) 
            dummy_RGB_image[:, :, 0] = grey_img[:, :]
            dummy_RGB_image[:, :, 1] = grey_img[:, :]
            dummy_RGB_image[:, :, 2] = grey_img[:, :]
            img = dummy_RGB_image  # (shape: (1024, 2048, 3)) grey values copied along all 3 channels for use in pretrained networks expecting 3 channel input
        else:
            img = cv2.imread(img_path, -1) # (shape: (1024, 2048, 3))

        if val_only_jitter:  #need to have al look at jitter settings
            img =  cv2.imread(img_path, -1) # (shape: (1024, 2048, 3))
            img = torch.from_numpy(np.transpose(img,(2, 0, 1))) # convert to torch tensor to use torch transform, and transpose to (channels, h, w)
            jitter_instance = T.ColorJitter( contrast=.9, saturation=.9, hue=.5)  #brightness=.5,
            jittered_imgs = jitter_instance(img) 
            jittered_imgs = jittered_imgs.detach().cpu().numpy() # convert back to numpy array
            img = np.transpose(jittered_imgs,(1, 2, 0)) # transpose (h, w, channels) for cv2
            cv2_imshow(img)
        else:
            img = cv2.imread(img_path, -1) # (shape: (1024, 2048, 3))

        if val_only_invert: 
            img =  cv2.imread(img_path, -1) # (shape: (1024, 2048))
            img = torch.from_numpy(np.transpose(img,(2, 0, 1))) # convert to torch tensor to use torch transform, and transpose BGR -> RGB
            inv_imgs = T.functional.invert(img) 
            inv_imgs = inv_imgs.detach().cpu().numpy() # convert back to numpy array
            img = np.transpose(inv_imgs,(1, 2, 0)) # transpose RGB -> BGR
            cv2_imshow(img)
        else:
            img = cv2.imread(img_path, -1) # (shape: (1024, 2048, 3))

        ##### ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #####
        ##### ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #####

        # resize img without interpolation (want the image to still match
        # label_img, which we resize below):
        img = cv2.resize(img, (self.new_img_w, self.new_img_h),
                         interpolation=cv2.INTER_NEAREST) # (shape: (512, 1024, 3))

        label_img_path = example["label_img_path"]
        label_img = cv2.imread(label_img_path, -1) # (shape: (1024, 2048))
        # resize label_img without interpolation (want the resulting image to
        # still only contain pixel values corresponding to an object class):
        label_img = cv2.resize(label_img, (self.new_img_w, self.new_img_h),
                               interpolation=cv2.INTER_NEAREST) # (shape: (512, 1024))

        # # # # # # # # debug visualization START
        # cv2.imshow("test", img)
        # cv2.waitKey(0)
        #
        # cv2.imshow("test", label_img)
        # cv2.waitKey(0)
        # # # # # # # # debug visualization END


        ## Getting bias labels...

        # 
        # bias_label = cv2.resize(bias_label, (256/16, 256/16),
        #                        interpolation=cv2.INTER_NEAREST) # Hard code for deeplab size for now...see note below. 
       

        # if network == 'SegNet':
        #     feat_map_height = 256/8 , feat_map_width = 256/8 # need updating
        # elif network == 'Deeplab':
        #     feat_map_height = 256/8 , feat_map_width = 256/8
        # else:
        #     print('Warning no information for expected feature map size of this Network')
        
        bias_label = img
        bias_label = cv2.resize(bias_label, (int(self.new_img_w/8), int(self.new_img_h/8)),  # by declaring the size this way is will match the input image coming into the bias fork wich is w/8 h/8 for resnet 18os8. 
                               interpolation=cv2.INTER_NEAREST)


        bias_label = torch.from_numpy(np.transpose(bias_label,(2, 0, 1)))
        #mask_image = torch.lt(bias_label.float()-0.00001, 0.) * 255   # Useful if any negative numbers appear
        bias_label = torch.div(bias_label,32)  # Change the division here for higher resolution. Divide by 16 would give us 256/16 = 16 bins instead of 256/32 = 8 bins 
        #bias_label = bias_label + mask_image
        bias_label = bias_label.long()
        #~~~~~~~~~~~~~~~~~~~~~~~~~

        # normalize the img (with the mean and std for the pretrained ResNet):
        img = img/255.0
        img = img - np.array([0.485, 0.456, 0.406])
        img = img/np.array([0.229, 0.224, 0.225]) # (shape: (512, 1024, 3))
        img = np.transpose(img, (2, 0, 1)) # (shape: (3, 512, 1024))
        img = img.astype(np.float32)

        # convert numpy -> torch:
        img = torch.from_numpy(img) # (shape: (3, 512, 1024))
        label_img = torch.from_numpy(label_img) # (shape: (512, 1024))

        return (img, bias_label, label_img, img_id)

    def __len__(self):
        return self.num_examples
