# 1. Libraries

In [None]:
import numpy as np
import h5py
import os
import time
import datetime
import threading
import random
import cv2
import imageio

from tensorflow.keras.optimizers import RMSprop, Adam
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Input, Activation
from tensorflow.keras.layers import Conv2D, Reshape, Conv3D, AveragePooling2D, Lambda, UpSampling2D, UpSampling3D, GlobalAveragePooling3D
from tensorflow.keras.layers import Dropout, BatchNormalization
from tensorflow.keras.layers import concatenate, add, multiply

import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.python.ops import math_ops
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
import tensorflow_addons as tfa

print(tf.__version__)

In [None]:
'''
GPU setting ( Our setting: rtx 3090,
                            gpu number = 0 )
'''
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

! export TF_GPU_ALLOCATOR=cuda_malloc_async

# 2. Functions

## 2.1. generate_traindata_noise

In [None]:
def generate_traindata_for_train(traindata_all, traindata_label, input_size,
                                 label_size, batch_size,
                                 Setting02_AngualrViews, boolmask_img4,
                                 boolmask_img6, boolmask_img15):
    """
     input: traindata_all   (16x128x128x9x9x3) uint8
            traindata_label (16x128x128x9x9)   float32
            input_size 23~   int
            label_size 1~    int
            batch_size 16    int
            Setting02_AngualrViews [0,1,2,3,4,5,6,7,8] for 9x9
            boolmask_img4 (128x128)  bool // reflection mask for images[4]
            boolmask_img6 (128x128)  bool // reflection mask for images[6]
            boolmask_img15 (128x128) bool // reflection mask for images[15]


     Generate traindata using LF image and disparity map
     by randomly chosen variables.
     1.  gray image: random R,G,B --> R*img_R + G*img_G + B*imgB
     2.  patch-wise learning: random x,y  --> LFimage[x:x+size1,y:y+size2]
     3.  scale augmentation: scale 1,2,3  --> ex> LFimage[x:x+2*size1:2,y:y+2*size2:2]

     output: traindata_batch   (batch_size x input_size x input_size x len(Setting02_AngualrViews)) float32
             traindata_batch_label (batch_size x label_size x label_size )                   float32
    """
    """ initialize image_stack & label """
    traindata_batch = np.zeros(
        (batch_size, input_size, input_size, len(Setting02_AngualrViews),
         len(Setting02_AngualrViews)),
        dtype=np.float32)

    traindata_batch_label = np.zeros((batch_size, label_size, label_size))
    """ inital variable """
    crop_half1 = int(0.5 * (input_size - label_size))
    """ Generate image stacks"""
    for ii in range(0, batch_size):
        sum_diff = 0
        valid = 0

        while (sum_diff < 0.01 * input_size * input_size or valid < 1):
            """//Variable for gray conversion//"""
            rand_3color = 0.05 + np.random.rand(3)
            rand_3color = rand_3color / np.sum(rand_3color)
            R = rand_3color[0]
            G = rand_3color[1]
            B = rand_3color[2]
            """
                We use totally 16 LF images,(0 to 15)
                Since some images(4,6,15) have a reflection region,
                We decrease frequency of occurrence for them.
            """
            # Use for fold 1
            aa_arr = np.array([
                0, 1, 2, 3, 5, 7, 8, 9, 10, 11, 12, 13, 14, 0, 1, 2, 3, 5, 7,
                8, 9, 10, 11, 12, 13, 14, 0, 1, 2, 3, 5, 7, 8, 9, 10, 11, 12,
                13, 14, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
            ])
            # Use for fold 2
            # aa_arr = np.array([
            #     0, 1, 2, 3, 5, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 5, 7,
            #     8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 5, 7, 8, 9, 10, 11, 12,
            #     13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
            # ])

            # Use for fold 3
            # aa_arr = np.array([
            #     0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 0, 1, 2, 3, 4, 5, 6, 7,
            #     8, 9, 10, 11, 12, 13, 14, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
            #     13, 14, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
            # ])
            image_id = np.random.choice(aa_arr)

            if (len(Setting02_AngualrViews) == 9):
                ix_rd = 0
                iy_rd = 0

            kk = np.random.randint(17)
            if (kk < 8):
                scale = 1
            elif (kk < 14):
                scale = 2
            elif (kk < 17):
                scale = 3

            idx_start = np.random.randint(0, 128 - scale * input_size)
            idy_start = np.random.randint(0, 128 - scale * input_size)
            valid = 1
            """
                boolmask: reflection masks for images(4,6,15)
            """
            # Comment out for fold 3
            if (image_id == 4 or 6 or 15):
                if (image_id == 4):
                    a_tmp = boolmask_img4
                    if (np.sum(a_tmp[
                            idx_start + scale * crop_half1:idx_start +
                            scale * crop_half1 + scale * label_size:scale,
                            idy_start + scale * crop_half1:idy_start +
                            scale * crop_half1 + scale * label_size:scale]) > 0
                            or np.sum(a_tmp[idx_start:idx_start +
                                            scale * input_size:scale,
                                            idy_start:idy_start +
                                            scale * input_size:scale]) > 0):
                        valid = 0
                if (image_id == 6):
                    a_tmp = boolmask_img6
                    if (np.sum(a_tmp[
                            idx_start + scale * crop_half1:idx_start +
                            scale * crop_half1 + scale * label_size:scale,
                            idy_start + scale * crop_half1:idy_start +
                            scale * crop_half1 + scale * label_size:scale]) > 0
                            or np.sum(a_tmp[idx_start:idx_start +
                                            scale * input_size:scale,
                                            idy_start:idy_start +
                                            scale * input_size:scale]) > 0):
                        valid = 0
                # Comment out for fold 2
                if (image_id == 15):
                    a_tmp = boolmask_img15
                    if (np.sum(a_tmp[
                            idx_start + scale * crop_half1:idx_start +
                            scale * crop_half1 + scale * label_size:scale,
                            idy_start + scale * crop_half1:idy_start +
                            scale * crop_half1 + scale * label_size:scale]) > 0
                            or np.sum(a_tmp[idx_start:idx_start +
                                            scale * input_size:scale,
                                            idy_start:idy_start +
                                            scale * input_size:scale]) > 0):
                        valid = 0

            if (valid > 0):

                image_center = (1 / 255) * np.squeeze(
                    R * traindata_all[image_id, idx_start:idx_start + scale *
                                      input_size:scale, idy_start:idy_start +
                                      scale * input_size:scale, 4 + ix_rd,
                                      4 + iy_rd, 0].astype('float32') +
                    G * traindata_all[image_id, idx_start:idx_start + scale *
                                      input_size:scale, idy_start:idy_start +
                                      scale * input_size:scale, 4 + ix_rd,
                                      4 + iy_rd, 1].astype('float32') +
                    B * traindata_all[image_id, idx_start:idx_start + scale *
                                      input_size:scale, idy_start:idy_start +
                                      scale * input_size:scale, 4 + ix_rd,
                                      4 + iy_rd, 2].astype('float32'))
                sum_diff = np.sum(
                    np.abs(image_center -
                           np.squeeze(image_center[int(0.5 * input_size),
                                                   int(0.5 * input_size)])))

                traindata_batch[ii, :, :, :, :] = np.squeeze(
                    R * traindata_all[
                        image_id:image_id + 1, idx_start:idx_start +
                        scale * input_size:scale, idy_start:idy_start +
                        scale * input_size:scale, :, :, 0].astype('float32') +
                    G * traindata_all[
                        image_id:image_id + 1, idx_start:idx_start +
                        scale * input_size:scale, idy_start:idy_start +
                        scale * input_size:scale, :, :, 1].astype('float32') +
                    B * traindata_all[
                        image_id:image_id + 1, idx_start:idx_start +
                        scale * input_size:scale, idy_start:idy_start +
                        scale * input_size:scale, :, :, 2].astype('float32'))
                '''
                 traindata_batch_label  <-- scale_factor*traindata_label[random_index, scaled_label_size, scaled_label_size]
                '''
                if (len(traindata_label.shape) == 5):
                    traindata_batch_label[
                        ii, :, :] = (1.0 / scale) * traindata_label[
                            image_id, idx_start +
                            scale * crop_half1:idx_start + scale * crop_half1 +
                            scale * label_size:scale, idy_start +
                            scale * crop_half1:idy_start + scale * crop_half1 +
                            scale * label_size:scale, 4 + ix_rd, 4 + iy_rd]
                else:
                    traindata_batch_label[
                        ii, :, :] = (1.0 / scale) * traindata_label[
                            image_id,
                            idx_start + scale * crop_half1:idx_start +
                            scale * crop_half1 + scale * label_size:scale,
                            idy_start + scale * crop_half1:idy_start +
                            scale * crop_half1 + scale * label_size:scale]

    traindata_batch = np.float32((1 / 255) * traindata_batch)

    return traindata_batch, traindata_batch_label


