In [1]:
import os
import sys
import torch.utils.data as data
from PIL import Image

class CPN(data.Dataset):
    """
    Args:6
        root (string): Root directory of the VOC Dataset.
        datatype (string): Dataset type 
        image_set (string): Select the image_set to use, ``train``, ``val`` or ``test``
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        dver (str): version of dataset (ex) ``splits/v5/3``
        kfold (int): k-fold cross validation
    """

    def __init__(self, root, datatype='CPN', dver='splits', 
                    image_set='train', transform=None, is_rgb=True):

        self.transform = transform
        self.is_rgb = is_rgb

        image_dir = os.path.join(root, 'CPN_all', 'Images')
        mask_dir = os.path.join(root, 'CPN_all', 'Masks')
        m_image_dir = os.path.join(root, 'Median', 'Images')
        m_mask_dir = os.path.join(root, 'Median', 'Masks')

        if not os.path.exists(image_dir) or not os.path.exists(mask_dir):
            raise Exception('Dataset not found or corrupted.')
        
        split_f = os.path.join(root, 'CPN_all', dver, image_set.rstrip('\n') + '.txt')
        m_split_f = os.path.join(root, 'Median/splits', image_set.rstrip('\n') + '.txt')

        if not os.path.exists(split_f):
            raise Exception('Wrong image_set entered!' 
                            'Please use image_set="train" or image_set="val"', split_f)

        with open(os.path.join(split_f), "r") as f:
            file_names = [x.strip() for x in f.readlines()]
        with open(os.path.join(m_split_f), "r") as f:
            m_file_names = [x.strip() for x in f.readlines()]

        self.image = [os.path.join(image_dir, x + ".bmp") for x in file_names]
        self.mask = [os.path.join(mask_dir, x + "_mask.bmp") for x in file_names]
        self.m_images = [os.path.join(m_image_dir, x + ".jpg") for x in m_file_names]
        self.m_masks = [os.path.join(m_mask_dir, x + ".jpg") for x in m_file_names] 
    
        assert (len(self.image) == len(self.mask))  

        if image_set == 'train' or image_set == 'val':
            self.image.extend(self.m_images)
            self.mask.extend(self.m_masks)

        self.images = []
        self.masks = []

        for index in range(len(self.image)):
            if self.is_rgb:
                img = Image.open(self.image[index]).convert('RGB')
                target = Image.open(self.mask[index]).convert('L')         
            else:
                img = Image.open(self.image[index]).convert('L')
                target = Image.open(self.mask[index]).convert('L')            

            if self.transform is not None:
                img, target = self.transform(img, target)
            
            self.images.append(img)
            self.masks.append(target)
        
        print(f'Images len: {len(self.images)} shape: {self.images[0].shape}')

    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is the image segmentation.
        """

        return self.images[index], self.masks[index]

    def __len__(self):
        return len(self.images)    

In [2]:
from utils import ext_transforms as et
from torch.utils.data import DataLoader
from tqdm import tqdm

transform = et.ExtCompose([
        et.ExtRandomCrop(size=(512, 512), pad_if_needed=True),
        et.ExtScale(scale=0.5),
        et.ExtToTensor(),
        et.ExtNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

dst = CPN(root='/mnt/server5/sdi/datasets', datatype='CPN', image_set='train',
                transform=transform, is_rgb=True, dver='splits/v5/3')
train_loader = DataLoader(dst, batch_size=16,
                            shuffle=True, num_workers=2, drop_last=True)

for epoch in range(100):
    for i, (ims, lbls) in tqdm(enumerate(train_loader)):
        print(f'Epoch {epoch}')
        break
    

Images len: 1366 shape: torch.Size([3, 256, 256])


0it [00:00, ?it/s]

Epoch 0



0it [00:00, ?it/s]

Epoch 1



0it [00:00, ?it/s]

Epoch 2



0it [00:00, ?it/s]

Epoch 3



0it [00:00, ?it/s]

Epoch 4



0it [00:00, ?it/s]

Epoch 5



0it [00:00, ?it/s]

Epoch 6



0it [00:00, ?it/s]


Epoch 7


0it [00:00, ?it/s]

Epoch 8



0it [00:00, ?it/s]

Epoch 9



0it [00:00, ?it/s]

Epoch 10



0it [00:00, ?it/s]

Epoch 11



0it [00:00, ?it/s]

Epoch 12



0it [00:00, ?it/s]

Epoch 13



0it [00:00, ?it/s]

Epoch 14



0it [00:00, ?it/s]

Epoch 15



0it [00:00, ?it/s]


Epoch 16


0it [00:00, ?it/s]


Epoch 17


0it [00:00, ?it/s]


Epoch 18


0it [00:00, ?it/s]

Epoch 19



0it [00:00, ?it/s]

Epoch 20



0it [00:00, ?it/s]

Epoch 21



0it [00:00, ?it/s]

Epoch 22



0it [00:00, ?it/s]

Epoch 23



0it [00:00, ?it/s]

Epoch 24



0it [00:00, ?it/s]


Epoch 25


0it [00:00, ?it/s]

Epoch 26



0it [00:00, ?it/s]

Epoch 27



0it [00:00, ?it/s]

Epoch 28



0it [00:00, ?it/s]


Epoch 29


0it [00:00, ?it/s]


Epoch 30


0it [00:00, ?it/s]

Epoch 31



0it [00:00, ?it/s]


Epoch 32


0it [00:00, ?it/s]


Epoch 33


0it [00:00, ?it/s]


Epoch 34


0it [00:00, ?it/s]


Epoch 35


0it [00:00, ?it/s]

Epoch 36



0it [00:00, ?it/s]

Epoch 37



0it [00:00, ?it/s]


Epoch 38


0it [00:00, ?it/s]


Epoch 39


0it [00:00, ?it/s]


Epoch 40


0it [00:00, ?it/s]


Epoch 41


0it [00:00, ?it/s]


Epoch 42


0it [00:00, ?it/s]


Epoch 43


0it [00:00, ?it/s]


Epoch 44


0it [00:00, ?it/s]


Epoch 45


0it [00:00, ?it/s]


Epoch 46


0it [00:00, ?it/s]


Epoch 47


0it [00:00, ?it/s]


Epoch 48


0it [00:00, ?it/s]


Epoch 49


0it [00:00, ?it/s]


Epoch 50


0it [00:00, ?it/s]


Epoch 51


0it [00:00, ?it/s]


Epoch 52


0it [00:00, ?it/s]


Epoch 53


0it [00:00, ?it/s]


Epoch 54


0it [00:00, ?it/s]


Epoch 55


0it [00:00, ?it/s]


Epoch 56


0it [00:00, ?it/s]


Epoch 57


0it [00:00, ?it/s]


Epoch 58


0it [00:00, ?it/s]


Epoch 59


0it [00:00, ?it/s]


Epoch 60


0it [00:00, ?it/s]


Epoch 61


0it [00:00, ?it/s]


Epoch 62


0it [00:00, ?it/s]


Epoch 63


0it [00:00, ?it/s]


Epoch 64


0it [00:00, ?it/s]


Epoch 65


0it [00:00, ?it/s]


Epoch 66


0it [00:00, ?it/s]


Epoch 67


0it [00:00, ?it/s]


Epoch 68


0it [00:00, ?it/s]


Epoch 69


0it [00:00, ?it/s]


Epoch 70


0it [00:00, ?it/s]


Epoch 71


0it [00:00, ?it/s]


Epoch 72


0it [00:00, ?it/s]


Epoch 73


0it [00:00, ?it/s]


Epoch 74


0it [00:00, ?it/s]


Epoch 75


0it [00:00, ?it/s]


Epoch 76


0it [00:00, ?it/s]


Epoch 77


0it [00:00, ?it/s]


Epoch 78


0it [00:00, ?it/s]


Epoch 79


0it [00:00, ?it/s]


Epoch 80


0it [00:00, ?it/s]


Epoch 81


0it [00:00, ?it/s]


Epoch 82


0it [00:00, ?it/s]


Epoch 83


0it [00:00, ?it/s]


Epoch 84


0it [00:00, ?it/s]


Epoch 85


0it [00:00, ?it/s]


Epoch 86


0it [00:00, ?it/s]


Epoch 87


0it [00:00, ?it/s]

Epoch 88



0it [00:00, ?it/s]


Epoch 89


0it [00:00, ?it/s]


Epoch 90


0it [00:00, ?it/s]


Epoch 91


0it [00:00, ?it/s]


Epoch 92


0it [00:00, ?it/s]


Epoch 93


0it [00:00, ?it/s]


Epoch 94


0it [00:00, ?it/s]


Epoch 95


0it [00:00, ?it/s]


Epoch 96


0it [00:00, ?it/s]


Epoch 97


0it [00:00, ?it/s]


Epoch 98


0it [00:00, ?it/s]


Epoch 99
