In [1]:
import os
import numpy as np
import pandas as pd
import cv2
import tifffile
import torch
import matplotlib.pyplot as plt
import torchvision.transforms.functional as F
import torchvision.transforms as transforms
import random as r

  from .autonotebook import tqdm as notebook_tqdm


Image pour le sanity check des fonctions

In [2]:
image_dir = '../data/train_images'
mask_dir = '../data/train_masks'

image_id = '10044'

image_path = os.path.join(image_dir, '{}.tiff'.format(image_id))
mask_path = os.path.join(mask_dir, '{}.tiff'.format(image_id))    
image = tifffile.imread(image_path)
mask = tifffile.imread(mask_path)
image = torch.tensor(image)
mask = torch.tensor(mask)


Fonction  d'affichage

In [3]:
def show_image(item,cmaps='coolwarm_r'):
    image, mask = item
    fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(16, 32))
    hybr = image[:, :,0]/2 + mask[:, :]

    ax[0].imshow(image)
    ax[0].axis('off')
    ax[0].set_title('IMAGE')
    ax[1].imshow(hybr,cmap=cmaps)
    ax[1].axis('off')
    ax[1].set_title('MASK ON IMAGE')
    plt.show()   

In [4]:
#show_image((image,mask))

Flip

In [5]:
class RandomFlip(object):


    def __call__(self, item):
        image, mask = item
        if r.random() > 0.5:
            return item
        image = image.flip(dims=0)
        mask = mask.flip(dims=0)
        return (image,mask)


#show_image(RandomFlip()((image,mask)))

Rotation de 90

In [6]:
class RandomRotation(object):

    def angle(self):
        #Sens de rotation
        return r.uniform(-180,180)
    
    def __call__(self,item):
        image,mask = item
        angle = self.angle()
        image = image.permute(2,0,1)
        image = F.rotate(image,angle=angle)
        image = image.permute(1,2,0)
        mask = F.rotate(mask[None,:],angle)
        return (image,mask.squeeze())
    
#show_image(RandomRotation()((image,mask)))

Saturation, Contrast, Brightness

In [7]:
class CustomColorJitter(object):

    def __init__(self,brightness = 0.3, hue = 0.3, saturation = 0.3):
        self.brightness = brightness
        self.hue = hue
        self.saturation = saturation

    def __call__(self,item):
        image, mask = item
        image = image.permute(2,0,1)
        image = transforms.ColorJitter(brightness=self.brightness,hue=self.hue,saturation=self.saturation)(image)
        image = image.permute(1,2,0)
        return (image,mask)

#show_image(CustomColorJitter()((image,mask)))
        

Blur

In [8]:
class RandomBlur(object):

    def __init__(self, kernel_size=25, blurred_ratio = 0.2):
        self.kernel_size = kernel_size
        self.blurred_ratio = blurred_ratio

    def __call__(self,item):
        image, mask = item
        if r.random() > self.blurred_ratio:
            return item
        image = image.permute(2,0,1)
        image = F.gaussian_blur(image,kernel_size=self.kernel_size)
        image = image.permute(1,2,0)
        return (image,mask)

#show_image(RandomBlur()((image,mask)))

Toutes les transformations d'un coup

In [9]:
data_transform = transforms.Compose([
    RandomFlip(),
    RandomRotation(),
    CustomColorJitter(),
    RandomBlur()
])

#show_image(data_transform((image,mask)))