""" (v, u) """
""" (0, 0) (0, 1) (0, 2) (0, 3) (0, 4) (0, 5) (0, 6) (0, 7) (0, 8)"""
""" (1, 0) (1, 1) (1, 2) (1, 3) (1, 4) (1, 5) (1, 6) (1, 7) (1, 8)"""
""" (2, 0) (2, 1) (2, 2) (2, 3) (2, 4) (2, 5) (2, 6) (2, 7) (2, 8)"""
""" (3, 0) (3, 1) (3, 2) (3, 3) (3, 4) (3, 5) (3, 6) (3, 7) (3, 8)"""
""" (4, 0) (4, 1) (4, 2) (4, 3) (4, 4) (4, 5) (4, 6) (4, 7) (4, 8)"""
""" (5, 0) (5, 1) (5, 2) (5, 3) (5, 4) (5, 5) (5, 6) (5, 7) (5, 8)"""
""" (6, 0) (6, 1) (6, 2) (6, 3) (6, 4) (6, 5) (6, 6) (6, 7) (6, 8)"""
""" (7, 0) (7, 1) (7, 2) (7, 3) (7, 4) (7, 5) (7, 6) (7, 7) (7, 8)"""
""" (8, 0) (8, 1) (8, 2) (8, 3) (8, 4) (8, 5) (8, 6) (8, 7) (8, 8)"""


def data_augmentation_for_train(traindata_batch, traindata_label_batchNxN,
                                batch_size):
    """
        For Data augmentation
        (rotation, transpose and gamma)
    """

    for batch_i in range(batch_size):
        gray_rand = 0.4 * np.random.rand() + 0.8

        traindata_batch[batch_i, :, :, :, :] = pow(
            traindata_batch[batch_i, :, :, :, :], gray_rand)
        """ transpose """
        transp_rand = np.random.randint(0, 2)
        if transp_rand == 1:
            traindata_batch_tmp6 = np.copy(
                np.rot90(
                    np.transpose(
                        np.squeeze(traindata_batch[batch_i, :, :, :, :]),
                        (1, 0, 2, 3))))
            traindata_batch[
                batch_i, :, :, :, :] = traindata_batch_tmp6[:, :, ::-1]
            traindata_label_batchNxN_tmp6 = np.copy(
                np.rot90(
                    np.transpose(traindata_label_batchNxN[batch_i, :, :],
                                 (1, 0))))
            traindata_label_batchNxN[
                batch_i, :, :] = traindata_label_batchNxN_tmp6
        """ rotation """
        rotation_rand = np.random.randint(0, 4)
        """ 90 """
        if rotation_rand == 1:
            traindata_batch_tmp6 = np.copy(
                np.rot90(np.squeeze(traindata_batch[batch_i, :, :, :, :])))
            traindata_batch[batch_i, :, :, :, :] = np.copy(
                np.rot90(traindata_batch_tmp6, 1, (2, 3)))
            traindata_label_batchNxN_tmp6 = np.copy(
                np.rot90(traindata_label_batchNxN[batch_i, :, :]))
            traindata_label_batchNxN[
                batch_i, :, :] = traindata_label_batchNxN_tmp6
        """ 180 """
        if rotation_rand == 2:
            traindata_batch_tmp6 = np.copy(
                np.rot90(np.squeeze(traindata_batch[batch_i, :, :, :, :]), 2))
            traindata_batch[batch_i, :, :, :, :] = np.copy(
                np.rot90(traindata_batch_tmp6, 2, (2, 3)))
            traindata_label_batchNxN_tmp6 = np.copy(
                np.rot90(traindata_label_batchNxN[batch_i, :, :], 2))
            traindata_label_batchNxN[
                batch_i, :, :] = traindata_label_batchNxN_tmp6
        """ 270 """
        if rotation_rand == 3:
            traindata_batch_tmp6 = np.copy(
                np.rot90(np.squeeze(traindata_batch[batch_i, :, :, :, :]), 3))
            traindata_batch[batch_i, :, :, :, :] = np.copy(
                np.rot90(traindata_batch_tmp6, 3, (2, 3)))
            traindata_label_batchNxN_tmp6 = np.copy(
                np.rot90(traindata_label_batchNxN[batch_i, :, :], 3))
            traindata_label_batchNxN[
                batch_i, :, :] = traindata_label_batchNxN_tmp6
        """ gaussian noise """
        noise_rand = np.random.randint(0, 12)
        if noise_rand == 0:
            gauss = np.random.normal(
                0.0,
                np.random.uniform() * np.sqrt(0.2),
                (traindata_batch.shape[1], traindata_batch.shape[2],
                 traindata_batch.shape[3], traindata_batch.shape[4]))
            traindata_batch[batch_i, :, :, :, :] = np.clip(
                traindata_batch[batch_i, :, :, :, :] + gauss, 0.0, 1.0)

    return traindata_batch, traindata_label_batchNxN


