Segmentation model for clothing for the task of [iMaterialist (Fashion) 2019 at FGVC6](http://https://www.kaggle.com/c/imaterialist-fashion-2019-FGVC6) dataset on Keras using UNet and SeResNet50 as a backbone.


**Loading Libraries**

In [None]:
! pip install git+https://github.com/qubvel/segmentation_models

In [None]:
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import cv2
plt.style.use("ggplot")

from tqdm import tqdm_notebook, tnrange, tqdm
from skimage.io import imread, imshow, concatenate_images
from skimage.transform import resize
from skimage.morphology import label
from sklearn.model_selection import train_test_split

from PIL import Image
import pandas as pd
import gc
import glob
import shutil
import json
from imgaug import augmenters as iaa
import imgaug as ia
from pathlib import Path
import seaborn as sns
import albumentations as albu

In [None]:
import tensorflow as tf
from keras import backend as K
import keras

from keras.models import Model, load_model
from keras.layers import Input, Conv2D, Conv2DTranspose, Dropout,BatchNormalization
from keras.layers import Concatenate, MaxPooling2D, LeakyReLU
from keras.layers import UpSampling2D, Add, ZeroPadding2D
from keras.layers import GlobalAveragePooling2D, Reshape, Dense, Permute
from keras.layers.merge import concatenate, add, multiply
from keras.layers.core import Dense, Lambda, Activation, SpatialDropout2D
from keras.optimizers import Adam

from keras.losses import binary_crossentropy
from keras.engine.training import Model
from keras.callbacks import Callback, EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from keras.utils import Sequence

from segmentation_models import Unet
from segmentation_models.losses import bce_jaccard_loss
from segmentation_models.metrics import iou_score

%matplotlib inline

**Defining some important variables**

In [None]:
img_height = 512 
img_width = 512
img_channels = 3 
n_classes = 1 
batch_size = 8
BACKBONE = 'seresnet50'
preprocess_input = BACKBONE

In [None]:
seed = 10
np.random.seed(seed)
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
tf.set_random_seed(seed)

In [None]:
input_dir = "../input/imaterialist-fashion-2019-FGVC6/"
image_dir = "../input/imaterialist-fashion-2019-FGVC6/train/"
model_weights = "../input/weights-unet/unet_seresnet50_weights.h5"

**Data analysis**

* There are 333,415 total data.
* Considering only one unique image, the number of images is 45,625 total.
* One image can correspond to multiple classes.
* Not all clothing items are marked at the photos.
* There are 46 categories and 92 attributes.

In [None]:
train_df = pd.read_csv(input_dir + "train.csv")
print("Number of data items: ", train_df.shape[0])
print("Number of unique images:",len(set(train_df['ImageId'])))

json_data = open(input_dir + "label_descriptions.json").read()
label_descriptions = json.loads(json_data)
categories_label_df = pd.DataFrame(label_descriptions['categories'])
print("The number of categories: ",len(categories_label_df))

attributes_label_df = pd.DataFrame(label_descriptions['attributes'])
print("The number of attributes: ",len(attributes_label_df))

In [None]:
train_df_ImageId_count = train_df['ImageId'].value_counts()
plt.figure(figsize=(20, 7))
plt.title('image labels count', size=20)
plt.xlabel('', size=15);plt.ylabel('', size=15);
sns.countplot(train_df_ImageId_count)
plt.show()

In [None]:
train_df['Category'] = train_df['ClassId'].apply(lambda x: int(x.split("_")[0]))
groupby_category = train_df.groupby('Category')['ImageId'].count()
groupby_category.index = map(int, groupby_category.index)
groupby_category = groupby_category.sort_index()
groupby_category[:5]

fig = plt.figure(figsize=(20, 7))
x = groupby_category.index
y = groupby_category.values

sns.barplot(x,y)
plt.title("Number of data items by category", fontsize=20)
plt.xlabel("Category", fontsize=20)
plt.ylabel("# of masks", fontsize=20)
plt.show()

**Loading data**

In [None]:
df = pd.read_csv(input_dir + "train.csv").head(10000)
df['CategoryId'] = df.ClassId.apply(lambda x: str(x).split("_")[0])
temp_df = df.groupby('ImageId')['EncodedPixels', 'CategoryId'].agg(lambda x: list(x)).reset_index()
size_df = df.groupby('ImageId')['Height', 'Width'].mean().reset_index()
df = temp_df.merge(size_df, on='ImageId', how='left')
df.head()

In [None]:
class DataLoader():
    
    def __init__(self, img_dir, df, visualize=True):
        
        self.img_dir = img_dir
        self.df = df
        self.images = []
        self.masks = []
        self.visualize = visualize
    
    def get_data(self):       

        for index, row in tqdm(self.df.iterrows(), total=len(self.df)):
            
            image_id = row['ImageId']
            image_path = os.path.join(self.img_dir, image_id)
            self.images.append(image_path)
            
            mask = []
            for m, (annotation, label) in enumerate(zip(row['EncodedPixels'], row['CategoryId'])):
                sub_mask = self.rle_decode(annotation, (row["Height"], row["Width"]))
                sub_mask = Image.fromarray(sub_mask)
                mask.append(np.asarray(sub_mask))
            self.masks.append(sum(mask))    
            
            if self.visualize:
                if index % 100 == 0:
                    self.visualize_results(image_path, mask)
                    
        return self.images, self.masks
                    
    
    def rle_decode(self, mask_rle, shape):
        shape = (shape[1], shape[0])
        s = mask_rle.split()
        starts, lengths = [np.asarray(x, dtype=int) for x in (s[0::2], s[1::2])]
        starts -= 1
        ends = starts + lengths
        img = np.zeros(shape[0] * shape[1], dtype=np.uint8)
        for lo, hi in zip(starts, ends):
            img[lo:hi] = 1
        return img.reshape(shape).T 
    
    
    def visualize_results(self, image_path, mask):
        plt.subplot(1, 2, 1)
        img = Image.open(image_path).convert("RGB")
        plt.imshow(img)
        plt.subplot(1, 2, 2)
        plt.imshow(sum(mask), cmap="Blues_r")
        plt.show()

    def get_datalen(self):
        return len(self.images)

In [None]:
dataloader = DataLoader(image_dir, df)
images, masks = dataloader.get_data()
print(dataloader.get_datalen())

**Splitting data into train and test**

In [None]:
train_image, valid_image, train_masks, valid_masks = train_test_split(
        images, masks,
        test_size = 0.2, 
        random_state=42
)

print('{} training images && {} masks'.format(len(train_image), len(train_masks)))
print('{} validation images && {} masks'.format(len(valid_image), len(valid_masks)))

**Data visualization**

In [None]:
def visualize(train_image, train_labels):
    ix = random.randint(0, len(train_image))

    pil_im = np.asarray(Image.open(train_image[ix], 'r'))
    pil_mask = train_labels[ix]
    pil_im = cv2.resize(np.asarray(pil_im), (img_height, img_width))
    pil_mask = cv2.resize(np.asarray(pil_mask), (img_height, img_width))

    print(pil_mask.shape)

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize = (40, 30))

    ax1.imshow(pil_im)
    ax1.contour(np.squeeze(pil_mask), colors='pink', linewidths = 5)
    ax1.set_title('Image')

    ax1.imshow(pil_im, cmap="bone")
    ax2.imshow(pil_mask, alpha=0.5, cmap="Reds", interpolation = 'bilinear')
    ax2.set_title('Mask')

