In [None]:
import numpy as np
import math
import matplotlib.pyplot as plt
import os

import keras
from keras.models import Model
from keras.layers.merge import concatenate
from keras.layers import Activation, Dense, Input
from keras.layers import Conv2D, Flatten
from keras.layers import LeakyReLU


from keras.optimizers import adam


from keras.layers import Dense, Activation, ZeroPadding2D, BatchNormalization, Flatten, Conv2D, Conv2DTranspose
from keras.layers import Input, AveragePooling2D, MaxPooling2D, Dropout, Lambda, AlphaDropout
from keras.layers.merge import concatenate
from keras import backend as K
from keras.models import Model


from sklearn.utils import shuffle

from keras.callbacks import ModelCheckpoint

import tensorflow as tf
from keras.callbacks import TensorBoard

import rasterio
import glob

# Utils

In [None]:
def make_trainable(net, val):
    # net.trainable = val
    for l in net.layers:
        l.trainable = val

# Networks

In [None]:
def Wnet(inp_dsm, inp_pan, blocks_list, k_size, activation, n_labels=1, name=None):
    """
    input:
        n_labels, int, number of labels = 1
        blocks list, list, number of filters in each block
        k_size, tuple, filter size
        activation, string, activation function
        
    output:
        keras model
    """
    
    
    # PAN
    
    k_init = 'lecun_normal'

    if K.image_data_format() == 'channels_first':
        concat_axis = 1
    else:
        concat_axis = 3
        
    encoder_pan = inp_pan
    
    list_encoders = []
    
    print('Building Unet for PAN Image')
    print(blocks_list)   
    
    with K.name_scope('PAN_UNet'):
        for l_idx, n_ch in enumerate(blocks_list):
            with K.name_scope('Encoder_block_{0}'.format(l_idx)):
                encoder_pan = Conv2D(filters=n_ch,
                                 kernel_size=k_size,
                                 activation=activation,
                                 padding='same',
                                 kernel_initializer=k_init)(encoder_pan)
                encoder_pan = AlphaDropout(0.1*l_idx, )(encoder_pan)
                encoder_pan = Conv2D(filters=n_ch,
                                 kernel_size=k_size,
                                 dilation_rate=(2, 2),
                                 activation=activation,
                                 padding='same',
                                 kernel_initializer=k_init)(encoder_pan)
                list_encoders.append(encoder_pan)
                # add maxpooling layer except the last layer
                if l_idx < len(blocks_list) - 1:
                    encoder_pan = MaxPooling2D(pool_size=(2,2))(encoder_pan)
                # if use_tfboard:
                    # tf.summary.histogram('conv_encoder', encoder)
        # decoders
        decoder_pan = encoder_pan
        dec_n_ch_list = blocks_list[::-1][1:]
        print(dec_n_ch_list)
        for l_idx, n_ch in enumerate(dec_n_ch_list):
            with K.name_scope('Decoder_block_{0}'.format(l_idx)):
                l_idx_rev = len(blocks_list) - 1 - l_idx
                decoder_pan = concatenate([decoder_pan, list_encoders[l_idx_rev]], axis=concat_axis)
                decoder_pan = Conv2D(filters=n_ch,
                                 kernel_size=k_size,
                                 activation=activation,
                                 padding='same',
                                 dilation_rate=(2, 2),
                                 kernel_initializer=k_init)(decoder_pan)
                decoder_pan = AlphaDropout(0.1*l_idx, )(decoder_pan)
                decoder_pan = Conv2D(filters=n_ch,
                                 kernel_size=k_size,
                                 activation=activation,
                                 padding='same',
                                 kernel_initializer=k_init)(decoder_pan)
                decoder_pan = Conv2DTranspose(filters=n_ch,
                                          kernel_size=k_size,
                                          strides=(2, 2), 
                                          activation=activation,
                                          padding='same',
                                          kernel_initializer=k_init)(decoder_pan)

        # output layer should be softmax
        outp_pan = Conv2DTranspose(filters=n_labels,
                               kernel_size=k_size,
                               activation='sigmoid',
                               padding='same',
                               kernel_initializer='glorot_normal')(decoder_pan)
    
    ### DSM
    
    encoder_dsm = inp_dsm
    
    list_encoders_dsm = []
    
    print('Building Unet for DSM')
    print(blocks_list)   
    
    with K.name_scope('DSM_UNet'):
        for l_idx, n_ch in enumerate(blocks_list):
            with K.name_scope('Encoder_block_{0}'.format(l_idx)):
                encoder_dsm = Conv2D(filters=n_ch,
                                 kernel_size=k_size,
                                 activation=activation,
                                 padding='same',
                                 kernel_initializer=k_init)(encoder_dsm)
                encoder_dsm = AlphaDropout(0.1*l_idx, )(encoder_dsm)
                encoder_dsm = Conv2D(filters=n_ch,
                                 kernel_size=k_size,
                                 dilation_rate=(2, 2),
                                 activation=activation,
                                 padding='same',
                                 kernel_initializer=k_init)(encoder_dsm)
                list_encoders_dsm.append(encoder_dsm)
                # add maxpooling layer except the last layer
                if l_idx < len(blocks_list) - 1:
                    encoder_dsm = MaxPooling2D(pool_size=(2,2))(encoder_dsm)
                # if use_tfboard:
                    # tf.summary.histogram('conv_encoder', encoder)
        # decoders
        decoder_dsm = encoder_dsm
        dec_n_ch_list = blocks_list[::-1][1:]
        print(dec_n_ch_list)
        for l_idx, n_ch in enumerate(dec_n_ch_list):
            with K.name_scope('Decoder_block_{0}'.format(l_idx)):
                l_idx_rev = len(blocks_list) - 1 - l_idx
                decoder_dsm = concatenate([decoder_dsm, list_encoders[l_idx_rev]], axis=concat_axis)
                decoder_dsm = Conv2D(filters=n_ch,
                                 kernel_size=k_size,
                                 activation=activation,
                                 padding='same',
                                 dilation_rate=(2, 2),
                                 kernel_initializer=k_init)(decoder_dsm)
                decoder_dsm = AlphaDropout(0.1*l_idx, )(decoder_dsm)
                decoder_dsm = Conv2D(filters=n_ch,
                                 kernel_size=k_size,
                                 activation=activation,
                                 padding='same',
                                 kernel_initializer=k_init)(decoder_dsm)
                decoder_dsm = Conv2DTranspose(filters=n_ch,
                                          kernel_size=k_size,
                                          strides=(2, 2), 
                                          activation=activation,
                                          padding='same',
                                          kernel_initializer=k_init)(decoder_dsm)
        
        # output layer should be softmax
        outp_dsm = Conv2DTranspose(filters=n_labels,
                               kernel_size=k_size,
                               activation='sigmoid',
                               padding='same',
                               kernel_initializer='glorot_normal')(decoder_dsm)
        
        outp = concatenate([outp_dsm, outp_pan], axis=concat_axis)
        outp = Conv2D(filters=1, kernel_size=(1,1), padding='same', kernel_initializer='lecun_normal')(outp)

    return Model(inputs=[inp_dsm,inp_pan], outputs=[outp], name=name)