def generate_traindata128(traindata_all, traindata_label,
                          Setting02_AngualrViews):
    """
    Generate validation or test set( = full size(128x128) LF images)

     input: traindata_all   (16x128x128x9x9x3) uint8
            traindata_label (16x128x128x9x9)   float32
            Setting02_AngualrViews [0,1,2,3,4,5,6,7,8] for 9x9


     output: traindata_batch_list   (batch_size x 128 x 128 x len(Setting02_AngualrViews)) float32
             traindata_label_batchNxN (batch_size x 128 x 128 )               float32
    """

    input_size = 128
    label_size = 128
    traindata_batch = np.zeros(
        (len(traindata_all), input_size, input_size,
         len(Setting02_AngualrViews), len(Setting02_AngualrViews)),
        dtype=np.float32)

    traindata_label_batchNxN = np.zeros(
        (len(traindata_all), label_size, label_size))
    """ inital setting """
    ### sz = (16, 27, 9, 128, 128)

    crop_half1 = int(0.5 * (input_size - label_size))

    for ii in range(0, len(traindata_all)):

        R = 0.299  ### 0,1,2,3 = R, G, B, Gray // 0.299 0.587 0.114
        G = 0.587
        B = 0.114

        image_id = ii

        ix_rd = 0
        iy_rd = 0
        idx_start = 0
        idy_start = 0

        traindata_batch[ii, :, :, :, :] = np.squeeze(
            R * traindata_all[image_id:image_id + 1, idx_start:idx_start +
                              input_size, idy_start:idy_start +
                              input_size, :, :, 0].astype('float32') +
            G * traindata_all[image_id:image_id + 1, idx_start:idx_start +
                              input_size, idy_start:idy_start +
                              input_size, :, :, 1].astype('float32') +
            B * traindata_all[image_id:image_id + 1, idx_start:idx_start +
                              input_size, idy_start:idy_start +
                              input_size, :, :, 2].astype('float32'))

        if (len(traindata_all) >= 12 and traindata_label.shape[-1] == 9):
            traindata_label_batchNxN[ii, :, :] = traindata_label[
                image_id,
                idx_start + crop_half1:idx_start + crop_half1 + label_size,
                idy_start + crop_half1:idy_start + crop_half1 + label_size,
                4 + ix_rd, 4 + iy_rd]
        elif (len(traindata_label.shape) == 5):
            traindata_label_batchNxN[ii, :, :] = traindata_label[
                image_id,
                idx_start + crop_half1:idx_start + crop_half1 + label_size,
                idy_start + crop_half1:idy_start + crop_half1 + label_size, 0,
                0]
        else:
            traindata_label_batchNxN[ii, :, :] = traindata_label[
                image_id,
                idx_start + crop_half1:idx_start + crop_half1 + label_size,
                idy_start + crop_half1:idy_start + crop_half1 + label_size]

    traindata_batch = np.float32((1 / 255) * traindata_batch)

    traindata_batch = np.minimum(np.maximum(traindata_batch, 0), 1)

    traindata_batch_list = []
    for i in range(traindata_batch.shape[3]):
        for j in range(traindata_batch.shape[4]):
            traindata_batch_list.append(
                np.expand_dims(traindata_batch[:, :, :, i, j], axis=-1))

    return traindata_batch_list, traindata_label_batchNxN

## 2.2. model

In [None]:
def convbn(input, out_planes, kernel_size, stride, dilation):

    seq = Conv2D(out_planes,
                 kernel_size,
                 stride,
                 'same',
                 dilation_rate=dilation,
                 use_bias=False)(input)
    seq = BatchNormalization()(seq)

    return seq


def convbn_3d(input, out_planes, kernel_size, stride):
    seq = Conv3D(out_planes, kernel_size, stride, 'same',
                 use_bias=False)(input)
    seq = BatchNormalization()(seq)

    return seq


def BasicBlock(input, planes, stride, downsample, dilation):
    conv1 = convbn(input, planes, 3, stride, dilation)
    conv1 = Activation('relu')(conv1)
    conv2 = convbn(conv1, planes, 3, 1, dilation)
    if downsample is not None:
        input = downsample

    conv2 = add([conv2, input])
    return conv2


def _make_layer(input, planes, blocks, stride, dilation):
    inplanes = 4
    downsample = None
    if stride != 1 or inplanes != planes:
        downsample = Conv2D(planes, 1, stride, 'same', use_bias=False)(input)
        downsample = BatchNormalization()(downsample)

    layers = BasicBlock(input, planes, stride, downsample, dilation)
    for i in range(1, blocks):
        layers = BasicBlock(layers, planes, 1, None, dilation)

    return layers


def UpSampling2DBilinear(size):
    return Lambda(lambda x: tf.compat.v1.image.resize_bilinear(
        x, size, align_corners=True))


def UpSampling3DBilinear(size):

    def UpSampling3DBilinear_(x, size):
        shape = K.shape(x)
        x = K.reshape(x, (shape[0] * shape[1], shape[2], shape[3], shape[4]))
        x = tf.image.resize_bilinear(x, size, align_corners=True)
        x = K.reshape(x, (shape[0], shape[1], size[0], size[1], shape[4]))
        return x

    return Lambda(lambda x: UpSampling3DBilinear_(x, size))