In [None]:
visualize(train_image, train_masks)

In [None]:
visualize(valid_image, valid_masks)

**Data Generator**

In [None]:
class DataGenerator(keras.utils.Sequence):

    def __init__(self, image_filenames, labels, batch_size, transforms=None):
        
        self.image_filenames, self.labels = image_filenames, labels
        self.batch_size = batch_size
        self.transforms = transforms
        
    def __len__(self):
        return len(self.image_filenames) // self.batch_size
    
    def __getitem__(self, idx):
        
        batch_x = self.image_filenames[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_y = self.labels[idx * self.batch_size:(idx + 1) * self.batch_size]
        
        batch_X = []
        batch_Y = []
        
        
        for img, mask in zip(batch_x, batch_y):
            
            with Image.open(img).convert('RGB') as img:
                img = cv2.resize(np.asarray(img), (img_height, img_width))
                img = np.asarray(img) / 255.0

            img_mask = cv2.resize(np.asarray(mask), (img_height, img_width))
            img_mask = np.asarray(img_mask) / 255.0
            img_mask = (img_mask > 0).astype(np.uint8)
            img_mask = np.expand_dims(img_mask, axis=2)
            
            if self.transforms is not None:
                augmented = transforms(image=img, mask=img_mask)
                image_augm = augmented['image']
                mask_augm = augmented['mask'].reshape(img_height, img_width, 1)
                batch_X.append(image_augm)
                batch_Y.append(mask_augm)
            
            else:
                batch_X.append(img)
                batch_Y.append(img_mask)
            
       
        batch_X = np.asarray(batch_X)
        batch_Y = np.asarray(batch_Y)
        
        return batch_X, batch_Y

In [None]:
transforms = albu.Compose([
    albu.HorizontalFlip(p=0.25),
    albu.Transpose(p=0.2),
    albu.RandomBrightnessContrast(p=0.25),
    albu.RandomGamma(p=0.25),
    albu.IAAEmboss(p=0.25),
    albu.Blur(p=0.2, blur_limit = 3)
], p = 1)

In [None]:
train_generator = DataGenerator(train_image, train_masks, batch_size, transforms)
valid_generator = DataGenerator(valid_image, valid_masks, batch_size)

**Building a segmentation model**

Pretrained Unet with SeResNet50 encoder

In [None]:
def build_model(image_size):
    model = Unet(BACKBONE, encoder_weights='imagenet', input_shape=image_size)
    return model

In [None]:
K.clear_session() 

model = build_model(image_size=(img_height, img_width, img_channels))

adam = Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-07, decay=0)
model.compile(optimizer='adam', loss=bce_jaccard_loss, metrics=[iou_score])

