In [1]:
import tensorflow as tf

In [2]:
print(tf.__version__)

2.3.0


In [3]:
from model import *
from data import *

# UNET SEGMENTATION (Fully convolutional network)

## Train your Unet with membrane data

membrane data is in folder /data/membrane/

   - training membrane data is in folder /data/membrane/train
       * training data : greyscale img
       * training label : binary mask (edge : 1, cytoplasm : 0)

   - testing membrane data is in folder /data/membrane/test
       * training data : greyscale img
       * training label : binary mask (edge : 1, cytoplasm : 0)


The input shape of image and mask are the same :
(batch_size, rows, cols, channel = 1)

## Train 

### Data generator

In [4]:
# set of allowed transormations for data aug.
data_gen_args = dict(rotation_range=0.2,
                    width_shift_range=0.05,
                    height_shift_range=0.05,
                    shear_range=0.05,
                    zoom_range=0.05,
                    horizontal_flip=True,
                    fill_mode='nearest')

# instantiation of train Generator
myGene = trainGenerator(2,
                        'data/membrane/train',
                        'image',
                        'label',
                        data_gen_args,
                        save_to_dir = None)

### Checkpoint

In [5]:
#checkpoint with output path where we monitor the loss and we save the best weight (only)
model_checkpoint = ModelCheckpoint('unet_membrane.hdf5', 
                                   monitor='loss',
                                   verbose=1, 
                                   save_best_only=True)

### Training step

In [6]:
EPOCHS = 5
STEPS_PER_EPOCH = 2000

In [7]:
# unet defined in model.py
model = unet()
# train unet() by using a generator (not fit everything in memory)
model.fit_generator(myGene,
                    steps_per_epoch=STEPS_PER_EPOCH,
                    epochs=EPOCHS,
                    callbacks=[model_checkpoint])

Model: "functional_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 256, 256, 1) 0                                            
__________________________________________________________________________________________________
ccp_block (CCPBlock)            ((None, 256, 256, 64 37568       input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 128, 128, 128 73856       ccp_block[0][1]                  
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 128, 128, 128 147584      conv2d_2[0][0]                   
_______________________________________________________________________________________

Found 30 images belonging to 1 classes.
Found 30 images belonging to 1 classes.
Epoch 1/5
   5/2000 [..............................] - ETA: 1:26:06 - loss: 0.7002 - accuracy: 0.7886

KeyboardInterrupt: 

## Test

### Load your model 

In [None]:
# instantiate test Generator
testGene = testGenerator("data/membrane/test")

In [None]:
# unet defined in model.py
model = unet()
# load weights in unet model
model.load_weights("unet_membrane.hdf5")



### Batch testing and save predicted results

In [None]:
# prediction performed by the model (by batch)
results = model.predict_generator(testGene,
                                  30,
                                  verbose=1)

#saveResult (function defined indata.py)
saveResult("data/membrane/test",results)