def feature_extraction(sz_input, sz_input2):
    i = Input(shape=(sz_input, sz_input2, 1))
    firstconv = convbn(i, 4, 3, 1, 1)
    firstconv = Activation('relu')(firstconv)
    firstconv = convbn(firstconv, 4, 3, 1, 1)
    firstconv = Activation('relu')(firstconv)

    layer1 = _make_layer(firstconv, 4, 2, 1, 1)
    layer2 = _make_layer(layer1, 8, 8, 1, 1)
    layer3 = _make_layer(layer2, 16, 2, 1, 1)
    layer4 = _make_layer(layer3, 16, 2, 1, 2)

    layer4_size = (layer4.get_shape().as_list()[1],
                   layer4.get_shape().as_list()[2])

    branch1 = AveragePooling2D((2, 2), (2, 2), 'same')(layer4)
    branch1 = convbn(branch1, 4, 1, 1, 1)
    branch1 = Activation('relu')(branch1)
    branch1 = UpSampling2DBilinear(layer4_size)(branch1)

    branch2 = AveragePooling2D((4, 4), (4, 4), 'same')(layer4)
    branch2 = convbn(branch2, 4, 1, 1, 1)
    branch2 = Activation('relu')(branch2)
    branch2 = UpSampling2DBilinear(layer4_size)(branch2)

    branch3 = AveragePooling2D((8, 8), (8, 8), 'same')(layer4)
    branch3 = convbn(branch3, 4, 1, 1, 1)
    branch3 = Activation('relu')(branch3)
    branch3 = UpSampling2DBilinear(layer4_size)(branch3)

    branch4 = AveragePooling2D((16, 16), (16, 16), 'same')(layer4)
    branch4 = convbn(branch4, 4, 1, 1, 1)
    branch4 = Activation('relu')(branch4)
    branch4 = UpSampling2DBilinear(layer4_size)(branch4)

    output_feature = concatenate(
        [layer2, layer4, branch4, branch3, branch2, branch1])
    lastconv = convbn(output_feature, 16, 3, 1, 1)
    lastconv = Activation('relu')(lastconv)
    lastconv = Conv2D(4, 1, (1, 1), 'same', use_bias=False)(lastconv)

    model = Model(inputs=[i], outputs=[lastconv])

    return model

def _getCostVolume_(inputs):
    shape = K.shape(inputs[0])
    disparity_values = np.linspace(-4, 4, 9)
    disparity_costs = []
    for d in disparity_values:
        if d == 0:
            tmp_list = []
            for i in range(len(inputs)):
                tmp_list.append(inputs[i])
        else:
            tmp_list = []
            for i in range(len(inputs)):
                (v, u) = divmod(i, 9)
                tensor = tfa.image.translate(inputs[i],
                                             [d * (u - 4), d * (v - 4)],
                                             'BILINEAR')
                tmp_list.append(tensor)

        cost = K.concatenate(tmp_list, axis=3)
        disparity_costs.append(cost)
    cost_volume = K.stack(disparity_costs, axis=1)
    cost_volume = K.reshape(cost_volume,
                            (shape[0], len(disparity_values), shape[1], shape[2], 4 * 81))
    return cost_volume

def channel_attention(cost_volume):
    x = GlobalAveragePooling3D()(cost_volume)
    x = Lambda(
        lambda y: K.expand_dims(K.expand_dims(K.expand_dims(y, 1), 1), 1))(x)
    x = Conv3D(170, 1, 1, 'same')(x)
    x = Activation('relu')(x)
    x = Conv3D(15, 1, 1, 'same')(x)  # [B, 1, 1, 1, 15]
    x = Activation('sigmoid')(x)

    # 15 -> 25
    # 0  1  2  3  4
    #    5  6  7  8
    #       9 10 11
    #         12 13
    #            14
    #
    # 0  1  2  3  4
    # 1  5  6  7  8
    # 2  6  9 10 11
    # 3  7 10 12 13
    # 4  8 11 13 14

    x = Lambda(lambda y: K.concatenate([
        y[:, :, :, :, 0:5], y[:, :, :, :, 1:2], y[:, :, :, :, 5:9],
        y[:, :, :, :, 2:3], y[:, :, :, :, 6:7], y[:, :, :, :, 9:12],
        y[:, :, :, :, 3:4], y[:, :, :, :, 7:8], y[:, :, :, :, 10:11],
        y[:, :, :, :, 12:14], y[:, :, :, :, 4:5], y[:, :, :, :, 8:9],
        y[:, :, :, :, 11:12], y[:, :, :, :, 13:15]
    ],
                                       axis=-1))(x)

    x = Lambda(lambda y: K.reshape(y, (K.shape(y)[0], 5, 5)))(x)
    x = Lambda(lambda y: tf.pad(y, [[0, 0], [0, 4], [0, 4]], 'REFLECT'))(x)
    attention = Lambda(lambda y: K.reshape(y, (K.shape(y)[0], 1, 1, 1, 81)))(x)
    x = Lambda(lambda y: K.repeat_elements(y, 4, -1))(attention)
    return multiply([x, cost_volume]), attention


def channel_attention_free(cost_volume):
    x = GlobalAveragePooling3D()(cost_volume)
    x = Lambda(
        lambda y: K.expand_dims(K.expand_dims(K.expand_dims(y, 1), 1), 1))(x)
    x = Conv3D(170, 1, 1, 'same')(x)
    x = Activation('relu')(x)
    x = Conv3D(81, 1, 1, 'same')(x)
    x = Activation('sigmoid')(x)
    attention = Lambda(lambda y: K.reshape(y, (K.shape(y)[0], 1, 1, 1, 81)))(x)
    x = Lambda(lambda y: K.repeat_elements(y, 4, -1))(attention)
    return multiply([x, cost_volume]), attention


def channel_attention_mirror(cost_volume):
    x = GlobalAveragePooling3D()(cost_volume)
    x = Lambda(
        lambda y: K.expand_dims(K.expand_dims(K.expand_dims(y, 1), 1), 1))(x)
    x = Conv3D(170, 1, 1, 'same')(x)
    x = Activation('relu')(x)
    x = Conv3D(25, 1, 1, 'same')(x)
    x = Activation('sigmoid')(x)
    x = Lambda(lambda y: K.reshape(y, (K.shape(y)[0], 5, 5)))(x)
    x = Lambda(lambda y: tf.pad(y, [[0, 0], [0, 4], [0, 4]], 'REFLECT'))(x)
    attention = Lambda(lambda y: K.reshape(y, (K.shape(y)[0], 1, 1, 1, 81)))(x)
    x = Lambda(lambda y: K.repeat_elements(y, 4, -1))(attention)
    return multiply([x, cost_volume]), attention


def basic(cost_volume):

    feature = 2 * 75
    dres0 = convbn_3d(cost_volume, feature, 3, 1)
    dres0 = Activation('relu')(dres0)
    dres0 = convbn_3d(dres0, feature, 3, 1)
    cost0 = Activation('relu')(dres0)

    dres1 = convbn_3d(cost0, feature, 3, 1)
    dres1 = Activation('relu')(dres1)
    dres1 = convbn_3d(dres1, feature, 3, 1)
    cost0 = add([dres1, cost0])

    dres4 = convbn_3d(cost0, feature, 3, 1)
    dres4 = Activation('relu')(dres4)
    dres4 = convbn_3d(dres4, feature, 3, 1)
    cost0 = add([dres4, cost0])

    classify = convbn_3d(cost0, feature, 3, 1)
    classify = Activation('relu')(classify)
    cost = Conv3D(1, 3, 1, 'same', use_bias=False)(classify)

    return cost