In [None]:
def DiscriminatorNet(inp_DSM, inp_Label, block_list, activation, k_size=(3,3), inputs_ch=64, name='DISCR'):
    
    if K.image_data_format() == 'channels_first':
        concat_axis = 1
    else:
        concat_axis = 3

    k_init = 'lecun_normal'
    with K.name_scope('DiscriminatorNet'):
        with K.name_scope('DSM_input_conv'):
            X = Conv2D(filters=inputs_ch,
                       kernel_size=(1,1),
                       activation=activation,
                       padding='same',
                       kernel_initializer=k_init)(inp_DSM)
        with K.name_scope('Label_input_conv'):  
            Y = Conv2D(filters=inputs_ch,
                       kernel_size=(1,1),
                       activation=activation,
                       padding='same',
                       kernel_initializer=k_init)(inp_Label)
            
        encoder = concatenate([X, Y], axis=concat_axis) 
        for l_idx, n_ch in enumerate(block_list):  #something like [32,32,32,32,32]
            with K.name_scope('Discr_block_{0}'.format(l_idx)):
                encoder = Conv2D(filters=n_ch,
                                 kernel_size=k_size,
                                 activation=activation,
                                 padding='same',
                                 kernel_initializer=k_init)(encoder)
                # encoder = AlphaDropout(0.1*l_idx, )(encoder)
                # add maxpooling layer except the last layer
                if l_idx < len(block_list) - 1:
                    encoder = MaxPooling2D(pool_size=(2,2))(encoder)
        encoder = Flatten()(encoder)
        outp = Dense(1, activation='sigmoid')(encoder)
    
    return Model(inputs=[inp_DSM, inp_Label], outputs=outp, name=name)

