# **Import Dataset**

In [None]:
!pip3 install git+https://github.com/lucasb-eyer/pydensecrf.git
!wget -q https://www.dropbox.com/s/8conv524x6xid27/dataset.tar.bz2?dl=0 && mv "dataset.tar.bz2?dl=0" "dataset.tar.bz2" && tar -jxf dataset.tar.bz2 && rm dataset.tar.bz2
!wget -q https://www.dropbox.com/s/5kcxh2dcjtn35k7/Object.csv?dl=0 && mv "Object.csv?dl=0" "Object.csv"
!wget -q https://www.dropbox.com/s/7bfw646mpadynt2/Semantic.csv?dl=0 && mv "Semantic.csv?dl=0" "Semantic.csv"
!mkdir dataset/Ori
!mkdir dataset/Train
!mkdir dataset/Annotations
!mkdir dataset/Valid
!mkdir dataset/Valid_Annotations
!mkdir dataset/Translated_Annotations
!mv dataset/*.jpg dataset/Ori
!rm -r sample_data/

In [None]:
!rm Object.csv
!rm dataset/Ori/Colsterium*
!rm dataset/Ori/Cylindrocystis*
!rm dataset/Ori/Lepocinclis*
!rm dataset/Ori/Peridinium*
!rm dataset/Ori/Pinnularia*
!rm dataset/Ori/Pleurotaenium*
!rm dataset/Ori/Pyrocystis*
!rm dataset/Ori/Volvox*
!rm dataset/Ori/Ceratium*
!rm dataset/Ori/Coleps*
!rm dataset/Ori/Collodictyon*
!rm dataset/Ori/Didinium*
!rm dataset/Ori/Dinobryon*
!rm dataset/Ori/Frontonia*
!rm dataset/Ori/Phacus*
!rm dataset/Ori/Paramecium_b*
!rm dataset/Ori/Paramecium\ s*

In [None]:
!grep Micrasterias Semantic.csv > S.csv

# **Augment Images**

In [None]:
import os
import math
from PIL import Image

def crop(directory):
	''' Crops images to make their dimetions multiples of 32'''
	os.makedirs('./dataset/Cropped', exist_ok=True)
	for i in os.listdir(directory):
		filename = '{}/{}'.format(directory, i)
		img = Image.open(filename)
		W, H = img.size
		w = 32* math.floor(W/32)
		h = 32* math.floor(H/32)
		area = (0, 0, w, h)
		c_img = img.crop(area)
		c_img.save('./dataset/Cropped/{}'.format(i))
crop('dataset/Ori')

In [None]:
import cv2
import json
import imgaug as ia
from collections import defaultdict
from imgaug import augmenters as iaa

def translate_poly(	image_path='./dataset/Train',
					ann_input='./dataset/Annotations',
					ann_output='./dataset/Augmented_Annotations',
					input_format='csv',
					output_format='json'):
	''' Translating between different polygon annotation formats '''
	if input_format == 'csv':
		TheLines = []
		POLY = defaultdict(list)
		with open(ann_input, 'r') as F:
			next(F)
			for line in F:
				line = line.strip().split(':')
				filename = line[0].split(',')[0]
				path = '{}/{}'.format(image_path, filename)
				W, H = Image.open(path).size
				lineColor = [0, 255, 0, 128]
				fillColor = [255, 255, 0, 128]
				shape_type = line[1].split('"')[2]
				line_color = 'null'
				fill_color = 'null'
				label = line[3].split('"')[-3]
				x_points = line[2].split('"')[0][:-2][1:].split(',')
				y_points = line[3].split('"')[0][:-2][1:].split(',')
				points = []
				for x, y in zip(x_points, y_points):
					try:
						x = int(x)
						y = int(y)
						point = [x, y]
						points.append(point)
					except:
						x = float(x)
						y = float(y)
						point = [x, y]
						points.append(point)
				POLY[filename,
					str(lineColor),
					str(fillColor),
					path, W, H].append([label,
										line_color,
										fill_color,
										points,
										shape_type])
	elif input_format == 'json':
		POLY = defaultdict(list)
		for TheFile in os.listdir(ann_input):
			with open('{}/{}'.format(ann_input, TheFile), 'r') as f:
				d = json.load(f)
				filename = TheFile.split('.')[0]+'.jpg'
				W, H = d['imageWidth'], d['imageHeight']
				lineColor = d['lineColor']
				fillColor = d['fillColor']
				path = d['imagePath']
				size = os.stat('{}/{}'.format(image_path, filename)).st_size
				for tot, x in enumerate(d['shapes']): total = tot+1
				num = 0
				for item in d['shapes']:
					label = item['label']
					points = item['points']
					x = []
					y = []
					for i in points:
						x.append(i[0])
						y.append(i[1])
					shape_type = item['shape_type']
					line_color = item['line_color']
					fill_color = item['fill_color']
					POLY[filename,
						str(lineColor),
						str(fillColor),
						path, W, H].append([label,
											line_color,
											fill_color,
											points,
											shape_type])
					num += 1
	if output_format == 'csv':
		with open('{}/Translated.csv'.format(ann_output), 'w+') as f:
			header = 'filename,file_size,file_attributes,region_count,region_id,region_shape_attributes,region_attributes\n'
			line = f.seek(0)
			if f.readline() != header:
				f.write(header)
			for name in POLY:
				filename = name[0]
				size = os.stat('{}/{}'.format(image_path, filename)).st_size
				for tot, x in enumerate(POLY[name]): total = tot+1
				num = 0
				for item in POLY[name]:
					shapetype = item[4]
					label = item[0]
					x = []
					y = []
					for points in item[3]:
						x.append(points[0])
						y.append(points[1])
					TheLine = '{},{},"{{}}",{},{},"{{""name"":""{}"",""all_points_x"":{},""all_points_y"":{}}}","{{""{}"":""""}}"\n'\
					.format(filename, size, total, num, shape_type, x, y, label)
					f.write(TheLine)
					num += 1
	elif output_format == 'json':
		for name in POLY:
			filename = name[0].split('.')[0]
			with open('{}/{}.json'.format(ann_output, filename), 'w+') as f:
				version = '3.11.2'
				flags = ''
				lineColor = name[1]
				fillColor = name[2]
				path = name[3]
				imageData = ''
				W, H = str(name[4]), str(name[5])
				header = '{{"version": "{}",\n"flags": {{{}}},\n"lineColor": {},\n"fillColor": {},\n"imagePath": "{}",\n"imageData": "{}",\n"imageHeight": {},\n"imageWidth": {},\n"shapes": ['\
				.format(version, flags, lineColor, fillColor, path, imageData, W, H)
				f.write(header)
				for info in POLY[name]:
					shape_type = info[4]
					line_color = info[1]
					fill_color = info[2]
					label = info[0]
					points = info[3]
					body = '\n\t{{"label": "{}",\n\t\t"line_color": {},\n\t\t"fill_color": {},\n\t\t"points": {},\n\t\t"shape_type": "{}"}},'\
					.format(label, line_color, fill_color, points, shape_type)
					f.write(body)
				loc = f.seek(0, os.SEEK_END)
				f.seek(loc-1)
				f.write(']}')
	print('[+] Done')
 
def augment_poly(TheImage, im_out, ann_path, ann_output, iterations):
    for iters in range(int(iterations)):
        seq = iaa.Sequential([
            iaa.Fliplr(0.5),
            iaa.Flipud(0.5),
            iaa.Multiply((0.7, 1.0)),
            iaa.Affine(
                    translate_percent={"x": (-0.2, 0.2), "y": (-0.2, 0.2)},
                    rotate=(-90, 90),
                    shear=(-16, 16),
                    mode=ia.ALL),
            iaa.Sometimes(0.5, iaa.Dropout((0.001, 0.01), per_channel=0.5))
            ], random_order=True)
        seq_det = seq.to_deterministic()
        im = cv2.imread(TheImage, 1)
        #im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
        with open(ann_path) as handle: data = json.load(handle)
        shape_dicts = data['shapes']
        points = []
        aug_shape_dicts = []
        i = 0
        for shape in shape_dicts:
            for pairs in shape['points']:
                points.append(ia.Keypoint(x=pairs[0], y=pairs[1]))
            _d = {}
            _d['label'] = shape['label']
            _d['index'] = (i, i+len(shape['points']))
            aug_shape_dicts.append(_d)
            i += len(shape['points'])
        W, H = Image.open(TheImage).size
        keypoints = ia.KeypointsOnImage(points, shape=(W,H,3))###### Switch if incorrect
        image_aug = seq_det.augment_images([im])[0]
        keypoints_aug = seq_det.augment_keypoints([keypoints])[0]
        for shape in aug_shape_dicts:
            start, end = shape['index']
            aug_points = [[keypoint.x, keypoint.y] for keypoint in keypoints_aug.keypoints[start:end]]
            shape['points'] = aug_points
        NewName = TheImage.split('/')[-1].split('.')[0]
        #print('{}/Aug_{}-{}.jpg'.format(im_out, NewName, str(iters+1)))
        cv2.imwrite('{}/Aug_{}-{}.jpg'.format(im_out, NewName, str(iters+1)), image_aug)
        with open('{}/Aug_{}-{}.json'.format(ann_output, NewName, str(iters+1)), 'w+') as f:
            version = data['version']
            flags = data['flags']
            lineColor = data['lineColor']
            fillColor = data['fillColor']
            path = '.{}/Aug_{}'.format(im_out, TheImage.split('/')[-1])
            imageData = data['imageData']
            W, H = Image.open(TheImage).size
            header = '{{"version": "{}",\n"flags": {},\n"lineColor": {},\n"fillColor": {},\n"imagePath": "{}",\n"imageData": "{}",\n"imageHeight": {},\n"imageWidth": {},\n"shapes": ['\
            .format(version, flags, lineColor, fillColor, path, imageData, W, H)
            f.write(header)
            for info in aug_shape_dicts:
                shape_type = 'polygon'
                line_color = 'null'
                fill_color = 'null'
                label = info['label']
                points = info['points']
                body = '\n\t{{"label": "{}",\n\t\t"line_color": {},\n\t\t"fill_color": {},\n\t\t"points": {},\n\t\t"shape_type": "{}"}},'\
                .format(label, line_color, fill_color, points, shape_type)
                f.write(body)
            loc = f.seek(0, os.SEEK_END)
            f.seek(loc-1)
            f.write(']}')

translate_poly(	image_path='./dataset/Cropped',
                ann_input='S.csv',
                ann_output='./dataset/Translated_Annotations',
                input_format='csv',
                output_format='json')

for Images in os.listdir('./dataset/Cropped'):
    augment_poly('./dataset/Cropped/{}'.format(Images),
                './dataset/Train',
                './dataset/Translated_Annotations/{}.json'.format(Images.split('.')[0]),
                './dataset/Annotations', 
                2)

[+] Done


# **UNet**

In [None]:
import os
import sys
import cv2
import json
import random
import numpy as np
from PIL import Image
import tensorflow as tf
import pydensecrf.densecrf as dcrf
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import Adam
from pydensecrf.utils import unary_from_softmax
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.layers import Dropout, Lambda, Conv2DTranspose, Add
from tensorflow.keras.layers import Conv2D, Input, MaxPooling2D, concatenate

for exam in os.listdir('./dataset/Train'): exam = exam
WW, HH = Image.open('./dataset/Train/{}'.format(exam)).size
imshape = (HH, WW, 3)
mode = 'multi'
model_name = 'unet_'+mode
LABELS = 'Colsterium Cylindrocystis Lepocinclis Micrasterias Peridinium Pinnularia Pleurotaenium Pyrocystis Volvox Ceratium Coleps Collodictyon Didinium Dinobryon Frontonia Phacus'
hues = {}
for l in LABELS:
	hues[l] = random.randint(0, 360)
labels = sorted(hues.keys())
if mode == 'binary': n_classes = 1
elif mode == 'multi': n_classes = len(labels) + 1
assert imshape[0]%32 == 0 and imshape[1]%32 == 0,\
    "imshape should be multiples of 32. comment out to test different imshapes."

class DataGenerator(tf.keras.utils.Sequence):
    def __init__(self, image_paths, annot_paths, batch_size=32, shuffle=True):
        self.image_paths = image_paths
        self.annot_paths = annot_paths
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.on_epoch_end()
    def __len__(self):
        return int(np.floor(len(self.image_paths) / self.batch_size))
    def __getitem__(self, index):
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
        image_paths = [self.image_paths[k] for k in indexes]
        annot_paths = [self.annot_paths[k] for k in indexes]
        X, y = self.__data_generation(image_paths, annot_paths)
        return X, y
    def on_epoch_end(self):
        self.indexes = np.arange(len(self.image_paths))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)
    def get_poly(self, annot_path):
        with open(annot_path) as handle:
            data = json.load(handle)
        shape_dicts = data['shapes']
        return shape_dicts
    def create_binary_masks(self, im, shape_dicts):
        blank = np.zeros(shape=(im.shape[0], im.shape[1]), dtype=np.float32)
        for shape in shape_dicts:
            if shape['label'] != 'background':
                points = np.array(shape['points'], dtype=np.int32)
                cv2.fillPoly(blank, [points], 255)
        blank = blank / 255.0
        return np.expand_dims(blank, axis=2)
    def create_multi_masks(self, im, shape_dicts):
        channels = []
        cls = [x['label'] for x in shape_dicts]
        poly = [np.array(x['points'], dtype=np.int32) for x in shape_dicts]
        label2poly = dict(zip(cls, poly))
        background = np.zeros(shape=(im.shape[0], im.shape[1]),dtype=np.float32)
        for i, label in enumerate(labels):
            blank = np.zeros(shape=(im.shape[0], im.shape[1]), dtype=np.float32)
            if label in cls:
                cv2.fillPoly(blank, [label2poly[label]], 255)
                cv2.fillPoly(background, [label2poly[label]], 255)
            channels.append(blank)
        if 'background' in cls:
            background = np.zeros(shape=(im.shape[0],
                                         im.shape[1]), dtype=np.float32)
            cv2.fillPoly(background, [label2poly['background']], 255)
        else:
            _, background = cv2.threshold(background, 127, 255,
                                          cv2.THRESH_BINARY_INV)
        channels.append(background)
        Y = np.stack(channels, axis=2) / 255.0
        return Y
    def __data_generation(self, image_paths, annot_paths):
        X = np.empty((self.batch_size,
                      imshape[0], imshape[1], imshape[2]), dtype=np.float32)
        Y = np.empty((self.batch_size,
                      imshape[0], imshape[1], n_classes),  dtype=np.float32)
        for i, (im_path, annot_path) in enumerate(zip(image_paths,annot_paths)):
            if imshape[2] == 1:
                im = cv2.imread(im_path, 0)
                im = np.expand_dims(im, axis=2)
            elif imshape[2] == 3:
                im = cv2.imread(im_path, 1)
                im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
            shape_dicts = self.get_poly(annot_path)
            if n_classes == 1:
                mask = self.create_binary_masks(im, shape_dicts)
            elif n_classes > 1:
                mask = self.create_multi_masks(im, shape_dicts)
            X[i,] = im
            Y[i,] = mask
        return X, Y

def unet(pretrained=False, base=4):
    if n_classes == 1:
        loss = 'binary_crossentropy'
        final_act = 'sigmoid'
    elif n_classes > 1:
        loss = 'categorical_crossentropy'
        final_act = 'softmax'
    b = base
    i = Input((imshape[0], imshape[1], imshape[2]))
    s = Lambda(lambda x: preprocess_input(x)) (i)
    c1 = Conv2D(2**b, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (s)
    c1 = Dropout(0.1) (c1)
    c1 = Conv2D(2**b, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (c1)
    p1 = MaxPooling2D((2, 2)) (c1)
    c2 = Conv2D(2**(b+1), (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (p1)
    c2 = Dropout(0.1) (c2)
    c2 = Conv2D(2**(b+1), (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (c2)
    p2 = MaxPooling2D((2, 2)) (c2)
    c3 = Conv2D(2**(b+2), (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (p2)
    c3 = Dropout(0.2) (c3)
    c3 = Conv2D(2**(b+2), (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (c3)
    p3 = MaxPooling2D((2, 2)) (c3)
    c4 = Conv2D(2**(b+3), (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (p3)
    c4 = Dropout(0.2) (c4)
    c4 = Conv2D(2**(b+3), (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (c4)
    p4 = MaxPooling2D(pool_size=(2, 2)) (c4)
    c5 = Conv2D(2**(b+4), (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (p4)
    c5 = Dropout(0.3) (c5)
    c5 = Conv2D(2**(b+4), (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (c5)
    u6 = Conv2DTranspose(2**(b+3), (2, 2), strides=(2, 2), padding='same') (c5)
    u6 = concatenate([u6, c4])
    c6 = Conv2D(2**(b+3), (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (u6)
    c6 = Dropout(0.2) (c6)
    c6 = Conv2D(2**(b+3), (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (c6)
    u7 = Conv2DTranspose(2**(b+2), (2, 2), strides=(2, 2), padding='same') (c6)
    u7 = concatenate([u7, c3])
    c7 = Conv2D(2**(b+2), (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (u7)
    c7 = Dropout(0.2) (c7)
    c7 = Conv2D(2**(b+2), (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (c7)
    u8 = Conv2DTranspose(2**(b+1), (2, 2), strides=(2, 2), padding='same') (c7)
    u8 = concatenate([u8, c2])
    c8 = Conv2D(2**(b+1), (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (u8)
    c8 = Dropout(0.1) (c8)
    c8 = Conv2D(2**(b+1), (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (c8)
    u9 = Conv2DTranspose(2**b, (2, 2), strides=(2, 2), padding='same') (c8)
    u9 = concatenate([u9, c1], axis=3)
    c9 = Conv2D(2**b, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (u9)
    c9 = Dropout(0.1) (c9)
    c9 = Conv2D(2**b, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (c9)
    o = Conv2D(n_classes, (1, 1), activation=final_act) (c9)
    model = Model(inputs=i, outputs=o, name=model_name)
    model.compile(optimizer=Adam(1e-4), loss=loss, metrics=[dice])
    #model.summary()
    if pretrained:
        path = model_name+'.h5'
        if os.path.exists(path):
            model.load_weights(path)
            print('Loaded weights')
        else:
            print('Failed to load existing weights at: {}'.format(path))
    return model

def fcn_8(pretrained=False, base=4):
    if n_classes == 1:
        loss = 'binary_crossentropy'
        final_act = 'sigmoid'
    elif n_classes > 1:
        loss = 'categorical_crossentropy'
        final_act = 'softmax'
    b = base
    i = Input(shape=imshape)
    s = Lambda(lambda x: preprocess_input(x)) (i)
    x = Conv2D(2**b, (3, 3), activation='elu', padding='same', name='block1_conv1')(s)
    x = Conv2D(2**b, (3, 3), activation='elu', padding='same', name='block1_conv2')(x)
    x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x)
    f1 = x
    x = Conv2D(2**(b+1), (3, 3), activation='elu', padding='same', name='block2_conv1')(x)
    x = Conv2D(2**(b+1), (3, 3), activation='elu', padding='same', name='block2_conv2')(x)
    x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x)
    f2 = x
    x = Conv2D(2**(b+2), (3, 3), activation='elu', padding='same', name='block3_conv1')(x)
    x = Conv2D(2**(b+2), (3, 3), activation='elu', padding='same', name='block3_conv2')(x)
    x = Conv2D(2**(b+2), (3, 3), activation='elu', padding='same', name='block3_conv3')(x)
    x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x)
    pool3 = x
    x = Conv2D(2**(b+3), (3, 3), activation='elu', padding='same', name='block4_conv1')(x)
    x = Conv2D(2**(b+3), (3, 3), activation='elu', padding='same', name='block4_conv2')(x)
    x = Conv2D(2**(b+3), (3, 3), activation='elu', padding='same', name='block4_conv3')(x)
    pool4 = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x)
    x = Conv2D(2**(b+3), (3, 3), activation='elu', padding='same', name='block5_conv1')(pool4)
    x = Conv2D(2**(b+3), (3, 3), activation='elu', padding='same', name='block5_conv2')(x)
    x = Conv2D(2**(b+3), (3, 3), activation='elu', padding='same', name='block5_conv3')(x)
    pool5 = MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(x)
    conv6 = Conv2D(2048 , (7, 7) , activation='elu' , padding='same', name="conv6")(pool5)
    conv6 = Dropout(0.5)(conv6)
    conv7 = Conv2D(2048 , (1, 1) , activation='elu' , padding='same', name="conv7")(conv6)
    conv7 = Dropout(0.5)(conv7)
    pool4_n = Conv2D(n_classes, (1, 1), activation='elu', padding='same')(pool4)
    u2 = Conv2DTranspose(n_classes, kernel_size=(2, 2), strides=(2, 2), padding='same')(conv7)
    u2_skip = Add()([pool4_n, u2])
    pool3_n = Conv2D(n_classes, (1, 1), activation='elu', padding='same')(pool3)
    u4 = Conv2DTranspose(n_classes, kernel_size=(2, 2), strides=(2, 2), padding='same')(u2_skip)
    u4_skip = Add()([pool3_n, u4])
    o = Conv2DTranspose(n_classes, kernel_size=(8, 8), strides=(8, 8), padding='same', activation=final_act)(u4_skip)
    model = Model(inputs=i, outputs=o, name=model_name)
    model.compile(optimizer=Adam(1e-4), loss=loss, metrics=[dice])
    #model.summary()
    if pretrained:
        path = model_name+'.h5'
        if os.path.exists(path):
            model.load_weights(path)
            print('Loaded weights')
        else:
            print('Failed to load existing weights at: {}'.format(path))
    return model

def sorted_fns(dir):
    return sorted(os.listdir(dir), key=lambda x: x.split('.')[0])

def preprocess_input(x):
    x /= 255.
    x -= 0.5
    x *= 2.
    return x

def dice(y_true, y_pred, smooth=1.):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2.*intersection+smooth)/(K.sum(y_true_f)+K.sum(y_pred_f) + smooth)

def add_masks(pred):
    blank = np.zeros(shape=imshape, dtype=np.uint8)
    for i, label in enumerate(labels):
        hue = np.full(shape=(imshape[0], imshape[1]), fill_value=hues[label], dtype=np.uint8)
        sat = np.full(shape=(imshape[0], imshape[1]), fill_value=255, dtype=np.uint8)
        val = pred[:,:,i].astype(np.uint8)
        im_hsv = cv2.merge([hue, sat, val])
        im_rgb = cv2.cvtColor(im_hsv, cv2.COLOR_HSV2RGB)
        blank = cv2.add(blank, im_rgb)
    return blank

def crf(im_softmax, im_rgb):
    n_classes = im_softmax.shape[2]
    feat_first = im_softmax.transpose((2, 0, 1)).reshape(n_classes, -1)
    unary = unary_from_softmax(feat_first)
    unary = np.ascontiguousarray(unary)
    im_rgb = np.ascontiguousarray(im_rgb)
    d = dcrf.DenseCRF2D(im_rgb.shape[1], im_rgb.shape[0], n_classes)
    d.setUnaryEnergy(unary)
    d.addPairwiseGaussian(sxy=(5, 5), compat=3, kernel=dcrf.DIAG_KERNEL,
                              normalization=dcrf.NORMALIZE_SYMMETRIC)
    d.addPairwiseBilateral(sxy=(5, 5), srgb=(13, 13, 13), rgbim=im_rgb,
                           compat=10,
                           kernel=dcrf.DIAG_KERNEL,
                           normalization=dcrf.NORMALIZE_SYMMETRIC)
    Q = d.inference(5)
    res = np.argmax(Q, axis=0).reshape((im_rgb.shape[0], im_rgb.shape[1]))
    if mode is 'binary':
        return res * 255.0
    if mode is 'multi':
        res_hot = to_categorical(res) * 255.0
        res_crf = add_masks(res_hot)
        return res_crf

def train():
	image_paths=[os.path.join('./dataset/Train', x) for x in sorted_fns('./dataset/Train')]
	annot_paths=[os.path.join('./dataset/Annotations', x) for x in sorted_fns('./dataset/Annotations')]
	if 'unet' in model_name:
		model = unet(pretrained=True, base=4)
	elif 'fcn_8' in model_name:
		model = fcn_8(pretrained=True, base=4)
	tg = DataGenerator(image_paths=image_paths,
                    annot_paths=annot_paths,
                    batch_size=5)
	checkpoint = ModelCheckpoint(model_name+'.h5',
	monitor='dice', verbose=1, mode='max', save_best_only=True,
    save_weights_only=True, period=10)
	model.fit_generator(generator=tg,
                     steps_per_epoch=len(tg),
                     epochs=500,
                     verbose=1,
                     callbacks=[checkpoint])

def predict(filename, CALC_CRF=True):
    model = model_name(pretrained=True, base=4)
    model.model.load_weights(model_name+'.h5')
#    model = load_model(model_name+'.h5',custom_objects={'dice': dice})
    im_cv = cv2.imread(filename)
    im = cv2.cvtColor(im_cv, cv2.COLOR_BGR2RGB).copy()
    tmp = np.expand_dims(im, axis=0)
    roi_pred = model.predict(tmp)
    if n_classes == 1:
        roi_mask = roi_pred.squeeze()*255.0
        roi_mask = cv2.cvtColor(roi_mask, cv2.COLOR_GRAY2RGB)
    elif n_classes > 1:
        roi_mask = add_masks(roi_pred.squeeze()*255.0)
    if CALC_CRF:
        if n_classes == 1:
            roi_pred = roi_pred.squeeze()
            roi_softmax = np.stack([1-roi_pred, roi_pred], axis=2)
            roi_mask = crf(roi_softmax, im)
            roi_mask = np.array(roi_mask, dtype=np.float32)
            roi_mask = cv2.cvtColor(roi_mask, cv2.COLOR_GRAY2RGB)
        elif n_classes > 1:
            roi_mask = crf(roi_pred.squeeze(), im)
    pos = np.count_nonzero(roi_mask)/3 # Number of white pixels
    neg = np.count_nonzero(roi_mask==0)/3
    print('Positive white pixels {}'.format(pos))
    cv2.imwrite('masked_{}'.format(filename), roi_mask)

train()