class BSplineLayer(tf.keras.layers.Layer):
    def __init__(self, name="BSpline", **kwargs):
        super().__init__(name=name, **kwargs)
        # Ensure float32 for consistency with typical NN layers
        self.d = tf.constant([-4.0, -3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0], dtype=tf.float32)
        self.C_dim = 9
        # Epsilon as constant tensor for graph mode compatibility
        self.eps = tf.constant(1e-8, dtype=tf.float32)

        # Precompute M_inv once
        M = tf.constant([
            [0.51388889, 0.31944444, 0.04166667, 0., 0., 0., 0.],
            [0.11111111, 0.55555556, 0.33333333, 0., 0., 0., 0.],
            [0., 0.125, 0.70833333, 0.16666667, 0., 0., 0.],
            [0., 0., 0.16666667, 0.66666667, 0.16666667, 0., 0.],
            [0., 0., 0., 0.16666667, 0.70833333, 0.125, 0.],
            [0., 0., 0., 0., 0.33333333, 0.55555556, 0.11111111],
            [0., 0., 0., 0., 0.04166667, 0.31944444, 0.51388889]
        ], dtype=tf.float32)
        M_inv_value = tf.linalg.inv(M)

        # Store inverse as a non-trainable weight for saving/loading with the model
        self.M_inv = self.add_weight(
            name="M_inv",
            shape=(7, 7),
            dtype=tf.float32,
            initializer=tf.keras.initializers.Constant(M_inv_value),
            trainable=False
        )

    # Define the piecewise polynomial evaluation function inside the layer
    @tf.function
    def evaluate_f_piecewise(self, x, a0, a1, a2, a3, a4, a5, a6, a7, a8):
        """
        Evaluates the piecewise cubic polynomial f(x) defined by a0..a4.
        Uses the specific formulas provided by the user.

        Args:
            x: Input tensor shape (B, H, W).
            a0..a4: Coefficient tensors shape (B, H, W).

        Returns:
            f(x): Evaluated polynomial tensor shape (B, H, W).
        """
        # Ensure inputs have correct dtype
        x = tf.cast(x, dtype=tf.float32)
        a0 = tf.cast(a0, dtype=tf.float32)
        a1 = tf.cast(a1, dtype=tf.float32)
        a2 = tf.cast(a2, dtype=tf.float32)
        a3 = tf.cast(a3, dtype=tf.float32)
        a4 = tf.cast(a4, dtype=tf.float32)
        a5 = tf.cast(a5, dtype=tf.float32)
        a6 = tf.cast(a6, dtype=tf.float32)
        a7 = tf.cast(a7, dtype=tf.float32)
        a8 = tf.cast(a8, dtype=tf.float32)

        # Calculate coefficients for polynomial terms (shapes B, H, W)
        P0 = -9*a0 + 19*a1 - 13*a2 + 3*a3
        Q0 = -54*a0 + 138*a1 - 120*a2 + 36*a3
        S0 = -108*a0 + 300*a1 - 336*a2 + 144*a3
        T0 = -72*a0 + 208*a1 - 256*a2 + 192*a3

        P1 = -8*a1 + 23*a2 - 27*a3 + 12*a4
        Q1 = -24*a1 + 96*a2 - 144*a3 + 72*a4
        S1 = -24*a1 + 96*a2 - 216*a3 + 144*a4
        T1 = -8*a1 + 32*a2 - 48*a3 + 96*a4

        P2 = -3*a2 + 11*a3 - 12*a4 + 4*a5
        Q2 = 12*a3 - 24*a4 + 12*a5
        S2 = -12*a3 + 12*a5
        T2 =  4*a3 + 16*a4 + 4*a5

        P3 = -4*a3 + 12*a4 - 11*a5 + 3*a6
        Q3 = 12*a3 - 24*a4 + 12*a5
        S3 = -12*a3 + 12*a5
        T3 = 4*a3 + 16*a4 + 4*a5

        P4 = -12*a4 + 27*a5 - 23*a6 + 8*a7
        Q4 = 72*a4 - 144*a5 + 96*a6 - 24*a7
        S4 = -144*a4 + 216*a5 - 96*a6 + 24*a7
        T4 = 96*a4 - 48*a5 + 32*a6 - 8*a7

        P5 = -3*a5 + 13*a6 - 19*a7 + 9*a8
        Q5 = 36*a5 - 120*a6 + 138*a7 - 54*a8
        S5 = -144*a5 + 336*a6 - 300*a7 + 108*a8
        T5 = 192*a5 - 256*a6 + 208*a7 - 72*a8

        # Common coefficients for x^2, x^1, x^0 terms

        x2 = x * x
        x3 = x2 * x

        # Evaluate polynomial for interval [-4, -2)
        f_0 = (P0 * x3 + Q0 * x2 + S0 * x + T0) / 72

        # Evaluate polynomial for interval [-2, -1)
        f_1 = (P1 * x3 + Q1 * x2 + S1 * x + T1) / 72

        # Evaluate polynomial for interval [-1, 0)
        f_2 = (P2 * x3 + Q2 * x2 + S2 * x + T2) / 24

        # Evaluate polynomial for interval [0, 1)
        f_3 = (P3 * x3 + Q3 * x2 + S3 * x + T3) / 24

        # Evaluate polynomial for interval [1, 2)
        f_4 = (P4 * x3 + Q4 * x2 + S4 * x + T4) / 72

        # Evaluate polynomial for interval [2, 4]
        f_5 = (P5 * x3 + Q5 * x2 + S5 * x + T5) / 72

        # Choose based on x value: Use f_neg if x < 0, use f_pos if x >= 0
        f_val = tf.where(x < -2.0, f_0,
                tf.where(x < -1.0, f_1,
                tf.where(x <  0.0, f_2,
                tf.where(x <  1.0, f_3,
                tf.where(x <  2.0, f_4, f_5)))))

        return f_val

    @tf.function
    def call(self, features):
        """
        Processes features according to the complex root-finding logic.

        Args:
            features: Tensor of shape (B, H, W, C) where C=5.

        Returns:
            Tensor of shape (B, H, W) containing the input value (from d or R0-R3)
            that maximizes the underlying function f_ij. # CHANGED: minimizes to maximizes
        """
        # Ensure input is float32
        features = tf.cast(features, dtype=tf.float32)

        shape = tf.shape(features)
        B, H, W = shape[0], shape[1], shape[2]

        # Assert shape dynamically
        tf.debugging.assert_equal(tf.shape(features)[3], self.C_dim,
                                message=f"Input tensor must have C={self.C_dim} channels.")

        # --- Steps 1-8: Calculate a0..a4, A0, A1, B, C, deltas, R0..R3 ---

        # 1. Extract Channels
        c0 = features[..., 0]
        c1 = features[..., 1]
        c2 = features[..., 2]
        c3 = features[..., 3]
        c4 = features[..., 4]
        c5 = features[..., 5]
        c6 = features[..., 6]
        c7 = features[..., 7]
        c8 = features[..., 8]

        # 2. Set Boundary `a` values
        a0 = c0
        a8 = c8

        # 3. Solve Linear System for a1, a2, a3
        R_val = tf.stack([c1 - 0.125 * c0, c2, c3, c4, c5, c6, c7 - 0.125*c8], axis=-1)
        a_vec = tf.linalg.matvec(self.M_inv, R_val) # Use stored M_inv
        a1 = a_vec[..., 0]
        a2 = a_vec[..., 1]
        a3 = a_vec[..., 2]
        a4 = a_vec[..., 3]
        a5 = a_vec[..., 4]
        a6 = a_vec[..., 5]
        a7 = a_vec[..., 6]

        # End Steps 1-3
        a_coeffs = [a0, a1, a2, a3, a4, a5, a6, a7, a8]
        R_values = tf.linspace(-4.0, 4.0, 81)  # Shape: (81,)

        x_broadcastable = tf.reshape(R_values, (1, 1, 1, 81))
        a_coeffs_broadcastable = [tf.expand_dims(a, axis=-1) for a in a_coeffs]
        all_candidate_fs = self.evaluate_f_piecewise(x_broadcastable, *a_coeffs_broadcastable)

        # Ensure output is float32
        return tf.cast(all_candidate_fs, dtype=tf.float32)

