# Purpose
extract masks from cloth images

In [11]:
import os
from pathlib import Path

import numpy as np
import math

import matplotlib.pyplot as plt

import wandb
from wandb.keras import WandbCallback

from PIL import Image

from sklearn.model_selection import train_test_split

import tensorflow as tf
from tensorflow import keras
from keras import layers
from keras.preprocessing.image import ImageDataGenerator
from keras.applications.vgg16 import preprocess_input
from keras import models

# Path list

In [12]:
image_paths, mask_paths = [], []
datapath = '../cloth_segmentation/raw_data/'

for dirpath, _, filename in os.walk(datapath):
    for file in filename:
        path = Path(dirpath).joinpath(file)
        if path.parent.name == 'IMAGES':
            image_paths.append(str(path))
        else: mask_paths.append(str(path))
image_paths = np.asarray(image_paths)
mask_paths = np.asarray(mask_paths)

In [13]:
image_paths, len(image_paths)

(array(['..\\cloth_segmentation\\raw_data\\IMAGES\\img_0001.jpeg',
        '..\\cloth_segmentation\\raw_data\\IMAGES\\img_0002.jpeg',
        '..\\cloth_segmentation\\raw_data\\IMAGES\\img_0003.jpeg',
        '..\\cloth_segmentation\\raw_data\\IMAGES\\img_0004.jpeg',
        '..\\cloth_segmentation\\raw_data\\IMAGES\\img_0005.jpeg',
        '..\\cloth_segmentation\\raw_data\\IMAGES\\img_0006.jpeg',
        '..\\cloth_segmentation\\raw_data\\IMAGES\\img_0007.jpeg',
        '..\\cloth_segmentation\\raw_data\\IMAGES\\img_0008.jpeg',
        '..\\cloth_segmentation\\raw_data\\IMAGES\\img_0009.jpeg',
        '..\\cloth_segmentation\\raw_data\\IMAGES\\img_0010.jpeg',
        '..\\cloth_segmentation\\raw_data\\IMAGES\\img_0011.jpeg',
        '..\\cloth_segmentation\\raw_data\\IMAGES\\img_0012.jpeg',
        '..\\cloth_segmentation\\raw_data\\IMAGES\\img_0013.jpeg',
        '..\\cloth_segmentation\\raw_data\\IMAGES\\img_0014.jpeg',
        '..\\cloth_segmentation\\raw_data\\IMAGES\\img_0015.jp

In [14]:
mask_paths, len(mask_paths)

(array(['..\\cloth_segmentation\\raw_data\\MASKS\\seg_0001.png',
        '..\\cloth_segmentation\\raw_data\\MASKS\\seg_0002.png',
        '..\\cloth_segmentation\\raw_data\\MASKS\\seg_0003.png',
        '..\\cloth_segmentation\\raw_data\\MASKS\\seg_0004.png',
        '..\\cloth_segmentation\\raw_data\\MASKS\\seg_0005.png',
        '..\\cloth_segmentation\\raw_data\\MASKS\\seg_0006.png',
        '..\\cloth_segmentation\\raw_data\\MASKS\\seg_0007.png',
        '..\\cloth_segmentation\\raw_data\\MASKS\\seg_0008.png',
        '..\\cloth_segmentation\\raw_data\\MASKS\\seg_0009.png',
        '..\\cloth_segmentation\\raw_data\\MASKS\\seg_0010.png',
        '..\\cloth_segmentation\\raw_data\\MASKS\\seg_0011.png',
        '..\\cloth_segmentation\\raw_data\\MASKS\\seg_0012.png',
        '..\\cloth_segmentation\\raw_data\\MASKS\\seg_0013.png',
        '..\\cloth_segmentation\\raw_data\\MASKS\\seg_0014.png',
        '..\\cloth_segmentation\\raw_data\\MASKS\\seg_0015.png',
        '..\\cloth_segmen

# Data splits

In [15]:
x_train, x_test, y_train, y_test = train_test_split(image_paths, mask_paths, train_size=0.8)
x_test, x_val, y_test, y_val = train_test_split(x_test, y_test, train_size=0.5)

# Data input generator

In [16]:
class ImageGenerator(keras.utils.Sequence):
    def __init__(self, x, y, batch_size):
        self.x, self.y = np.asarray(x), np.asarray(y)
        self.batch_size = batch_size

    def __len__(self):
        return math.ceil(len(self.x) / self.batch_size)

    def __getitem__(self, idx):
        batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]

        batch_x = [np.asarray(Image.open(path).resize((512, 512), resample=Image.BICUBIC)) for path in batch_x]
        batch_x = np.asarray(batch_x)

        batch_y = [np.asarray(Image.open(path).resize((512, 512), resample=Image.BICUBIC)) for path in batch_y]
        batch_y = np.asarray(batch_y)
        batch_y[batch_y > 0] = 1
        batch_y = np.asarray(batch_y, dtype=np.float32)

        return batch_x, batch_y

    def on_epoch_end(self):
        indices = np.random.permutation(len(self.x))
        self.x, self.y = self.x[indices], self.y[indices]

In [17]:
train_gen = ImageGenerator(x_train, y_train, 8)
val_gen = ImageGenerator(x_val, y_val, 8)
test_gen = ImageGenerator(x_test, y_test, 8)

# Keras callback

In [25]:
checkpoint_filepath = '../cloth_segmentation/artifacts/w-{epoch:02d}-{val_loss:.2f}.h5'
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_dice',
    mode='max',
    save_best_only=True)

# Wandb

In [19]:
wandb.init(project='cloth_segmentation')

wandb: Currently logged in as: porpoising (use `wandb login --relogin` to force relogin)


# Setup

In [20]:
def conv_block(filters, kernel_size, strides=1):
    return models.Sequential([
        layers.Conv2D(filters=filters, kernel_size=kernel_size, strides=strides, padding='same'),
        layers.BatchNormalization(),
        layers.Activation('relu')])

In [21]:
input = layers.Input(shape=(512, 512, 3))
conv_512x512 = conv_block(16, 3)(input)
conv_256x256 = conv_block(32, 3, 2)(conv_512x512)
conv_128x128 = conv_block(64, 3, 2)(conv_256x256)
conv_64x64 = conv_block(128, 3, 2)(conv_128x128)
conv_32x32 = conv_block(256, 3, 2)(conv_64x64)
conv_16x16 = conv_block(512, 3, 2)(conv_32x32)

up_32x32 = layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(conv_16x16)
conv_32x32_up = conv_block(256, 3)(up_32x32)
add_32 = layers.Add()([conv_32x32_up, conv_32x32])

up_64x64 = layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(add_32)
conv_64x64_up = conv_block(128, 3)(up_64x64)
add_64 = layers.Add()([conv_64x64_up, conv_64x64])

up_128x128 = layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(add_64)
conv_128x128_up = conv_block(64, 3)(up_128x128)
add_128 = layers.Add()([conv_128x128_up, conv_128x128])

up_256x256 = layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(add_128)
conv_256x256_up = conv_block(32, 3)(up_256x256)
add_256 = layers.Add()([conv_256x256_up, conv_256x256])

up_512x512 = layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(add_256)
conv_512x512_up = conv_block(16, 3)(up_512x512)

final_conv = conv_block(1, 1)(conv_512x512_up)
head = layers.Activation("sigmoid")(final_conv)

model = keras.Model(input, head)
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 512, 512, 3  0           []                               
                                )]                                                                
                                                                                                  
 sequential (Sequential)        (None, 512, 512, 16  512         ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 sequential_1 (Sequential)      (None, 256, 256, 32  4768        ['sequential[0][0]']             
                                )                                                             

In [22]:
def dice(y_true, y_pred):
    y_true_c = tf.keras.backend.flatten(y_true)
    y_pred_c = tf.keras.backend.flatten(y_pred)
    intersection = tf.keras.backend.sum(y_true_c * y_pred_c, axis=-1)
    return (2. * intersection + tf.keras.backend.epsilon()) / (
            tf.keras.backend.sum(y_true_c, axis=-1) +
            tf.keras.backend.sum(y_pred_c, axis=-1) +
            tf.keras.backend.epsilon())


def mse_log_dice(y_true, y_pred):
    mse_walls = tf.keras.losses.mean_squared_error(y_true, y_pred)
    dice_walls = tf.keras.backend.log(dice(y_true, y_pred))
    return mse_walls - dice_walls

In [23]:
model.compile(
    optimizer=keras.optimizers.Adam(0.0001),
    loss=mse_log_dice,
    metrics=[dice]
)

In [24]:
model.fit(x=train_gen, validation_data=val_gen,
          epochs=30, callbacks=[model_checkpoint_callback, WandbCallback()])

Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30


<keras.callbacks.History at 0x13b7ac09b80>