print(model.summary())

In [None]:
class CustomCallback(Callback):
    def __init__(self, model):
        self.batch_size = 10
        self.generator = DataGenerator(valid_image, valid_masks, self.batch_size)
        self.model = model

    def on_epoch_end(self, epoch, logs={}):
        
        X, y_true = self.generator.__getitem__(0)
        
        y_pred = model.predict(X)
        i = 4
        count = 0
        img = X[i]
        mask_true = y_true[i]
        mask_pred = y_pred[i]
        mask_binary = (mask_pred > 0.5).astype(np.uint8)

        fig, ax = plt.subplots(1, 4, figsize = (40, 30))

        ax[count].imshow(img)
        ax[count].set_title('Image')

        ax[count + 1].imshow(mask_true.squeeze(), cmap = 'gray', interpolation = 'bilinear')
        ax[count + 1].set_title('Mask')

        ax[count + 2].imshow(mask_pred.squeeze(), cmap = 'gray', interpolation = 'bilinear')
        ax[count + 2].set_title('Predicted_Mask')
        
        ax[count + 3].imshow(mask_binary.squeeze(), cmap = 'gray', interpolation = 'bilinear')
        ax[count + 3].set_title('Binary_Mask')

        plt.show()
        return

**Training the model**

In [None]:
epochs = 30
keras.backend.get_session().run(tf.global_variables_initializer())

history = model.fit_generator(
      generator = train_generator,
      epochs = epochs,
      callbacks = [
           CustomCallback(model),
           ModelCheckpoint('unet_weights_epoch-{epoch:02d}_loss-{loss:.4f}.h5',
                           monitor='val_loss',
                           verbose=1,
                           save_best_only=True,
                           save_weights_only=True,
                           mode='auto',
                           period=1),
           ReduceLROnPlateau(monitor='val_loss',
                             factor=0.5,
                             patience=0,
                             epsilon=0.001,
                             cooldown=0)
          ],
      validation_data = valid_generator,
      max_queue_size=10
)

**Saving the trained model**

In [None]:
model_name = 'unet_seresnet50'
model.save('{}.h5'.format(model_name))
model.save_weights('{}_weights.h5'.format(model_name))

print()
print("Model saved under {}.h5".format(model_name))
print("Weights also saved separately under {}_weights.h5".format(model_name))
print()

**Visualizing the results of the trained model**
* history of the training: loss and metrics
* images with corresponding true and predicted masks

