In [120]:
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils import data
import bdd_utils
from PIL import Image
from utils.pad_collate import PadCollate
from utils._utils import get_config

import os, sys
import argparse
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import numpy as np
import json

sys.path.append(".."),

def __deterministic_worker_init_fn(worker_id, seed=0):
    import random
    import numpy
    random.seed(seed)
    numpy.random.seed(seed)
    torch.manual_seed(seed)

def get_train_labels(labels):
    label_map = {label[0]: label[2] for label in labels }
    return label_map

class CroppedDataset(data.Dataset):
    def __init__(self, root="/home/ege/datasets", split="train", dataset_size='100k', noisy=False, alterations={}, img_size=(720 , 1280), val_percent=0.2, test_percent=0.1, resize=False, crop=False, crop_size=None):
        self.data = []
        self.img_size = img_size
        self.resize = resize
        self.crop = crop
        self.crop_size = crop_size
        self.resize = alterations['resize']
        self.jitter = alterations['jitter']
        self.normalize = alterations['normalize']
        self.split = split
        self.noisy = noisy
        
        if self.noisy:
            self.gauss_mean = alterations['gaussian_noise'][0]
            self.gauss_std = alterations['gaussian_noise'][1]
        self.label_map = get_train_labels(michael_labels)

        # Use 2871_labels to get labels
        if self.split == "train":
            labels_file = root + "2871_labels.json"
            new_split = "train"
        else:
            labels_file = root + "2871_labels.json"
            new_split = "val"
        
        with open(labels_file, 'r') as f:
            labels = json.load(f)
            

        # standard normalization for the pretrained networks
        if 'train' in self.split:
            self.color_transforms = transforms.Compose([
                transforms.ColorJitter(brightness=self.jitter['brightness'], contrast=self.jitter['contrast'], saturation=self.jitter['saturation'], hue=self.jitter['hue']),
            ])
            self.transforms = transforms.Compose([
                #transforms.ToTensor(),
                transforms.Normalize(mean=self.normalize['mean'], std=self.normalize['std'])
            ])
        else:
            self.transforms = transforms.Compose([
                #transforms.ToTensor(),
                transforms.Normalize(mean=self.normalize['mean'], std=self.normalize['std'])])

        images_folder = root + "michael_processed"
        images_list = []
        
        # Adds all labeled images to image list
        images_list = [key.replace("michael/", "ege/") for key in labels]
        label_list = [label for key, label in labels.items()]
                                   
        #images_list = images_list[:30000] #30000
        for i in range(2000):
            while True:
                is_repeat = False
                rand_pair = np.random.choice(np.arange(len(images_list)), size=2, replace=False)
                
                for pair in self.data:
                    if pair['image_1'] == images_list[rand_pair[0]] and pair['image_2'] == images_list[rand_pair[1]]:
                        is_repeat = True
                
                if is_repeat:
                    pass
                else:
                    self.data.append({
                        'image_1': images_list[rand_pair[0]],
                        'image_2': images_list[rand_pair[1]],
                        'label_1': label_list[rand_pair[0]],
                        'label_2': label_list[rand_pair[1]]
                    })
                    break
        
        #print(len(self.data))
        np.random.shuffle(self.data)

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self,index):
        #info = []
        # Get raw images and labels
        img_1 = Image.open(self.data[index]['image_1'])
        img_2 = Image.open(self.data[index]['image_2'])
        label_1 = self.data[index]["label_1"]
        label_2 = self.data[index]["label_2"]

        width_1, height_1 = img_1.size
        width_2, height_2 = img_2.size
        
        #info.append("Rawmage 1 " + str(img_1.size))
        #info.append("Rawmage 2 " + str(img_2.size))
        
        # Adjust heights of images
        if self.crop:
            img_1 = img_1.resize((int(width_1 * (160 / height_1)), 160), resample=Image.NEAREST, box=None)
            img_2 = img_2.resize((int(width_2 * (160 / height_2)), 160), resample=Image.NEAREST, box=None)

        #noise_img = noise_img[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]]
        #info.append("Image 1 " + str(img_1.size))
        #info.append("Image 2 " + str(img_2.size))
        
        width_1, height_1 = img_1.size
        width_2, height_2 = img_2.size
        
        if self.crop:
            img_1 = img_1.crop((0, 0, width_1 - width_1 % 32, height_1))
            img_2 = img_2.crop((0, 0, width_2 - width_2 % 32, height_2))
            
        #info.append("Crop 1 " + str(img_1.size))
        #info.append("Crop 2 " + str(img_2.size))

        #print(info)

        # Convert images to nparrays
        img_1 = np.array(img_1).astype(np.float32)
        img_2 = np.array(img_2).astype(np.float32)

        # Convert images to tensors
        img_1 = torch.from_numpy(img_1).float()
        img_2 = torch.from_numpy(img_2).float()
        
        # Convert labels to numbers
        try:
            label_1 = [el.id for el in michael_labels if el.name == label_1][0]
            label_2 = [el.id for el in michael_labels if el.name == label_2][0]
        except IndexError:
            print(label_1, label_2)
        # Convert labels to arrays
        label_1 = (np.array([label_1]) == np.arange(50)).astype(np.int32)
        label_2 = (np.array([label_2]) == np.arange(50)).astype(np.int32)
        
        img_1 = np.transpose(img_1, (2, 0, 1))
        img_2 = np.transpose(img_2, (2, 0, 1))

        # Convert labels to tensors
        label_1 = torch.from_numpy(label_1).float()
        label_2 = torch.from_numpy(label_2).float()
        
        return img_1, img_2, label_1, label_2
        #return self.transforms(noise_img), self.transforms(gt_img), annotation