def disparityregression(input):
    shape = K.shape(input)
    disparity_values = np.linspace(-4, 4, 81)
    x = K.constant(disparity_values, shape=[81])
    x = K.expand_dims(K.expand_dims(K.expand_dims(x, 0), 0), 0)
    x = tf.tile(x, [shape[0], shape[1], shape[2], 1])
    out = K.sum(multiply([input, x]), -1)
    return out


def define_9BS_SubFocal(sz_input, sz_input2, view_n, learning_rate):
    """ 81 inputs"""
    input_list = []
    for i in range(len(view_n) * len(view_n)):
        input_list.append(Input(shape=(sz_input, sz_input2, 1)))
    """ 81 features"""
    feature_extraction_layer = feature_extraction(sz_input, sz_input2)

    feature_list = []
    for i in range(len(view_n) * len(view_n)):
        feature_list.append(feature_extraction_layer(input_list[i]))
    """ cost volume """
    cv = Lambda(_getCostVolume_)(feature_list)
    """ channel attention """
    cv, attention = channel_attention(cv)
    """ cost volume regression """
    cost = basic(cv)
    cost = Lambda(lambda x: K.permute_dimensions(K.squeeze(x, -1),
                                                 (0, 2, 3, 1)))(cost)
    BSpline_layer = BSplineLayer(name='BSplineInterpolation')
    cost_refined = BSpline_layer(cost)

    pred = Activation('softmax')(cost_refined)

    pred = Lambda(disparityregression)(pred)

    model = Model(inputs=input_list, outputs=[pred])

    model.summary()

    opt = Adam(learning_rate=learning_rate)

    model.compile(optimizer=opt, loss='mae')

    return model


if __name__ == '__main__':
    input_size = 32  # Input size should be greater than or equal to 23
    label_size = 32  # Since label_size should be greater than or equal to 1
    # number of views ( 0~8 for 9x9 )
    AngualrViews = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8])
    T1 = time.time()
    model = define_9BS_SubFocal(input_size, input_size, AngualrViews, 0.001)
    T2 = time.time()
    print('model load: %s s' % ((T2 - T1)))

# 2.3. pfm

In [None]:
def read_pfm(fpath, expected_identifier="Pf"):
    # PFM format definition: http://netpbm.sourceforge.net/doc/pfm.html

    def _get_next_line(f):
        next_line = f.readline().decode('utf-8').rstrip()
        # ignore comments
        while next_line.startswith('#'):
            next_line = f.readline().rstrip()
        return next_line

    with open(fpath, 'rb') as f:
        #  header
        identifier = _get_next_line(f)
        if identifier != expected_identifier:
            raise Exception('Unknown identifier. Expected: "%s", got: "%s".' % (expected_identifier, identifier))

        try:
            line_dimensions = _get_next_line(f)
            dimensions = line_dimensions.split(' ')
            width = int(dimensions[0].strip())
            height = int(dimensions[1].strip())
        except:
            raise Exception('Could not parse dimensions: "%s". '
                            'Expected "width height", e.g. "128 128".' % line_dimensions)

        try:
            line_scale = _get_next_line(f)
            scale = float(line_scale)
            assert scale != 0
            if scale < 0:
                endianness = "<"
            else:
                endianness = ">"
        except:
            raise Exception('Could not parse max value / endianess information: "%s". '
                            'Should be a non-zero number.' % line_scale)

        try:
            data = np.fromfile(f, "%sf" % endianness)
            data = np.reshape(data, (height, width))
            data = np.flipud(data)
            with np.errstate(invalid="ignore"):
                data *= abs(scale)
        except:
            raise Exception('Invalid binary values. Could not create %dx%d array from input.' % (height, width))

        return data

# 2.4. savedata

