<a href="https://colab.research.google.com/github/ALYAMBR/Notebooks/blob/master/augmentations_image_segmentation_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Sketches for image augmentations in segmentation problem.

In [None]:
import numpy as np
from skimage import io, transform
from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.utils.data as data
from torchvision.transforms import RandomRotation
import torchvision.transforms.functional as TF
import random
import torch.optim as optim
import torch.nn.functional as F

In [None]:
class ToTensor(object):
    def __call__(self, image, mask):
        im_tens = torch.IntTensor(image).permute(2, 1, 0)
        ms_tens = torch.IntTensor(mask).permute(2, 1, 0)
        return im_tens, ms_tens

In [None]:
class RandRot(object):
    def __init__(self, maxdeg=45):
        self.max_deg = maxdeg
    
    def __call__(self, image, mask):
        rot_angle = random.randint(-self.max_deg, self.max_deg)
        rot_image = TF.rotate(image, angle=rot_angle)
        rot_mask = TF.rotate(mask, angle=rot_angle)
        return rot_image, rot_mask

In [None]:
class RandShear(object):
    def __init__(self, x_maxdeg=20, y_maxdeg=20):
        self.x_max_deg = x_maxdeg
        self.y_max_deg = y_maxdeg

    def __call__(self, image, mask):
        x_deg = random.randint(-self.x_max_deg, self.x_max_deg)
        y_deg = random.randint(-self.y_max_deg, self.y_max_deg)
        shear_image = TF.affine(image, 0, [0, 0], 1,  shear=[x_deg, y_deg])
        shear_mask = TF.affine(mask, 0, [0, 0], 1, shear=[x_deg, y_deg])
        return shear_image, shear_mask

In [None]:
class RandCenterCrop(object):
    def __init__(self, max_crop_percent = 0.2):
        self.max_crop_coef = max_crop_percent

    def __call__(self, image, mask):
        im_h = image.shape[1]
        im_w = image.shape[2]
        ms_h = mask.shape[1]
        ms_w = mask.shape[2]
        crop_coef = 1 - random.uniform(0, self.max_crop_coef)
        croped_im = TF.center_crop(image, [int(im_h * crop_coef), int(im_w * crop_coef)])
        croped_ms = TF.center_crop(mask, [int(ms_h * crop_coef), int(ms_w * crop_coef)])
        return croped_im, croped_ms

In [None]:
class RandPadding(object):
    def __init__(self, max_pad_percent = 0.3):
        self.max_pad_coef = max_pad_percent

    def __call__(self, image, mask):
        coef = random.uniform(0, self.max_pad_coef)
        h_coef = random.uniform(0, 1)
        w_coef = random.uniform(0, 1)
        rh_coef = 1 - h_coef
        rw_coef = 1 - w_coef
        im_h = image.shape[1] * coef
        im_w = image.shape[2] * coef
        ms_h = mask.shape[1] * coef
        ms_w = mask.shape[2] * coef
        pad_im = TF.pad(image, [int(im_w * w_coef),
                                int(im_h * h_coef),
                                int(im_w * rw_coef),
                                int(im_h * rh_coef)])
        pad_ms = TF.pad(mask, [int(ms_w * w_coef),
                               int(ms_h * h_coef),
                               int(ms_w * rw_coef),
                               int(ms_h * rh_coef)])
        return pad_im, pad_ms

In [None]:
class RandColorTransform(object):
    def __init__(self, max_bright=0.3, max_contrast=0.3, max_satur=0.3, max_sharp=0.7):
        self.bright_max = max_bright
        self.contrast_max = max_contrast
        self.satur_max = max_satur
        self.sharp_max = max_sharp

    def __call__(self, image, mask):
        # mask doesn't need these transformations!
        bright = 1 + random.uniform(-self.bright_max, self.bright_max)
        contrast = 1 + random.uniform(-self.contrast_max, self.contrast_max)
        satur = 1 + random.uniform(-self.satur_max, self.satur_max)
        sharp = 1 + random.uniform(-self.sharp_max, self.sharp_max)

        bright_image = TF.adjust_brightness(image, bright)
        contrast_image = TF.adjust_contrast(bright_image, contrast)
        satur_image = TF.adjust_saturation(contrast_image, satur)
        sharp_image = TF.adjust_sharpness(satur_image, sharp)

        return sharp_image, mask

In [None]:
class SampleResize(object):
    def __init__(self, width=512, height=512):
        self.width = width
        self.height = height

    def __call__(self, image, mask):
        image = TF.resize(image, [self.height, self.width])
        mask = TF.resize(mask, [self.height, self.width])
        return image, mask

In [None]:
class TensorToNumpy(object):
    def __call__(self, image, mask):
        # im_np = image.permute(2, 1, 0).numpy()
        # ms_np = mask.permute(2, 1, 0).numpy()
        im_np = image.numpy()
        ms_np = mask.numpy()
        return im_np, ms_np

In [None]:
class Normalizer(object):
    def __call__(self, image, mask):
        # mask must be int type with vals 0 and 1 only
        image = image.astype(float) / 255.0
        return image, mask

In [None]:
class DataLoaderSegmentation(data.Dataset):
    def __init__(self, folder_path, names_list, transform=False):
        super(DataLoaderSegmentation, self).__init__()
        self.img_files = glob(os.path.join(folder_path,'images','*'))
        temp_img_files = self.img_files.copy()
        for img_path in temp_img_files:
            if not img_path.split(sep='/')[-1] in names_list:
                self.img_files.remove(img_path)
        self.mask_files = []
        for img_path in self.img_files:
            self.mask_files.append(os.path.join(folder_path, 'masks',
                                                os.path.basename(img_path)))
        self.to_tensor = ToTensor()
        self.rand_rot = RandRot()
        self.rand_shear = RandShear()
        self.rand_center_crop = RandCenterCrop()
        self.rand_pad = RandPadding()
        self.rand_color_trans = RandColorTransform()
        self.tensor_to_numpy = TensorToNumpy()
        self.resize = SampleResize()
        self.normalize = Normalizer()
        self.transform = transform
    
    def __getitem__(self, index):
        if type(index) == int:
            index = [index]
        batch_X, batch_Y =[], []
        for i in index:
            img_path = self.img_files[i]
            mask_path = self.mask_files[i]
            img = io.imread(img_path)
            mask = io.imread(mask_path)

            mask = (mask > 100) # True на месте цветных пикселей (маски), False на фоне
            mask = mask.astype(int) # 1 на месте масок и 0 на месте фона

            img, mask = self.to_tensor(img, mask) # Перевод в виде, 
                                                # удобный для аугментаций (C, H, W)

            if self.transform == True:
                img, mask = self.rand_rot(img, mask)
                img, mask = self.rand_shear(img, mask)
                img, mask = self.rand_center_crop(img, mask)
                img, mask = self.rand_pad(img, mask)
                img, mask = self.rand_color_trans(img, mask)

            img, mask = self.resize(img, mask)
            img, mask = self.tensor_to_numpy(img, mask) # Перевод в массив numpy (W, H, C)
            img, mask = self.normalize(img, mask)
            batch_X.append(img)
            batch_Y.append(mask)
        batch_X = np.array(batch_X)
        batch_Y = np.array(batch_Y)
        return torch.from_numpy(batch_X).float(), torch.from_numpy(batch_Y).float()

    def __len__(self):
        return len(self.img_files)

    def get_len(self):
        return len(self.img_files)