# Architecture (Model)

In [None]:
class Wnet_cgan:
    def __init__(self,
                 height, 
                 width,
                 n_labels=1):
        
        if K.image_data_format() == 'channels_first':
            input_shape = (1, height, width) #define laebl_shape separately in case of multiple labels of roof
            concat_axis = 1
        else:
            input_shape = (height, width, 1)
            concat_axis = 3
            
        self.pan_shape = self.dsm_shape = self.label_shape = input_shape
        self.init_epoch = 0
        self.n_labels = n_labels
        
    def build_wnet_cgan(self,
                        wnet_block_list,
                        wnet_k_size, 
                        wnet_activation='selu',
                        wnet_lr=1e-4,
                        discr_inp_channels = 64,
                        discr_block_list=[32,32,32,32,32],
                        discr_k_size=(3,3), 
                        discr_activation='relu',
                        discr_lr=1e-4,
                        lambda_=1e-1):
        inp_dsm = Input(self.dsm_shape, name='dsm_input')
        inp_pan = Input(self.pan_shape, name='pan_input')
        inp_label = Input(self.label_shape, name='label_input')

        wnet_opt = adam(lr=wnet_lr)
        discr_opt = adam(lr=discr_lr)

        # build the Discriminator
        print('Build discr')
        self.discriminator = DiscriminatorNet(inp_dsm,
                                              inp_label,
                                              discr_block_list,
                                              discr_activation,
                                              discr_k_size,
                                              discr_inp_channels,
                                              'Discriminator')
        print('Done')
        # make Discriminator untrainable and copy it to 'frozen Discriminator' (like pyTorch's detach()?!)
        make_trainable(self.discriminator, False)

        frozen_discriminator = Model(inputs=self.discriminator.inputs,
                                     outputs=self.discriminator.outputs,
                                     name='frozen_discriminator')
        frozen_discriminator.compile(discr_opt,
                                     loss = 'binary_crossentropy',
                                     metrics=['accuracy'])
        #print('Frozen and compiled')
        # build the wnet
        #print('Build Wnet')
        self.wnet = Wnet(inp_dsm, 
                         inp_pan, 
                         wnet_block_list, 
                         wnet_k_size, 
                         wnet_activation, 
                         self.n_labels, 
                         name='Wnet')

        #compile the wnet
        self.wnet.compile(wnet_opt,
                          loss = 'binary_crossentropy',
                          metrics=['accuracy'])  # CHANGE TO mIoU !!!!!!!!

        #print('Compiled Wnet') 
        # get the wnet prediction
        pred = self.wnet([inp_dsm, inp_pan])
        #print('got pred from Wnet')
        # input the prediction into the frozen discriminator and get the probability fake/real
        prob = frozen_discriminator([inp_dsm, pred])
        #print('got prob from frozen Discr')
        # stack wnet and discriminator to form the Wnet-CGAN
        #print('stacking the two')
        self.wnet_cgan = Model(inputs=[inp_dsm, inp_pan, inp_label],
                               outputs=[pred, prob],
                               name='WNet-CGAN')
        #print('stacked')
        # compile it
        #print('compiling the stcaked')
        self.wnet_cgan.compile(wnet_opt,
                               loss=['binary_crossentropy', 'binary_crossentropy'],
                               loss_weights=[1., lambda_],
                               metrics=['accuracy'])
        #print('compiled')
        #print(wnet_cgan.summary())

        # compile the discriminator
        make_trainable(self.discriminator, True)
        self.discriminator.compile(discr_opt,
                                   loss='binary_crossentropy',
                                   metrics=['accuracy'])

        #print(self.discriminator.summary())
            
    def fit_wnet_cgan(self,
                      train_generator,
                      valid_generator,
                      adv_epochs=10,
                      adv_steps_epoch=100,
                      gen_epochs=20,
                      gen_steps_epoch=100,
                      validation_steps=4,
                      n_rounds=10):

        discr_callbacks = self.build_callbacks(monitor='val_acc', phase='discr')
        gen_callbacks = self.build_callbacks(monitor='val_acc', phase='gen')
              
        for i in range(n_rounds):
            #train discriminator first
            #self.discriminator.fit(x=discr_X, 
                                   #y=discr_Y,
                                   #epochs=(i+1)*adv_epochs,
                                   #callbacks=discr_callbacks,
                                   #validation_split=0.2,
                                   #validation_steps=validation_steps,
                                   #shuffle=True,
                                   #steps_per_epoch=adv_steps_epoch,
                                   #initial_epoch=i*adv_epochs,
                                   #verbose=0)
            train_generator.phase='gen'
            valid_generator.phase='gen'
            self.wnet_cgan.fit_generator(generator=train_generator,
                               validation_data=valid_generator,
                               epochs=(i+1)*gen_epochs,
                               callbacks=gen_callbacks,
                               validation_steps=validation_steps,
                               shuffle=True,
                               steps_per_epoch=gen_steps_epoch,
                               initial_epoch=i*gen_epochs,
                               verbose=1)
            
            # Sub training-dataset for disciminator
            #pred = self.wnet.predict([X[0],X[1]])

            #discr_X_1, discr_X_2 = np.concatenate((X[0],X[0]), axis=0), np.concatenate((X[2],pred), axis=0)
            #discr_Y = np.concatenate((Y[1],np.ones(shape=(len(pred),1))),axis=0)
            
            #discr_X_1, discr_X2, discr_Y = shuffle(discr_X_1, discr_X_2, discr_Y, random_state=42)
            #discr_X = [discr_X_1, discr_X_2]
            
            self.wnet._make_predict_function()
            train_generator.pred_fn = self.wnet.predict
            valid_generator.pred_fn = self.wnet.predict
            train_generator.phase='discr'
            valid_generator.phase='discr'
            
            # train discriminator last
            self.discriminator.fit_generator(generator=train_generator, 
                                   validation_data=valid_generator,
                                   epochs=(i+1)*adv_epochs,
                                   callbacks=discr_callbacks,
                                   validation_steps=validation_steps,
                                   shuffle=True,
                                   steps_per_epoch=adv_steps_epoch,
                                   initial_epoch=i*adv_epochs,
                                   verbose=0)
    
            
    def build_callbacks(self, use_tfboard=True, monitor=None, phase=None, save=False):
        
        if phase == 'gen':
            path = './results/gen'
        elif phase == 'discr':
            path = './results/discr'

        # Model Checkpoints
        if monitor is None:
            callbackList = []
        else:
            if not os.path.exists(path):
                os.makedirs(path)
            filepath=path+'/weights-{epoch:02d}.hdf5'
            checkpoint = ModelCheckpoint(filepath,
                                         monitor=monitor,
                                         verbose=1,
                                         save_best_only=save,
                                         save_weights_only=True,
                                         mode='max')

            # Bring all the callbacks together into a python list
            callbackList = [checkpoint]
                    
        # Tensorboard
        if use_tfboard:
            if phase is None:
                tfpath = './logs'
            else:
                tfpath = './logs/{0}'.format(phase)
            tensorboard = TrainValTensorBoard(log_dir=tfpath)
            callbackList.append(tensorboard)
        return callbackList
    
    