def load_data(dataset, batch_size, num_workers, split='train', deterministic=False, shuffle=False):
    """
    Load the denoise dataset.
    """


    #idx = dataset.indx
    #sampler = SubsetRandomSampler(idx)

    worker_init_fn = __deterministic_worker_init_fn if deterministic else None

    loader = data.DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle,
                                        pin_memory=True, worker_init_fn=worker_init_fn, collate_fn=PadCollate(dim=0))

    return loader



In [121]:
parser = argparse.ArgumentParser(description='Config file')
#parser.add_argument('--config', nargs='?', type=str, default='../configs/label-match.yaml', help='Specify yaml config file to use')
#args = parser.parse_args()
#config = get_config(args.config)

config = get_config('../config/train/label-match.yaml')
train_config = config['train_dataloaders'][0]
dataset = CroppedDataset(split=train_config['split'], root=train_config['root'], img_size=train_config['img_size'], alterations=train_config['alterations'], crop=train_config['crop'], crop_size=train_config['crop_size'])
# TODO: debug dataset initialization
#loader = LoaderWrapper(noisyLoader, batch_size=train_config['batch_size'])
trainloader = data.DataLoader(dataset, batch_size=train_config['batch_size'], num_workers=train_config['num_workers'], shuffle=train_config['shuffle'])

for i, (image_1, image_2, label_1, label_2) in enumerate(trainloader):
    #print(i, image_1, image_2, label_1, label_2)
    if i < 10:
        print(image_1.shape, ";", label_1)
        print(image_2.shape, ";", label_2)
        


torch.Size([1, 3, 160, 32]) ; tensor([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
torch.Size([1, 3, 160, 128]) ; tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
torch.Size([1, 3, 160, 96]) ; tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
torch.Size([1, 3, 160, 64]) ; tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0

In [78]:
np.arange(32)

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])

In [8]:
(np.array([2]) == np.arange(32)).astype(np.uint32)

array([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=uint32)

In [2]:
michael_labels

[Label(name='red_light', id=0, trainId=0, category='void', categoryId=0, hasInstances=False, ignoreInEval=False, color=(0, 0, 0)),
 Label(name='black_person', id=1, trainId=1, category='void', categoryId=0, hasInstances=False, ignoreInEval=False, color=(0, 0, 0)),
 Label(name='red_person', id=2, trainId=2, category='void', categoryId=0, hasInstances=False, ignoreInEval=False, color=(0, 0, 0)),
 Label(name='grey_car', id=3, trainId=3, category='void', categoryId=0, hasInstances=False, ignoreInEval=False, color=(0, 0, 0)),
 Label(name='grey_suv', id=4, trainId=4, category='void', categoryId=0, hasInstances=False, ignoreInEval=False, color=(0, 0, 0)),
 Label(name='green_light', id=5, trainId=5, category='void', categoryId=0, hasInstances=False, ignoreInEval=False, color=(0, 0, 0)),
 Label(name='white_truck', id=6, trainId=6, category='void', categoryId=0, hasInstances=False, ignoreInEval=False, color=(0, 0, 0)),
 Label(name='black_car', id=7, trainId=7, category='void', categoryId=0, hasI

In [6]:
[el.id for el in michael_labels if el.name == 'white_truck'][0]

6