In [3]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset
from imageio import imread
from torchvision.transforms.functional import to_pil_image
import matplotlib.pyplot as plt
from torchvision import transforms as T


class CCPDataset(Dataset):

    def __init__(self, df, patch_size=224, transforms=None, mode='train', class_dict_path='clothes/class_dict.csv'):
        super(CCPDataset, self).__init__()
        self.df = df
        self.ps = patch_size
        self.transforms = transforms
        self.mode = mode

        class_df = pd.read_csv(class_dict_path)

        r = class_df['r'].to_numpy(dtype=np.uint32)
        g = class_df['g'].to_numpy(dtype=np.uint32)
        b = class_df['b'].to_numpy(dtype=np.uint32)

        keys = (r << 16) | (g << 8) | b

        self.color_to_class = dict(zip(keys.tolist(), range(len(keys))))
        self.num_classes = len(self.color_to_class)

    def mask_rgb_to_ids(self, mask):
        packed = (mask[...,0].astype(np.uint32) << 16) | \
                 (mask[...,1].astype(np.uint32) << 8)  | \
                  mask[...,2].astype(np.uint32)

        mapped = np.vectorize(self.color_to_class.get)(packed, 0)

        return mapped.astype(np.int64)


    def __getitem__(self, idx):
        r = self.df.iloc[idx]

        x = imread(r.image_path)
        y = imread(r.mask_path)

        if self.mode == 'eval':
            x = x[r.coords[0]:r.coords[0]+self.ps, r.coords[1]:r.coords[1]+self.ps]
            y = y[r.coords[0]:r.coords[0]+self.ps, r.coords[1]:r.coords[1]+self.ps]


        if y.ndim == 3 and y.shape[2] == 3:
            y = self.mask_rgb_to_ids(y)
        else:
            y = y.astype(np.int64)

        x = to_pil_image(x.astype('float32'))
        y = to_pil_image(y.astype('uint8'), mode='L')

        if self.transforms:
            x, y = self.transforms[0](x, y)
            x = self.transforms[1](x)

        y = torch.from_numpy(np.array(y, dtype=np.int64))

        return x, y


    def __len__(self):
        return self.df.shape[0]


    def __show_item__(self, x, y):
        f, ax = plt.subplots(1, 3, figsize=(15, 5))

        ax[0].imshow(x.permute(1, 2, 0))
        ax[0].set_title('Image')
        ax[0].set_xticks([])
        ax[0].set_yticks([])

        ax[1].imshow(y, cmap='tab20')
        ax[1].set_title('Mask')
        ax[1].set_xticks([])
        ax[1].set_yticks([])

        ax[2].imshow(x.permute(1, 2, 0))
        ax[2].imshow(y, alpha=.5, cmap='tab20', vmin=0, vmax=self.num_classes-1)
        ax[2].set_title('Overlay')
        ax[2].set_xticks([])
        ax[2].set_yticks([])

        plt.show()

In [2]:
image_and_mask_transforms = T.Compose([
    T.RandomCrop((224, 224)),
    T.Resize(size=(256, 256)),
])

image_only_transforms = T.Compose([
    T.ToImage(),                               
    T.ToDtype(torch.float32, scale=True),      
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]), 
])
ds = CCPDataset(df, patch_size=224, transforms=(image_and_mask_transforms, image_only_transforms), mode='train')

rndm_idx = torch.randint(ds.__len__(), [1]).item()
x, y = ds.__getitem__(rndm_idx)
print(y.dtype)
ds.__show_item__(x, y)

NameError: name 'T' is not defined