In [None]:
def display_current_output(train_output,
                           traindata_label,
                           iter00,
                           directory_save,
                           train_val='train'):
    '''
        display current results from CasLF
        and save results in /current_output
    '''
    sz = len(traindata_label)
    train_output = np.squeeze(train_output)
    if (len(traindata_label.shape) > 3
            and traindata_label.shape[-1] == 9):  # traindata
        pad1_half = int(
            0.5 * (np.size(traindata_label, 1) - np.size(train_output, 1)))
        train_label120 = traindata_label[:, 4:-4, 4:-4, 4, 4]
    else:  # valdata
        pad1_half = int(
            0.5 * (np.size(traindata_label, 1) - np.size(train_output, 1)))
        train_label120 = traindata_label[:, 4:-4, 4:-4]

    train_output120 = train_output[:, 4 - pad1_half:120 + 4 - pad1_half,
                                   4 - pad1_half:120 + 4 - pad1_half]

    train_diff = np.abs(train_output120 - train_label120)
    train_bp = (train_diff >= 0.07)

    train_output120_all = np.zeros((2 * 120, sz * 120), np.uint8)
    train_output120_all[0:120, :] = np.uint8(
        25 *
        np.reshape(np.transpose(train_label120, (1, 0, 2)), (120, sz * 120)) +
        100)
    train_output120_all[120:2 * 120, :] = np.uint8(
        25 *
        np.reshape(np.transpose(train_output120, (1, 0, 2)), (120, sz * 120)) +
        100)

    imageio.imsave(
        directory_save + '/' + train_val + '_iter%05d.jpg' % (iter00),
        np.squeeze(train_output120_all))

    return train_diff, train_bp

# 2.5. util

