In [None]:
import os
import glob

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from skimage.io import imread

from sklearn.model_selection import train_test_split

%env SM_FRAMEWORK=tf.keras
import segmentation_models as sm
from segmentation_models.losses import *
from segmentation_models.metrics import *

import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import *
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.optimizers import Adam

tf.keras.backend.set_image_data_format('channels_last')
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import imgaug.augmenters as iaa

In [None]:
short_codes = [
    'Background',
    'LC_Fol', 
    'LC_Branch', 
    'Ar_TA', 
    'LC_Por', 
    'UDC_CCA', 
    'Ro_TA', 
    'Pa_FA', 
    'Mu_Ba', 
    'CR_TA', 
    'Pa_Cy', 
    'LC_Encr', 
    'Sa_Ba', 
    'Pa_H', 
    'Pa_TA', 
    'UDC_H', 
    'CR_FA', 
    'Ro_CCA', 
    'Mu_Cy', 
    'UDC_TA', 
    'Sa_Cy', 
    'UDC_FA', 
    'Ro_H', 
    'Sa_FA', 
    'Pa_CCA', 
    'Ro_Ba', 
    'Ro_Cy', 
    'Sa_TA', 
    'Ar_Ba', 
    'UDC_Cy', 
    'Ar_FA', 
    'CR_CCA'
]

# Creating a mapping of short codes to label values
class2label = {k:i for i, k in enumerate(short_codes)}

colors = sns.color_palette('Paired', n_colors=len(short_codes))
colors = (np.array(colors) * 255).astype(np.uint8)

# Creating a mapping of label to color values
label2color = {i: colors[i] for i, k in enumerate(short_codes)}

In [None]:
DATA_PATH = "/home/azureuser/cloudfiles/code/Users/jordan.pierce/Data/Guam_Saipan/3653/"
assert os.path.exists(DATA_PATH)

EXP_DIR = "Experiments/"
EXP_NAME = "112"
EXP_FOLDER = EXP_DIR + EXP_NAME + "/"
WEIGHTS_DIR = EXP_FOLDER + "Weights/"
LOGS_DIR = EXP_FOLDER + "Logs/"

os.makedirs(EXP_DIR, exist_ok=True)
os.makedirs(EXP_FOLDER, exist_ok=True)
os.makedirs(WEIGHTS_DIR, exist_ok=True) 
os.makedirs(LOGS_DIR, exist_ok=True) 

label_path = DATA_PATH + "Updated_CNet_Segmentation_Masks.csv"
data = pd.read_csv(label_path, index_col=0)

In [None]:
train, valid = train_test_split(data, test_size = .1)

train.reset_index(drop = True, inplace = True)
valid.reset_index(drop = True, inplace = True)

len(train), len(valid)

In [None]:
def colorize_mask(mask):
   
    colored_mask = np.zeros(shape = (mask.shape[0], mask.shape[1], 3), dtype=np.uint8)

    for _ in np.unique(mask):
           
            colored_mask[mask == _] = label2color[_]
        
    return colored_mask

In [None]:
height, width = 736, 1280 


# Augmentation methods
augs_for_images = iaa.Sequential([iaa.Resize(size = {'height' : height, 'width' : width}, interpolation = 'linear',
                                            random_state = 5),
                                  iaa.Fliplr(0.25, random_state = 1),
                                  iaa.Flipud(0.25, random_state = 2),
                                  iaa.Rot90([1, 2, 3, 4], True, random_state = 3)
                       ])


augs_for_masks = iaa.Sequential([iaa.Resize(size = {'height' : height, 'width' : width}, interpolation = 'nearest',
                                           random_state = 5),
                                  iaa.Fliplr(0.25, random_state = 1),
                                  iaa.Flipud(0.25, random_state = 2),
                                  iaa.Rot90([1, 2, 3, 4], True, random_state = 3)
                                ])



resize_for_images = iaa.Sequential([
     iaa.Resize(size = {'height' : height, 'width' : width}, interpolation = 'linear', random_state = 1),
])

resize_for_masks = iaa.Sequential([
     iaa.Resize(size = {'height' : height, 'width' : width}, interpolation = 'nearest', random_state = 1),
])


# Image data generator class
class DataGenerator(tf.keras.utils.Sequence):
    
    def __init__(self, dataframe, batch_size, augment, n_classes):
        self.dataframe = dataframe
        self.batch_size = batch_size
        self.n_classes = n_classes
        self.augment = augment
          
        
    # Steps per epoch    
    def __len__(self):
        return len(self.dataframe) // self.batch_size
    
    # Shuffles and resets the index at the end of training epoch
    def on_epoch_end(self):
        self.dataframe = self.dataframe.reset_index(drop = True)
    
    
    # Generates data, feeds to training
    def __getitem__(self, index):
        
        global preprocess_input
        
        processed_images = []
        processed_masks = []
        
        for _ in range(self.batch_size):

            the_image = plt.imread(self.dataframe['Image'][index])
            the_mask = np.load(self.dataframe['Mask'][index]).astype('uint8')
            one_hot_mask = to_categorical(the_mask, len(list(short_codes)))
            
            if(self.augment):
                
                processed_image = augs_for_images(image = the_image)
                processed_mask = augs_for_masks(image = one_hot_mask)
         
            else:
                # Still resizing and then random cropping, but no augmentations   
                processed_image = resize_for_images(image = the_image)
                processed_mask = resize_for_masks(image = one_hot_mask)

            processed_images.append(preprocess_input(processed_image))
            processed_masks.append(processed_mask)

                
        batch_x = np.array( processed_images )
        batch_y = np.array( processed_masks )
        
        return (batch_x, batch_y)


