# Albumentation

TL;DR : Data_augment is powerful, but won't be used here for various reasons
1- Our models are trained using only 10% of dataset, and already need almost 24 hours each to train
2- Data augment have been tried, but local machine seems unable to run data-augment pipeline

In [5]:
# classic Librairies
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import time
import os
import shutil

import tensorflow as tf
import keras
os.environ["SM_FRAMEWORK"] = "tf.keras"

import segmentation_models as sm
from segmentation_models import get_preprocessing

# image imports
from matplotlib.image import imread
import PIL
from PIL import Image, ImageFilter, ImageEnhance

from tensorflow.keras import layers, models
import tensorflow.keras.backend as K
from keras.metrics import IoU
from tensorflow.keras.metrics import MeanIoU

import albumentations as A
import cv2

from metrics_and_loss import CombinedLoss, IoUMetric

Segmentation Models: using `tf.keras` framework.


  check_for_updates()


In [154]:
# Paramètres
IMG_HEIGHT = 512
IMG_WIDTH = 512
NUM_CLASSES = 8
BATCH_SIZE = 16

# Albumentations - train & validation
albumentations_transform_train = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.5),
    A.OneOf([
        A.RandomFog(p=0.25),         # If OneOf is chosen, RandomFog has a 25% chance within the OneOf block
        A.RandomSnow(p=0.25),           
        A.RandomRain(p=0.25), 
        A.RandomSunFlare(p=0.25)
    ], p=0.3), # The OneOf block has a 30% chance to be considered
    A.Resize(IMG_HEIGHT, IMG_WIDTH)
],seed=42, additional_targets={'mask': 'mask'}) # added seed for reproducibility

albumentations_transform_val = A.Compose([
    A.Resize(IMG_HEIGHT, IMG_WIDTH)
], additional_targets={'mask': 'mask'})

# Chargement avec OpenCV + Albumentations (appelée par tf.py_function)
def read_image_and_mask(img_path, mask_path, augment=True):
    img_path = img_path.numpy()
    mask_path = mask_path.numpy()
    
    #debug funcs
    if isinstance(img_path, bytes):
        print("Traitement de l'image :", img_path.decode('utf-8'), "is utf decoded")
        print("Traitement de l'image :", mask_path.decode('utf-8'), "is utf decoded")
    else:
        print("Traitement de l'image :", img_path)


    if isinstance(img_path, bytes):
        img_path = img_path.decode('utf-8')
    if isinstance(mask_path, bytes):
        mask_path = mask_path.decode('utf-8')
    
    image = cv2.imread(img_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    

    #checker if image / mask is available
    if image is None:
        raise ValueError(f"Impossible de lire l'image : {img_path}")
    if mask is None:
        raise ValueError(f"Impossible de lire le masque : {mask_path}")
        
    if augment:
        augmented = albumentations_transform_train(image=image, mask=mask)
    else:
        augmented = albumentations_transform_val(image=image, mask=mask)

    image = augmented["image"] / 255.0
    mask = augmented["mask"]

    mask_one_hot = np.zeros((IMG_HEIGHT, IMG_WIDTH, NUM_CLASSES), dtype=np.uint8)
    for i in range(NUM_CLASSES):
        mask_one_hot[:, :, i] = (mask == i).astype(np.uint8)
        
    #print("Somme des pixels image :", np.sum(image))
    #print("Somme des pixels masque :", np.sum(mask))
    print("Image OK : ", img_path)

    return image.astype(np.float32), mask_one_hot.astype(np.float32)

# Wrapper pour tf.data
def load_image_mask(img_path, mask_path, augment=True):
    image, mask = tf.py_function(
        func=lambda x, y: read_image_and_mask(x, y, augment=augment),
        inp=[img_path, mask_path],
        Tout=[tf.float32, tf.uint8]
    )
    image.set_shape([IMG_HEIGHT, IMG_WIDTH, 3])
    mask.set_shape([IMG_HEIGHT, IMG_WIDTH, NUM_CLASSES])
    return image, mask

def get_dataset(image_dir, mask_dir, augment=False):
    image_paths = sorted([os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith(".png")])
    mask_paths = sorted([os.path.join(mask_dir, f) for f in os.listdir(mask_dir) if f.endswith(".png")])

    assert len(image_paths) == len(mask_paths), "Mismatch entre le nombre d'images et de masques !"

    dataset = tf.data.Dataset.from_tensor_slices((image_paths, mask_paths))
    dataset = dataset.map(lambda x, y: load_image_mask(x, y, augment), num_parallel_calls=1) #problem is not from threading count
    dataset = dataset.shuffle(buffer_size=100)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)

    return dataset

