## Self-training

The execution of this notebook denotes one iteration in the self-training process.

The oracle is assumed to be already trained an stored in `models/oracle`.

Firstly a u-net is trained using the supervised data in the first iteration.

The trained segmenter computes segmentation for unsupervised data. 

The oracle determines the highly confident predictions (either in class 5 or 0 in the `buckets_classification` modality).

A sampling using the size of the computed liver is performed. All selected samples are inserted into the training data (see last line).

Re-execute the notebook to carry out a new iteration.

In [None]:
import h5py, random
import tensorflow as tf
import numpy as np
from libs.generators.utils import get_case_length, get_x_slice, get_y_slice
from libs.models.u_net import get_model_unet
from libs.postprocessing import self_ensembling

In [None]:
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.Session(config=config)

In [None]:
model = get_model_unet(input_size =(416, 416, 1), feature_maps = 16, output_layers=1, output_type='sigmoid')

In [None]:
from libs.generators.ssl_batch_generator import SemiSupervisedBatchGenerator

gen = SemiSupervisedBatchGenerator()

## Training

In [None]:
from keras.optimizers import Adam
from libs.metrics import dice_coef_sig

model.compile(optimizer= Adam(lr = 1e-4, clipnorm = 1., clipvalue = 0.5), loss = 'binary_crossentropy', metrics = [ dice_coef_sig, 'binary_accuracy'])

In [None]:
model.fit_generator(gen, epochs = 120, verbose= 1, workers = 32, max_queue_size = 100)

In [None]:
model.compile(optimizer= Adam(lr = 1e-5, clipnorm = 1., clipvalue = 0.5), loss = 'binary_crossentropy', metrics = [dice_coef_sig, 'binary_accuracy'])

In [None]:
model.fit_generator(gen, epochs = 30, verbose = 1, workers = 32, max_queue_size = 100)

In [None]:
from libs.keras_checkpoints import load_model

oracle = load_model('models/oracle')

## Inference and collection of highly confident data by the oracle

In [None]:
def get_unsupervised_samples(supervised= 0):
    '''Getter of unsupervised data indices.
    @param supervised The id of the supervised volume
    @return indices of unsupervised data
    '''
    indices = []
    with h5py.File('data/training_data.h5', 'r') as hdf:
        for i in range(0, 20):
            if i == supervised:
                continue
            l = get_case_length(hdf, i)
            for j in range(l):
                indices.append((i, j))
        return indices

In [None]:
indices = get_unsupervised_samples()

In [None]:
def batch_normalization(x):
    return (x - x.mean()) / x.std()

def stack_segmentation_rgb(x, y):
    b = np.full(x.shape, x.mean(), 'float32')
    g = x + y.reshape(416, 416, 1)
    return batch_normalization(np.concatenate((x, g, b), axis= -1))

In [None]:
collected_high_iou = []
collected_undefined_iou = []

with h5py.File('data/training_data.h5', 'r') as hdf:
    print('Start evaluating unlabeled data')
    for i in indices:
        x = get_x_slice(hdf, index[0], index[1])
        P = self_ensembling(x, model)
        rgb = stack_segmentation_rgb(x, p)
        oracle_verdict = np.argmax(oracle.predict(rgb.reshape(1, 416, 416, 3))[0])
        if oracle_verdict == 5: # in P_{high}
            collected_high_iou.append((index[0], index[1], P, P.sum()))
        if oracle_verdict == 0 and P.sum() == 0: # in P_{nan}
            collected_undefined_iou.append((index[0], index[1], P, P.sum()))

## Sampling using the size of the predicted livers

In [None]:
def sampling(collected_data, generator, liver_area = 35000):
    tau = 50 # min data to add at each iteration
    
    Pn = 0
    while Pn < tau and 0 < liver_area :
        liver_area = liver_area - 1000
        sampled = filter(lambda x: x[3] > t, collected_data)
        Pn = generator.count_new_samples(sampled)
    return sampled

In [None]:
high_confident_data = sampling(collected_high_iou, gen, 35000)
random.shuffle(collected_undefined_iou)
sampled_nan = collected_undefined_iou[:len(high_confident_data)]

## Add data in semi-supervised dataset

In [None]:
gen.add_separate_samples(high_confident_data, sampled_nan)

## Restart the process