In [1]:
import os
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import cv2
import albumentations as A

In [2]:
def visualize_mask(image, mask):
    CLASSES = ['background', 'square', 'circle', 'triangle', 'star']

    fig, ax = plt.subplots(nrows=1, ncols=6, figsize=(10, 10))
    ax[0].set_title('Ori Image')
    ax[0].imshow(image)
    ax[0].axis('off')
    
    for i in range(1, 6, 1):
        ax[i].set_title(CLASSES[i - 1])
        ax[i].imshow(mask[:, :, i - 1])
        ax[i].axis('off')

In [3]:
imglist = [img.split('.')[0] for img in os.listdir('images')]
print(len(imglist))


241


In [12]:
class Dataset:
    CLASSES = ['background', 'square', 'circle', 'triangle', 'star']
    # imgs_dir: folder of images
    # masks_dir: folder of masks
    # classes: classes we use
    
    def __init__(self, imgs_dir, masks_dir, classes, augmentation=True, preprocessing=True):
        self.images_id = [img.split('.')[0] for img in os.listdir(imgs_dir)]
        if '' in self.images_id:
            self.images_id.remove('')
        self.images_path = [os.path.join(imgs_dir, img_id) + '.jpg' for img_id in self.images_id]
        self.masks_path = [os.path.join(masks_dir, img_id) + '.png' for img_id in self.images_id]
        self.images = []
        self.masks = []
        self.class_value = [self.CLASSES.index(cls.lower()) for cls in classes]
        self.augmentation = augmentation
        self.preprocessing = preprocessing
        for i in range(len(self.images_id)):
            image = cv2.imread(self.images_path[i])
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            ori_mask = cv2.imread(self.masks_path[i])
            if ori_mask is not None:
                ori_mask = cv2.cvtColor(ori_mask, cv2.COLOR_BGR2RGB)
            else:
                ori_mask = np.zeros((256, 256, 3))
            mask = np.zeros((ori_mask.shape[0], ori_mask.shape[1]), dtype=np.uint8)
            mask[np.all(ori_mask == self.square, axis=-1)] = 1
            mask[np.all(ori_mask == self.circle, axis=-1)] = 2
            mask[np.all(ori_mask == self.triangle, axis=-1)] = 3
            mask[np.all(ori_mask == self.star, axis=-1)] = 4
            mask[np.all(ori_mask == self.background, axis=-1)] = 0
            masks = [(mask == v) for v in self.class_value]
            mask_output = np.stack(masks, axis=-1).astype(np.float32)
            self.images.append(image)
            self.masks.append(mask_output)
        if self.augmentation:
            transform = A.Compose([
                A.HorizontalFlip(p=0.5),
                A.VerticalFlip(p=0.5),
                A.RandomRotate90(),
                A.Perspective(p=0.8),
                A.Transpose(p=0.5),
                A.RandomBrightnessContrast((-0.2, 0.2), p=0.5)
            ])
            for i in range(len(self.images_id)):
                for j in range(5):
                    transformed1 = transform(image=self.images[i], mask=self.masks[i])
                    self.images.append(transformed1['image'])
                    self.masks.append(transformed1['mask'])

    def __getitem__(self, index):
        image = self.images[index]
        if self.preprocessing:
            image = image.astype(np.float32)
            image = image / 255.0
        return image, self.masks[index]


    def __len__(self):
        return len(self.images)


In [14]:
cls = ['background', 'square', 'circle', 'triangle', 'star']

train_data = Dataset('images', 'annotation', cls, True, True)

In [15]:
class DataLoader(tf.keras.utils.Sequence):
    """
    dataset:
    batch_size:
    shuffle: shuffle data or not
    """
    def __init__(self, dataset, batch_size=1, shuffle=True):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.on_epoch_end()
    def __getitem__(self, index):
        x = []
        y = []
        indexes = self.indexes[index * self.batch_size: (index + 1) * self.batch_size]
        for j in indexes:
            x.append(self.dataset[j][0])
            y.append(self.dataset[j][1])
        X = np.stack(x, axis=0)
        Y = np.stack(y, axis=0)
        batch = (X, Y)
        return batch
    def __len__(self):
        return int(np.ceil(len(self.dataset) / self.batch_size))
    def on_epoch_end(self):
        self.indexes = np.arange(len(self.dataset))
        if self.shuffle:
            np.random.shuffle(self.indexes)


In [8]:
dataloader = DataLoader(train_data, 16, False)