In [None]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append('..')
import os
import glob

import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split

from rainforest.models.unet import get_unet
from rainforest.data_generators import get_data_with_masks
import paths

from keras.callbacks import CSVLogger, ReduceLROnPlateau, ModelCheckpoint

%matplotlib inline

In [None]:
model = get_unet()

In [None]:
masks = glob.glob(os.path.join(paths.DATA_FOLDER, 'slash_burn/*.png'))
images = [m.replace('.mask.0.png', '.bmp') for m in masks]
train_imgs, val_imgs, train_masks, val_masks = train_test_split(images, masks)
train_files = zip(train_imgs, train_masks)
val_files = zip(val_imgs, val_masks)

In [None]:
batch_size = 16

In [None]:
train_gen = get_data_with_masks(train_files, batch_size=batch_size, hflip=True, vflip=True, shift_x=3, shift_y=3, rot_range=-5)
val_gen = get_data_with_masks(val_files, batch_size=batch_size)

In [None]:
csv_logger = CSVLogger('log.csv')
lr_plateau = ReduceLROnPlateau(monitor='val_loss', patience=3, verbose=1, factor=0.5)
checkpoint = ModelCheckpoint(filepath='E:/Models/brainforest/unet1_model.{epoch:02d}-{val_loss:.2f}.hdf5',
                             verbose=1, save_best_only=True)

model.fit_generator(train_gen, steps_per_epoch=len(train_files) // batch_size,
                    epochs=100, verbose=1,
                    callbacks=[csv_logger, lr_plateau, checkpoint],
                    validation_data=val_gen, validation_steps=len(val_files) // batch_size)

In [None]:
model.load_weights('E:/Models/brainforest/unet1_model.07-2.58.hdf5')
imgs, masks = next(val_gen)
for img, mask in zip(imgs, masks):
    pred = model.predict(np.expand_dims(img, 0))[0]
    plt.subplot(131)
    plt.imshow(img.transpose(1, 2, 0) + 0.5)
    plt.subplot(132)
    plt.imshow(mask[1])
    plt.subplot(133)
    plt.imshow(pred[1], vmin=0, vmax=1)
    plt.show()

In [None]:
imgs.mean()