class TrainValTensorBoard(TensorBoard):
    def __init__(self, log_dir='./logs', phase=None, hist_freq=0, **kwargs):
        # Make the original `TensorBoard` log to a subdirectory 'training'
        training_log_dir = os.path.join(log_dir, 'training')
        super(TrainValTensorBoard, self).__init__(training_log_dir, histogram_freq=hist_freq, **kwargs)

        # Log the validation metrics to a separate subdirectory
        self.val_log_dir = os.path.join(log_dir, 'validation')

    def set_model(self, model):
        # Setup writer for validation metrics
        self.val_writer = tf.summary.FileWriter(self.val_log_dir)
        super(TrainValTensorBoard, self).set_model(model)

    def on_epoch_end(self, epoch, logs=None):
        # Pop the validation logs and handle them separately with
        # `self.val_writer`. Also rename the keys so that they can
        # be plotted on the same figure with the training metrics
        logs = logs or {}
        val_logs = {k.replace('val_', ''): v for k, v in logs.items() if k.startswith('val_')}
        for name, value in val_logs.items():
            summary = tf.Summary()
            summary_value = summary.value.add()
            summary_value.simple_value = value.item()
            summary_value.tag = name
            self.val_writer.add_summary(summary, epoch)
        self.val_writer.flush()

        # Pass the remaining logs to `TensorBoard.on_epoch_end`
        logs = {k: v for k, v in logs.items() if not k.startswith('val_')}
        super(TrainValTensorBoard, self).on_epoch_end(epoch, logs)

    def on_train_end(self, logs=None):
        super(TrainValTensorBoard, self).on_train_end(logs)
        self.val_writer.close()