In [141]:
def read_image_and_mask(img_path, mask_path, augment=True):
    print("Image path is :", img_path)
    img_path = img_path.numpy()
    mask_path = mask_path.numpy()
    print("Numpy Image path is :", img_path)
    
    #debug funcs
    if isinstance(img_path, bytes):
        print("Traitement de l'image :", img_path.decode('utf-8'), "is utf decoded")
        print("Traitement de l'image :", mask_path.decode('utf-8'), "is utf decoded")
    else:
        print("Traitement de l'image :", img_path)


    if isinstance(img_path, bytes):
        img_path = img_path.decode('utf-8')
    if isinstance(mask_path, bytes):
        mask_path = mask_path.decode('utf-8')
    
    image = cv2.imread(img_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    

    #checker if image / mask is available
    if image is None:
        raise ValueError(f"Impossible de lire l'image : {img_path}")
    if mask is None:
        raise ValueError(f"Impossible de lire le masque : {mask_path}")
        
    if augment:
        augmented = albumentations_transform_train(image=image, mask=mask)
    else:
        augmented = albumentations_transform_val(image=image, mask=mask)

    image = augmented["image"] / 255.0
    mask = augmented["mask"]

    mask_one_hot = np.zeros((IMG_HEIGHT, IMG_WIDTH, NUM_CLASSES), dtype=np.uint8)
    for i in range(NUM_CLASSES):
        mask_one_hot[:, :, i] = (mask == i).astype(np.uint8)
        
    #print("Somme des pixels image :", np.sum(image))
    #print("Somme des pixels masque :", np.sum(mask))
    print("Image OK : ", img_path)

    return image.astype(np.float32), mask_one_hot.astype(np.float32)

#bit more harsh, but tf.reshape forces the shape to be as we want
def load_image_mask(img_path, mask_path, augment=True):
    image, mask = tf.py_function(
        func=lambda x, y: read_image_and_mask(x, y, augment=augment),
        inp=[img_path, mask_path],
        Tout=[tf.float32 , tf.float32]
    )
    image = tf.reshape(image,(IMG_HEIGHT, IMG_WIDTH, 3))
    mask = tf.reshape(mask, (IMG_HEIGHT, IMG_WIDTH, NUM_CLASSES))
    #print("Image shape:", image.shape)
    #print("Mask shape:", mask.shape)
    return image, mask

In [181]:
train_image_dir = "data/train/images"
train_mask_dir = "data/train/gen_masks"
val_image_dir = 'data/val/images'
val_mask_dir = 'data/val/gen_masks'

# Charger datasets
train_dataset = get_dataset(train_image_dir, train_mask_dir, augment=False) #problem is not from training
val_dataset = get_dataset(val_image_dir, val_mask_dir, augment=False)

In [75]:
BACKBONE = 'inceptionv3'
combined_loss = CombinedLoss(smooth=100, alpha=0.5)
IoU_score = IoUMetric()

model_augmented = sm.Unet(BACKBONE, input_shape=(512, 512, 3), encoder_weights='imagenet', classes=8, activation='softmax')
model_augmented.compile('Adam', loss=combined_loss, metrics=['accuracy', IoU_score])

In [179]:
"""train_image_dir = "data/train/images"
train_mask_dir = "data/train/gen_masks"
image_files = os.listdir(train_image_dir)
mask_files = os.listdir(train_mask_dir)
for img_name, mask_name in zip(image_files, mask_files):
    #print(f"{img_name.split('_leftImg8bit')[0]} VS {mask_name.split('_gtFine_labelIds')[0]}")
    if not img_name.split('_leftImg8bit')[0] == mask_name.split('_gtFine_labelIds')[0]:
        print("Mismatch potentiel :", img_name, mask_name)
    else:
        print("all ok")"""
#proves that problem is not in potential mismatch

'train_image_dir = "data/train/images"\ntrain_mask_dir = "data/train/gen_masks"\nimage_files = os.listdir(train_image_dir)\nmask_files = os.listdir(train_mask_dir)\nfor img_name, mask_name in zip(image_files, mask_files):\n    #print(f"{img_name.split(\'_leftImg8bit\')[0]} VS {mask_name.split(\'_gtFine_labelIds\')[0]}")\n    if not img_name.split(\'_leftImg8bit\')[0] == mask_name.split(\'_gtFine_labelIds\')[0]:\n        print("Mismatch potentiel :", img_name, mask_name)\n    else:\n        print("all ok")'

In [183]:
t0 = time.time()

history = model_augmented.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=10
)

t1 = time.time() - t0
print("Took", t1, "seconds")

Epoch 1/10
Traitement de l'image : data/train/images\aachen_000022_000019_leftImg8bit.png is utf decoded
Traitement de l'image : data/train/gen_masks\aachen_000022_000019_gtFine_labelIds.png is utf decoded
Image OK :  data/train/images\aachen_000022_000019_leftImg8bit.png
Traitement de l'image : data/train/images\aachen_000025_000019_leftImg8bit.png is utf decoded
Traitement de l'image : data/train/gen_masks\aachen_000025_000019_gtFine_labelIds.png is utf decoded
Image OK :  data/train/images\aachen_000025_000019_leftImg8bit.png
Traitement de l'image : data/train/images\aachen_000028_000019_leftImg8bit.png is utf decoded
Traitement de l'image : data/train/gen_masks\aachen_000028_000019_gtFine_labelIds.png is utf decoded
Image OK :  data/train/images\aachen_000028_000019_leftImg8bit.png
Traitement de l'image : data/train/images\aachen_000031_000019_leftImg8bit.png is utf decoded
Traitement de l'image : data/train/gen_masks\aachen_000031_000019_gtFine_labelIds.png is utf decoded
Image OK

AttributeError: 'NoneType' object has no attribute 'items'

In [None]:
# Loss
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Val Loss')
plt.legend()
plt.title("Courbe de Loss")
plt.show()

# IoU ou autre métrique
if 'iou' in history.history:
    plt.plot(history.history['iou'], label='Train IoU')
    plt.plot(history.history['val_iou'], label='Val IoU')
    plt.legend()
    plt.title("Courbe de IoU")
    plt.show()