In [None]:
def visualize_history(history):
    plt.figure(figsize=(16,4))
    plt.subplot(1,2,1)
    plt.plot(history.history['iou_score'][1:])
    plt.plot(history.history['val_iou_score'][1:])
    plt.ylabel('iou')
    plt.xlabel('epoch')
    plt.legend(['train','Validation'], loc='upper left')

    plt.title('model IOU')

    plt.subplot(1,2,2)
    plt.plot(history.history['loss'][1:])
    plt.plot(history.history['val_loss'][1:])
    plt.ylabel('val_loss')
    plt.xlabel('epoch')
    plt.legend(['train','Validation'], loc='upper left')
    plt.title('model loss')
    gc.collect()

In [None]:
visualize_history(history)

In [None]:
def plot_results(num, photos, labels): 
    
    batch_size = 10 
    test_gen = DataGenerator(photos, labels, batch_size)

    X, y_true = test_gen.__getitem__(num)
    y_pred = model.predict(X)
    y_pred = (y_pred > 0.4).astype(np.uint8)

    fig, ax = plt.subplots(10, 3, figsize = (40, 100))
    count = 0

    for i in range(len(X)):
        count = 0

        img = X[i]
        mask_true = y_true[i]
        mask_pred = y_pred[i]

        ax[i][count].imshow(img)
        ax[i][count].set_title('Image')
        ax[i][count].contour(mask_pred.squeeze(), colors='yellow', levels=[0.5], linewidths = 3)
        ax[i][count].axis('off')

        ax[i][count + 1].imshow(mask_true.squeeze(), cmap = 'gray', interpolation = 'bilinear')
        ax[i][count + 1].set_title('Mask')
        ax[i][count + 1].axis('off')

        ax[i][count + 2].imshow(mask_pred.squeeze(), alpha=0.5, cmap="Reds")
        ax[i][count + 2].contour(mask_true.squeeze(), colors='green', levels=[0.5], linewidths = 3)
        ax[i][count + 2].set_title('Predicted_Mask')
        ax[i][count + 2].axis('off')

    plt.show()

In [None]:
plot_results(0, valid_image, valid_masks)

In [None]:
plot_results(1, valid_image, valid_masks)

In [None]:
plot_results(2, valid_image, valid_masks)

In [None]:
plot_results(6, train_image, train_masks)

**Loading model and testing it on the new data**

In [None]:
model.load_weights('unet_seresnet50_weights.h5')

In [None]:
def files(path_img):  
    input_images = []
    
    num = 0
    for file in os.listdir(path_img):
        
        if num == 200:
            break
        
        if os.path.isfile(os.path.join(path_img, file)):
            input_images.append(os.path.join(path_img, file))
            num += 1
            
    return input_images

In [None]:
input_dir = "../input/imaterialist-fashion-2019-FGVC6/test/"
test = files(input_dir)
print(len(test))

In [None]:
def plot_test(X): 
    
    fig, ax = plt.subplots(10, 2, figsize = (20, 100))
    count = 0

    for i in range(10):
        ix = random.randint(0, len(X))
        count = 0

        img = X[ix]
        with Image.open(img).convert('RGB') as img:
                img = cv2.resize(np.asarray(img), (img_height, img_width))
                img = np.asarray(img) / 255.0
        
        y_pred = model.predict(np.expand_dims(img, axis=0))
        y_pred = (y_pred > 0.4).astype(np.uint8)

        mask_pred = y_pred

        ax[i][count].imshow(img)
        ax[i][count].set_title('Image')
        ax[i][count].contour(mask_pred.squeeze(), colors='yellow', levels=[0.5], linewidths = 3)
        ax[i][count].axis('off')

        ax[i][count + 1].imshow(mask_pred.squeeze(), alpha=0.5, cmap="Reds")
        ax[i][count + 1].set_title('Predicted_Mask')
        ax[i][count + 1].axis('off')

    plt.show()

In [None]:
plot_test(test)

In [None]:
plot_test(test)

**Conclusion **
* UNet with SeResNet50 has the best results comparing to UNet++ and DLA on ResNet50.
* Not all the data items have correct masks.
* There are images with a lot of people and masks only for one person.
* There can be a lot of textures on the images that makes it difficult for segmentation.
* Not enough resources :c

**Future work**
* Add item classification
* Try other backbones 