# Data Generator

In [None]:
class DataGenerator(keras.utils.Sequence):
    'Generates data for Keras'
    def __init__(self, DSM_IDs,
                 PAN_IDs,
                 LABEL_IDs,
                 batch_size=32,
                 shuffle=True,
                 pred_fn=None):
        'Initialization'
        self.DSM_IDs = DSM_IDs
        self.PAN_IDs = PAN_IDs
        self.LABEL_IDs = LABEL_IDs
        self.phase = 'gen'
        self.pred_fn = pred_fn
        if len(self.PAN_IDs) != len(self.DSM_IDs) or len(self.DSM_IDs) != len(self.LABEL_IDs):
            raise ValueError('DSM, PAN or LABEL do not match')
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(len(self.DSM_IDs) / self.batch_size))

    def getitem(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        # Find list of IDs
        DSM_IDs_temp = [self.DSM_IDs[k] for k in indexes]
        PAN_IDs_temp = [self.PAN_IDs[k] for k in indexes]
        LABEL_IDs_temp = [self.LABEL_IDs[k] for k in indexes]

        # Generate data
        DSM, PAN, label = self.__data_generation(DSM_IDs_temp, PAN_IDs_temp, LABEL_IDs_temp)
        
        if self.phase == 'gen':
            y1 = np.ones([label.shape[0], 1])
            return [DSM, PAN, label], [label, y1]

        elif self.phase == 'discr':
        
            pred = self.pred_fn([DSM,PAN])
            
            discr_X_1 = np.concatenate((DSM,DSM), axis=0)
            discr_X_2 = np.concatenate((label,pred), axis=0)
            
            y1 = np.ones(shape=(len(label),1))
            y0 = np.zeros(shape=(len(pred),1))
            
            prob = np.concatenate([y1,y0],axis=0)
            
            #shuffle
            discr_X_1, discr_X2, prob = shuffle(discr_X_1, discr_X_2, prob, random_state=42)
            
                    
            discr_X = [discr_X_1, discr_X_2]
            return discr_X, prob
            

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        # Find list of IDs
        DSM_IDs_temp = [self.DSM_IDs[k] for k in indexes]
        PAN_IDs_temp = [self.PAN_IDs[k] for k in indexes]
        LABEL_IDs_temp = [self.LABEL_IDs[k] for k in indexes]

        # Generate data
        DSM, PAN, label = self.__data_generation(DSM_IDs_temp, PAN_IDs_temp, LABEL_IDs_temp)
        
        if self.phase == 'gen':
            y1 = np.ones([label.shape[0], 1])
            return [DSM, PAN, label], [label, y1]

        elif self.phase == 'discr':
        
            pred = self.pred_fn([DSM,PAN])
            
            discr_X_1 = np.concatenate((DSM,DSM), axis=0)
            discr_X_2 = np.concatenate((label,pred), axis=0)
            
            y1 = np.ones(shape=(len(label),1))
            y0 = np.zeros(shape=(len(pred),1))
            
            prob = np.concatenate([y1,y0],axis=0)
            
            #shuffle
            discr_X_1, discr_X2, prob = shuffle(discr_X_1, discr_X_2, prob, random_state=42)
            
                    
            discr_X = [discr_X_1, discr_X_2]
            return discr_X, prob

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.DSM_IDs))
        if self.shuffle:
            np.random.shuffle(self.indexes)

    def __data_generation(self, DSM_IDs_temp, PAN_IDs_temp, LABEL_IDs_temp):
        'Generates data containing batch_size samples' 
        # X_out : (n_samples, *dim, n_channels)
        # Y_out : (n_samples, *dim, n_classes)
        # Initialization
        DSM_out = []
        PAN_out = []
        LABEL_out = []
        for i in range(len(DSM_IDs_temp)):
                DSM_out.append(np.moveaxis(rasterio.open(DSM_IDs_temp[i]).read(),0,2))
                PAN_out.append(np.moveaxis(rasterio.open(PAN_IDs_temp[i]).read(),0,2))
                LABEL_out.append(np.moveaxis(rasterio.open(LABEL_IDs_temp[i]).read(),0,2))
       
        return np.asarray(DSM_out), np.asarray(PAN_out), np.asarray(LABEL_out)
    