In [None]:
def load_LFdata(dir_LFimages):
    target_shape = 128
    # (number of scenes, width_images, lenght_images, 9, 9, RGB).
    traindata_all = np.zeros((len(dir_LFimages), target_shape, target_shape, 9, 9, 3), np.uint8)
    traindata_label = np.zeros((len(dir_LFimages), target_shape, target_shape), np.float32)

    image_id = 0
    for dir_LFimage in dir_LFimages:
        print(dir_LFimage)
        for i in range(81):
            try:
                tmp = np.float32(
                    imageio.imread('full_data/' + dir_LFimage +
                                   '/input_Cam0%.2d.png' %
                                   i))  # load LF images(9x9)
                tmp_resized = cv2.resize(tmp, (target_shape, target_shape), interpolation=cv2.INTER_AREA)
            except:
                print('full_data/' + dir_LFimage +
                      '/input_Cam0%.2d.png..does not exist' % i)
            traindata_all[image_id, :, :, i // 9, i - 9 * (i // 9), :] = tmp_resized
            del tmp, tmp_resized
        try:
            tmp = np.float32(
                read_pfm('full_data/' + dir_LFimage +
                         '/gt_disp_lowres.pfm'))  # load LF disparity map
            tmp_resized = cv2.resize(tmp, (target_shape, target_shape), interpolation=cv2.INTER_AREA)
        except:
            print('full_data/' + dir_LFimage +
                  '/gt_disp_lowres.pfm..does not exist' % i)
        traindata_label[image_id, :, :] = tmp_resized
        del tmp, tmp_resized
        image_id = image_id + 1
    return traindata_all, traindata_label

# 3. train

## 3.1. Load the data

In [None]:
if __name__ == '__main__':

    '''
    Load Train data from LF .png files
    '''
    print('Load training data...')
    dir_LFimages = [
        'additional/antinous', 'additional/boardgames', 'additional/dishes',
        'additional/greek', 'additional/kitchen', 'additional/medieval2',
        'additional/museum', 'additional/pens', 'additional/pillows',
        'additional/platonic', 'additional/rosemary', 'additional/table',
        'additional/tomb', 'additional/tower', 'additional/town',
        'additional/vinyl'
    ]
    # For fold 2 use
    # print('Load training data...')
    # dir_LFimages = [
    #     'additional/antinous', 'additional/boardgames', 'additional/dishes',
    #     'additional/greek', 'additional/kitchen', 'additional/medieval2',
    #     'additional/museum', 'stratified/backgammon', 'stratified/dots',
    #     'stratified/pyramids', 'stratified/stripes', 'training/boxes',
    #     'training/cotton', 'training/dino', 'training/sideboard'
    # ]

    # For fold 3 use
    # dir_LFimages = [
    #     'stratified/backgammon', 'stratified/dots', 'stratified/pyramids',
    #     'stratified/stripes', 'training/boxes', 'training/cotton',
    #     'training/dino', 'training/sideboard', 'additional/pillows',
    #     'additional/platonic', 'additional/rosemary', 'additional/table',
    #     'additional/tomb', 'additional/tower', 'additional/town',
    #     'additional/vinyl'
    # ]

    traindata_all, traindata_label = load_LFdata(dir_LFimages)

    print('Load training data... Complete')

    '''
    Load Test data from LF .png files
    '''
    print('Load test data...')
    dir_LFimages = [
        'stratified/backgammon', 'stratified/dots', 'stratified/pyramids',
        'stratified/stripes', 'training/boxes', 'training/cotton',
        'training/dino', 'training/sideboard'
    ]

    # For fold 2 use
    # dir_LFimages = [
    #     'stratified/backgammon', 'stratified/dots', 'stratified/pyramids',
    #     'stratified/stripes', 'training/boxes', 'training/cotton',
    #     'training/dino', 'training/sideboard'
    # ]
    # For fold 3 use
    # dir_LFimages = [
    #     'additional/antinous', 'additional/boardgames', 'additional/dishes',
    #     'additional/greek', 'additional/kitchen', 'additional/medieval2',
    #     'additional/museum', 'additional/pens'
    # ]

    valdata_all, valdata_label = load_LFdata(dir_LFimages)

    print('Load test data... Complete')

In [None]:
print(f'train:{traindata_all.shape}  label:{traindata_label.shape}')
print(f'val:{valdata_all.shape}  label:{valdata_label.shape}')

## 3.2. Training

In [None]:
def setup_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    tf.random.set_seed(seed)


def save_disparity_jet(disparity, filename):
    max_disp = np.nanmax(disparity[disparity != np.inf])
    min_disp = np.nanmin(disparity[disparity != np.inf])
    disparity = (disparity - min_disp) / (max_disp - min_disp)
    disparity = (disparity * 255.0).astype(np.uint8)
    disparity = cv2.applyColorMap(disparity, cv2.COLORMAP_JET)
    cv2.imwrite(filename, disparity)


if __name__ == '__main__':

    '''
    We use fit_generator to train LF,
    so here we defined a generator function.
    '''

    class threadsafe_iter:
        """
        Takes an iterator/generator and makes it thread-safe by
        serializing call to the `next` method of given iterator/generator.
        """

        def __init__(self, it):
            self.it = it
            self.lock = threading.Lock()

        def __iter__(self):
            return self

        def __next__(self):
            with self.lock:
                return next(self.it)

    def threadsafe_generator(f):
        """
        A decorator that takes a generator function and makes it thread-safe.
        """

        def g(*a, **kw):
            return threadsafe_iter(f(*a, **kw))

        return g

    @threadsafe_generator
    def myGenerator(traindata_all, traindata_label, input_size, label_size,
                    batch_size, AngualrViews, boolmask_img4, boolmask_img6,
                    boolmask_img15):
        while 1:
            (traindata_batch,
             traindata_label_batchNxN) = generate_traindata_for_train(
                 traindata_all, traindata_label, input_size, label_size,
                 batch_size, AngualrViews, boolmask_img4, boolmask_img6,
                 boolmask_img15)

            (traindata_batch,
             traindata_label_batchNxN) = data_augmentation_for_train(
                 traindata_batch, traindata_label_batchNxN, batch_size)

            traindata_batch_list = []
            for i in range(traindata_batch.shape[3]):
                for j in range(traindata_batch.shape[4]):
                    traindata_batch_list.append(
                        np.expand_dims(traindata_batch[:, :, :, i, j],
                                       axis=-1))

            yield (traindata_batch_list, traindata_label_batchNxN)

    # = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = =
    # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #


    seed = 42
    setup_seed(seed)
    print("random seed: ", seed)

    '''
    Define Patch-wise training parameters
    '''
    input_size = 32  # Input size should be greater than or equal to 23 32
    label_size = 32  # Since label_size should be greater than or equal to 1
    # number of views ( 0~8 for 9x9 )
    AngualrViews = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8])

    batch_size = 8
    workers_num = 8  # number of threads

    display_status_ratio = 2000  # 10000

    iter00 = 20

    load_weight_is = False

    model_learning_rate = 0.001
    networkname = '9BS_SubFocal'

    traindata, _ = generate_traindata128(traindata_all, traindata_label,
                                         AngualrViews)

    valdata, valdata_label = generate_traindata128(valdata_all, valdata_label,
                                                   AngualrViews)

    '''load invalid regions from training data (ex. reflective region)'''
    boolmask_img4 = imageio.imread(
        'full_data/additional_invalid_area/kitchen/input_Cam040_invalid_ver2.png'
    )
    boolmask_img6 = imageio.imread(
        'full_data/additional_invalid_area/museum/input_Cam040_invalid_ver2.png'
    )
    boolmask_img15 = imageio.imread(
        'full_data/additional_invalid_area/vinyl/input_Cam040_invalid_ver2.png'
    )

    boolmask_img4 = 1.0 * boolmask_img4[:, :, 3] > 0
    boolmask_img6 = 1.0 * boolmask_img6[:, :, 3] > 0
    boolmask_img15 = 1.0 * boolmask_img15[:, :, 3] > 0

    '''
    Model for patch-wise training
    '''
    model = define_9BS_SubFocal(input_size, input_size, AngualrViews,
                            model_learning_rate)
    model.load_weights('9BSfold2_LF_checkpoint/9BSfold2_SubFocal_sub_0.5_ckp/iter0020_valmse4.963_bp25.53.hdf5')
    '''
    Model for predicting full-size LF images
    '''
    image_w = 128
    image_h = 128
    model_128 = define_9BS_SubFocal(image_w, image_h, AngualrViews,
                                model_learning_rate)
    """
    load latest_checkpoint
    """
    if load_weight_is:
        model.load_weights(
            'LF_checkpoint/SubFocal_sub_0.5_ckp/iter0049_valmse0.845_bp2.04.hdf5'
        )

    '''
    Define directory for saving checkpoint file & disparity output image
    '''
    LF_checkpoints_path = '9BSfold2_LF_checkpoint/'
    LF_output_path = '9BSfold2_LF_output/'

    directory_ckp = LF_checkpoints_path + "%s_ckp" % (networkname)
    if not os.path.exists(directory_ckp):
        os.makedirs(directory_ckp)

    if not os.path.exists(LF_output_path):
        os.makedirs(LF_output_path)
    directory_t = LF_output_path + '%s' % (networkname)
    if not os.path.exists(directory_t):
        os.makedirs(directory_t)

    txt_name = LF_checkpoints_path + 'lf_%s.txt' % (networkname)
    """
    Write date & time
    """
    f1 = open(txt_name, 'a')
    now = datetime.datetime.now()
    f1.write('\n' + str(now) + '\n\n')
    f1.write('Learning rate: {}\n'.format(model_learning_rate))
    f1.write('Batch size: {}\n\n'.format(batch_size))
    f1.close()


    my_generator = myGenerator(traindata_all, traindata_label, input_size,
                               label_size, batch_size, AngualrViews,
                               boolmask_img4, boolmask_img6, boolmask_img15)
    best_bad_pixel = 25.51  #100.0
    val_output = model_128.predict(valdata, batch_size=1)
    print("test!!!")

    # Get a batch from the generator
    traindata_batch_list, traindata_label_batchNxN = next(my_generator)

    for iter02 in range(20, 40):
        ''' Patch-wise training... start'''
        print(f'EPOCH {iter02+1} STARTS')
        t0 = time.time()
        history = model.fit_generator(my_generator,
                            steps_per_epoch=int(display_status_ratio),
                            epochs=iter00 + 1,
                            class_weight=None,
                            max_queue_size=10,
                            initial_epoch=iter00,
                            verbose=1,
                            workers=workers_num)

        iter00 = iter00 + 1
        '''Get the training loss from the history'''
        loss = history.history['loss'][-1]
        ''' Test after N*(display_status_ratio) iteration.'''
        weight_tmp1 = model.get_weights()
        model_128.set_weights(weight_tmp1)
        """ Validation """
        ''' Test after N*(display_status_ratio) iteration.'''

        val_output = model_128.predict(valdata, batch_size=1)
        ''' Save prediction image(disparity map) in 'current_output/' folder '''
        val_error, val_bp = display_current_output(val_output, valdata_label,
                                                   iter00, directory_t, 'val')

        validation_mean_squared_error_x100 = 100 * \
            np.average(np.square(val_error))
        validation_bad_pixel_ratio = 100 * np.average(val_bp)

        save_path_file_new = (directory_ckp +
                              '/iter%04d_valmse%.3f_bp%.2f.hdf5' %
                              (iter00, validation_mean_squared_error_x100,
                               validation_bad_pixel_ratio))
        """
        Save bad pixel & mean squared error
        """
        print(save_path_file_new)
        f1 = open(txt_name, 'a')
        f1.write('.' + save_path_file_new + f'  loss={loss:.4f}\n')
        f1.close()
        t1 = time.time()
        ''' save model weights if it get better results than previous one...'''
        if (validation_bad_pixel_ratio < best_bad_pixel):
            best_bad_pixel = validation_bad_pixel_ratio
            model.save(save_path_file_new)
            print("saved!!!")
        else:
            model.save(save_path_file_new)

        print(f'EPOCH {iter02+1} ENDS')