In [None]:
# Parameters for training      
batch_size = 2
num_epochs = 20

steps_per_epoch_train = len(train) // batch_size; print(steps_per_epoch_train)
steps_per_epoch_valid = len(valid) // batch_size; print(steps_per_epoch_valid)

train_gen = DataGenerator(train, batch_size=batch_size, augment=True, n_classes=len(short_codes)) 
valid_gen = DataGenerator(valid, batch_size=batch_size, augment=False, n_classes=len(short_codes))

In [None]:
BACKBONE = 'efficientnetb0'
preprocess_input = sm.get_preprocessing(BACKBONE) 

model = sm.Unet(input_shape = (None, None, 3), 
                backbone_name = BACKBONE, 
                encoder_weights = 'imagenet',
                activation = 'softmax', 
                classes = len(list(short_codes)),
                encoder_freeze = True,
                decoder_use_batchnorm = True)

In [None]:
class_weights = [1.0 for _ in range(len(short_codes))]
class_weights[class2label['Background']] = 0

In [None]:
jaccard_loss = JaccardLoss(class_weights=class_weights)

In [None]:
model.compile(optimizer = Adam(lr = 0.001), 
              loss = [jaccard_loss], 
              metrics = ['accuracy', iou_score, precision, recall])

In [None]:

callbacks = [
                ReduceLROnPlateau(monitor = 'val_loss', factor = .65, patience = 2, verbose = 1),

                ModelCheckpoint(filepath = WEIGHTS_DIR + 'model-{epoch:03d}.h5', 
                                monitor='val_loss', save_weights_only = True, 
                                save_best_only = False, verbose = 1),
]

In [None]:
history = model.fit_generator(generator = train_gen, 
                              steps_per_epoch = steps_per_epoch_train, 
                              epochs = num_epochs, 
                              validation_data = valid_gen,
                              validation_steps = steps_per_epoch_valid,
                              verbose = 1,
                              callbacks = callbacks)

In [None]:
print(history.history.keys())

plt.figure(figsize= (10, 5))
plt.plot(history.history["loss"], label="train_loss")
plt.plot(history.history["val_loss"], label="val_loss")
plt.plot(np.argmin(history.history["val_loss"]), 
         np.min(history.history["val_loss"]), 
         marker = "x", color = "b", label = "best model")
plt.title("Training Loss")
plt.xlabel("Epoch #")
plt.ylabel("Loss")
plt.legend(loc="upper right")
plt.save(EXP_FOLDER + "Loss.png")
plt.show()

plt.figure(figsize= (10, 5))
plt.plot(history.history["precision"], label="precision")
plt.plot(history.history["val_precision"], label="val_precision")
plt.title("Training Precision")
plt.xlabel("Epoch #")
plt.ylabel("Precision")
plt.legend(loc="upper right")
plt.save(EXP_FOLDER + "Precision.png")
plt.show()

plt.figure(figsize= (10, 5))
plt.plot(history.history["recall"], label="recall")
plt.plot(history.history["val_recall"], label="val_recall")
plt.title("Training Recall")
plt.xlabel("Epoch #")
plt.ylabel("Recall")
plt.legend(loc="upper right")
plt.save(EXP_FOLDER + "Recall.png")
plt.show()

In [None]:
weights = sorted(glob.glob(WEIGHTS_DIR + "*.h5"), key=os.path.getmtime)
[print(w, i) for i, w in enumerate(weights)];

In [None]:
best_weights = weights[6]
print("Best Weights: ", best_weights)
model.load_weights(best_weights)

In [None]:
# Making predictions with the trained model

test_gen = DataGenerator(valid, batch_size=1, augment=False, n_classes=len(short_codes))

for _ in range(5):
    
    image, mask = test_gen.__getitem__(_) 
    prediction = model.predict(image)

    mask = np.argmax(mask, axis=-1).astype("uint8")
    prediction = np.argmax(prediction, axis=-1).astype("uint8")

    image = image.squeeze()
    mask = mask.squeeze()
    prediction = prediction.squeeze()
    
    plt.figure(figsize=(20, 20))
    plt.subplot(1, 3, 1)
    plt.imshow(image)
    plt.subplot(1, 3, 2)
    plt.imshow(colorize_mask(mask))
    plt.subplot(1, 3, 3)
    plt.imshow(colorize_mask(prediction))
    plt.show()