class Data:
    
    def __init__(self, path, random=False):
        """
        input:
            path: path to the folder with subfolders: DSM, PAN, LABEL
            max_num: int, num of samples
            random: bool, to load samples randomly or from 0 to num_max
        """
        self.DSM = sorted(glob.glob(path+"/DSM/*.tif"))
        self.PAN = sorted(glob.glob(path+"/PAN/*.tif"))
        self.LABEL = sorted(glob.glob(path+"/LABEL/*.tif"))
        if len(self.DSM) != len(self.PAN) or len(self.LABEL) != len(self.PAN):
            raise ValueError('DSM, PAN or LABEL do not match')
      
    def get_data(self, start=0, num=10, as_arr=True, random=False):
        """
        function: load max_num of XY into lists
        output: list of numpy arrays, X (images) and Y (labels)
        """
        DSM_out = []
        PAN_out = []
        LABEL_out = []
      
        if random:
            idx = np.random.choice(list(range(len(self.X))), num, replace=False)
            print('randomly loading {0} tiles from {1} tiles'.format(num, len(self.DSM))) 
        else:
            idx = list(range(start, start+num))
            print('loading {0} - {1} image tiles'.format(start, start+num-1))

        for i in idx:
            DSM_out.append(np.moveaxis(rasterio.open(self.DSM[i]).read(),0,2))
            PAN_out.append(np.moveaxis(rasterio.open(self.PAN[i]).read(),0,2))
            LABEL_out.append(np.moveaxis(rasterio.open(self.LABEL[i]).read(),0,2))
        
        DSM_remove = [self.DSM[i] for i in idx]
        PAN_remove = [self.PAN[i] for i in idx]
        LABEL_remove = [self.LABEL[i] for i in idx]
        
        for i in range(len(DSM_remove)):
            self.DSM.remove(DSM_remove[i])
            self.PAN.remove(PAN_remove[i])
            self.LABEL.remove(LABEL_remove[i])
        
        if as_arr:
            return np.asarray(DSM_out), np.asarray(PAN_out), np.asarray(LABEL_out)
        else:
            return DSM_out, PAN_out, LABEL_out
           
    def split_trn_vld_tst(self, vld_rate=0.2, tst_rate=0.05, random=True, seed=10):
        np.random.seed(seed)

        num = len(self.DSM)
        vld_num = int(num*vld_rate)
        tst_num = int(num*tst_rate)
        
        print('split into {0} train, {1} validation, {2} test samples'.format(num-vld_num-tst_num, vld_num, tst_num))
        idx = np.arange(num)
        if random:
            np.random.shuffle(idx)
        DSM_tst, PAN_tst, LABEL_tst = [self.DSM[k] for k in idx[:tst_num]], [self.PAN[k] for k in idx[:tst_num]], [self.LABEL[k] for k in idx[:tst_num]]
        DSM_vld, PAN_vld, LABEL_vld = [self.DSM[k] for k in idx[tst_num:tst_num+vld_num]], [self.PAN[k] for k in idx[tst_num:tst_num+vld_num]], [self.LABEL[k] for k in idx[tst_num:tst_num+vld_num]]
        DSM_trn, PAN_trn, LABEL_trn = [self.DSM[k] for k in idx[tst_num+vld_num:]], [self.PAN[k] for k in idx[tst_num+vld_num:]], [self.LABEL[k] for k in idx[tst_num+vld_num:]]
        
        
        return DSM_trn, PAN_trn, LABEL_trn, DSM_vld, PAN_vld, LABEL_vld, DSM_tst, PAN_tst, LABEL_tst

