In [None]:
import os
import numpy as np
import pandas as pd
import cv2
import tifffile
import torch
import matplotlib.pyplot as plt
import torchvision.transforms.functional as F
import torchvision.transforms as transforms
import random as r
from dataset import CustomDataset

Image pour le sanity check des fonctions

In [None]:
root_dir = '../data/'

image_id = 62

dataset = CustomDataset(root_dir, 1024)

image, _, mask = dataset[image_id]


Fonction  d'affichage

In [None]:
def show_image(item,cmaps='coolwarm_r'):
    image, mask = item
    fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(16, 32))
    hybr = image[0, :, :]/2 + mask[:, :]

    ax[0].imshow(image.permute(1,2,0))
    ax[0].axis('off')
    ax[0].set_title('IMAGE')
    ax[1].imshow(hybr.permute(1,2,0),cmap=cmaps)
    ax[1].axis('off')
    ax[1].set_title('MASK ON IMAGE')
    plt.show()

In [None]:
show_image((image,mask))

Flip

In [None]:
class RandomHorizontalFlip(object):


    def __call__(self, item):
        image, mask = item
        if r.random() > 0.5:
            return item
        image = image.flip(dims=(1,))
        mask = mask.flip(dims=(1,))
        return (image,mask)


show_image(RandomHorizontalFlip()((image,mask)))

In [None]:
class RandomVerticalFlip(object):

    def __call__(self, item):
        image, mask = item
        if r.random() > 0.5:
            return item
        image = image.flip(dims=(2,))
        mask = mask.flip(dims=(2,))
        return (image,mask)


show_image(RandomVerticalFlip()((image,mask)))

Rotation de 90

In [None]:
class RandomRotation(object):

    def angle(self):
        #Sens de rotation
        return r.uniform(-180,180)

    def __call__(self,item):
        image,mask = item
        angle = self.angle()
        image = F.rotate(image,angle=angle)
        mask = F.rotate(mask,angle)
        return (image,mask)

show_image(RandomRotation()((image,mask)))

Saturation, Contrast, Brightness

In [None]:
class CustomColorJitter(object):

    def __init__(self,brightness = 0.3, hue = 0.3, saturation = 0.3):
        self.brightness = brightness
        self.hue = hue
        self.saturation = saturation

    def __call__(self,item):
        image, mask = item
        image = transforms.ColorJitter(brightness=self.brightness,hue=self.hue,saturation=self.saturation)(image)
        return (image,mask)

show_image(CustomColorJitter()((image,mask)))

Blur

In [None]:
class RandomBlur(object):

    def __init__(self, kernel_size=25, blurred_ratio = 0.2):
        self.kernel_size = kernel_size
        self.blurred_ratio = blurred_ratio

    def __call__(self,item):
        image, mask = item
        if r.random() > self.blurred_ratio:
            return item
        image = F.gaussian_blur(image,kernel_size=self.kernel_size)
        return (image,mask)

show_image(RandomBlur()((image,mask)))

Toutes les transformations d'un coup

In [None]:
data_transform = transforms.Compose([
    RandomVerticalFlip(),
    RandomHorizontalFlip(),
    RandomRotation(),
    CustomColorJitter(),
    RandomBlur()
])

show_image(data_transform((image,mask)))