In [32]:
import os
from PIL import Image 
from IPython.display import display
import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms as tfs
from torch.utils.data import Dataset
import torch
from torch.utils.data import DataLoader
%matplotlib inline

In [20]:
root = './VOCdevkit/VOC2007'  
def read_list(train=True):
    root = './VOCdevkit/VOC2007'    
    if train:
        txt_name = root + '/ImageSets/Segmentation/train.txt'
        with open(txt_name, 'r') as reader:
            image_name_list = reader.read().split()
        train_data = [os.path.join(root, 'JPEGImages', i+'.jpg') for i in image_name_list]
        label = [os.path.join(root, 'SegmentationClass', i+'.png') for i in image_name_list]
        return train_data, label
    else:
        txt_name = root + '/ImageSets/Segmentation/val.txt'
        with open(txt_name, 'r') as reader:
            image_name_list = reader.read().split()
        test_data = [os.path.join(root, 'JPEGImages', i+'.jpg') for i in image_name_list]
        label = [os.path.join(root, 'SegmentationClass', i+'.png') for i in image_name_list]
        return test_data, label

In [21]:
tempfilename_list = os.listdir(root + '/ImageSets/Main')
for i, name in enumerate(tempfilename_list):
    tempfilename_list[i] = name[:name.find('_')]
classList = np.array([name for name in tempfilename_list if 'tx' not in name])
classList = np.append(['background'], np.unique(classList))

In [22]:
colormap = [[0,0,0],[128,0,0],[0,128,0], [128,128,0], [0,0,128],
            [128,0,128],[0,128,128],[128,128,128],[64,0,0],[192,0,0],
            [64,128,0],[192,128,0],[64,0,128],[192,0,128],
            [64,128,128],[192,128,128],[0,64,0],[128,64,0],
            [0,192,0],[128,192,0],[0,64,128]]

In [23]:
cm2lbl = np.zeros(256**3) 
for i,cm in enumerate(colormap):
    cm2lbl[(cm[0]*256+cm[1])*256+cm[2]] = i 

def image2label(im):
    data = np.array(im, dtype='int32')
    idx = (data[:, :, 0] * 256 + data[:, :, 1]) * 256 + data[:, :, 2]
    return np.array(cm2lbl[idx], dtype='int64')

In [24]:
def img_rndCrop(data, label, height, width):
    data, rectangle = tfs.RandomCrop((height, width))(data)
    label = tfs.FixedCrop(*rectangle)(label)
    return data, label

In [25]:
def img_transform(data, label, cropsize):
    data, label = rand_Crop(data, label, cropsize)
    tf_pipeline = tfs.Compose([tfs.ToTensor(),
                          tfs.Normalize([0.485, 0.456, 0.406],
                                        [0.229, 0.224, 0.225])])
    data = tf_pipeline(data)
    label = image2label(label)
    label = torch.from_numpy(label)
    return data, label

In [26]:
def rand_Crop(img, label, crop_size):
    # crop
    start_x = np.random.randint(low=0,high=(img.size[0]-crop_size[0]+1))
    end_x = start_x + crop_size[0]
    start_y = np.random.randint(low=0,high=(img.size[1]-crop_size[1]+1))
    end_y = start_y + crop_size[1]
    crop = (start_y,start_x,end_y,end_x) # y1,x1,y2,x2
    img = img.crop(crop)
    label = label.crop(crop)
    return(img, label)

In [27]:
crop_size = [200,300]

In [28]:
class VOCSegDataset(Dataset):
    '''
    voc dataset
    '''
    def __init__(self, trainflg, crop_size, transforms):
        self.crop_size = crop_size
        self.transforms = transforms
        data_list, label_list = read_list(train=trainflg)
        self.data_list = self._filter(data_list)
        self.label_list = self._filter(label_list)
        print('Read ' + str(len(self.data_list)) + ' images')
        
    def _filter(self, images):
        return [im for im in images if (Image.open(im).size[0] >= self.crop_size[0] and 
                                        Image.open(im).size[1] >= self.crop_size[1])]
        
    def __getitem__(self, idx):
        img = self.data_list[idx]
        label = self.label_list[idx]
        img = Image.open(img)
        label = Image.open(label).convert('RGB')
        img, label = self.transforms(img, label, self.crop_size)
        return img, label
    
    def __len__(self):
        return len(self.data_list)

In [29]:
train_data = VOCSegDataset(True, [230,280], img_transform)

Read 207 images


In [30]:
test_data = VOCSegDataset(False, [230, 280], img_transform)

Read 208 images


In [33]:
train_d = DataLoader(train_data, 64, shuffle=True, num_workers=4)
val_d = DataLoader(test_data, 128, shuffle=True, num_workers=4)

<__main__.VOCSegDataset at 0x113946790>