# Data

In [None]:
dd = Data(r'D:\source\TRAIN_DATA')

In [None]:
dsm_train, pan_train, label_train, dsm_vld, pan_vld, label_vld, dsm_tst, pan_tst, label_tst = dd.split_trn_vld_tst()

In [None]:
train_gen = DataGenerator(dsm_train, pan_train, label_train, pred_fn=None, batch_size=8, 
                          shuffle=True)
valid_gen = DataGenerator(dsm_vld, pan_vld, label_vld, pred_fn=None, batch_size=8, 
                          shuffle=True)

# Training

In [None]:
myModel = Wnet_cgan(256, 256, n_labels=1)

In [None]:
myModel.build_wnet_cgan([32,32,32,32],
                        (3,3), 
                        wnet_activation='selu',
                        wnet_lr=1e-4,
                        discr_inp_channels = 16,
                        discr_block_list=[32,32,32,32],
                        discr_k_size=(3,3), 
                        discr_activation='relu',
                        discr_lr=1e-4,
                        lambda_=1e-1)

In [None]:
myModel.fit_wnet_cgan(train_gen, valid_gen, adv_epochs=25, gen_epochs=75,
                     adv_steps_epoch=50, gen_steps_epoch=100, n_rounds=20)
#myModel.fit_wnet_cgan(train_gen, valid_gen, adv_epochs=2, gen_epochs=2,
#                     adv_steps_epoch=2, gen_steps_epoch=2, n_rounds=10)