# Import Library

In [None]:
import numpy as np
import math
import natsort
import cv2
import os
import matplotlib.pyplot as plt
from glob import glob

from PIL import Image

import tensorflow as tf
from tensorflow import keras

import tensorflow_datasets as tfds

from tensorflow.keras.utils import *
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import *

# Data Loder

In [None]:
def sortlist(filelist):
    filelist = natsort.natsorted(filelist)
    return filelist

In [None]:
dir_path = os.getenv("HOME")+ '/Cloud_data/cloud_train/'

In [None]:
class CloudGenerator(Sequence):
    def __init__(self,
                 dir_path,
                 batch_size = 4,
                 img_size = (800,800,3),
                 output_size = (800,800),
                 is_train = True):
        
        self.dir_path = dir_path
        self.batch_size = batch_size
        self.img_size = img_size
        self.output_size = output_size
        self.is_train = is_train

        self.data = self.load_dataset()
    
    def load_dataset(self):
        input_images = glob(os.path.join(self.dir_path,"patch_img","*png"))
        label_images = glob(os.path.join(self.dir_path,"patch_labeling","*png"))
        input_images = sortlist(input_images)
        label_images = sortlist(label_images)
        
        assert len(input_images) == len(label_images)
        data = [ _ for _ in zip(input_images, label_images)]
        
        train_percent = int(len(data) * 0.8)

        if self.is_train:
            return data[:train_percent]
        return data[train_percent:]

    def __getitem__(self, index):
        batch_data = self.data[
                           index*self.batch_size:
                           (index + 1)*self.batch_size
                           ]
        inputs = np.zeros([self.batch_size, *self.img_size])
        outputs = np.zeros([self.batch_size, *self.output_size])
            
        for i, data in enumerate(batch_data):
            input_img_path, output_path = data
            _input = cv2.imread(input_img_path)
            _output = cv2.imread(output_path,0)
            _output = (_output==50).astype(np.uint8)*1
            inputs[i] = _input/255
            outputs[i] = _output
            
            return inputs, outputs

    def __len__(self):
        return math.ceil(len(self.data) / self.batch_size)
    
    def on_epoch_end(self):
        self.indexes = np.arange(len(self.data))
        if self.is_train == True :
            np.random.shuffle(self.indexes)
            return self.indexes

In [None]:
train_generator = CloudGenerator(
    dir_path,
    is_train=True
)

test_generator = CloudGenerator(
    dir_path,
    is_train=False
)

# Build Atrous U-Net Model

In [None]:
def atrous_unet_encoder(x, channel = 64, kernel_size = (3,3), strides=1,activation='relu'):
    skip_connection = []
    for i in range(4):
        x = Conv2D(channel* 2**i, kernel_size, activation=activation, padding='same',kernel_initializer='he_normal')(x)
        x = Conv2D(channel* 2**i, kernel_size, activation=activation, padding='same',kernel_initializer='he_normal')(x)
        skip_connection.append(x)
        x = MaxPooling2D(pool_size=(2, 2))(x)
    return x, skip_connection

def atrous_unet_bottleneck(x,kernel_size = (3,3),activation='relu'):
    for i in range(6):
        x = Conv2D(1024, kernel_size, activation=activation, padding='same',dilation_rate=(2**i, 2**i),kernel_initializer='he_normal')(x)
    x = Dropout(0.5)(x)
    return x

def atrous_unet_decoder(x,skip_connection,channel = 64, kernel_size = (3,3),activation = 'relu'):
    for i in reversed(range(4)):
        x = Conv2DTranspose(channel* 2**i, 2, strides = (2,2), activation=activation,kernel_initializer='he_normal')(x)
        x = concatenate([skip_connection[i],x],axis =3)
        x = Conv2D(channel* 2**i, kernel_size, activation=activation, padding='same',kernel_initializer='he_normal')(x)
        x = Conv2D(channel* 2**i, kernel_size, activation=activation, padding='same',kernel_initializer='he_normal')(x)
    return x

def atrous_unet_model(input_shape=(800,800,3)):
    input = Input(input_shape)
    encoder,skip = atrous_unet_encoder(input)
    bottleneck = atrous_unet_bottleneck(encoder)
    decoder = atrous_unet_decoder(bottleneck,skip)
    output = Conv2D(1,1,activation='sigmoid')(decoder)
    Atrous_unet_model = Model(inputs = input,outputs = output)
    return Atrous_unet_model
    
Atrous_unet_model = atrous_unet_model()
Atrous_unet_model.summary()

In [None]:
Atrous_unet_model_path = os.getenv("HOME")+'/Cloud_data/cloud_model/seg_atrous_unet_model.h5'

# Atrous_unet_model = atrous_unet_model()
Atrous_unet_model.compile(optimizer = Adam(lr = 1e-4), loss = 'binary_crossentropy')
Atrous_unet_model.fit_generator(
     generator=train_generator,
     validation_data=test_generator,
     steps_per_epoch=len(train_generator),
     epochs=10,
 )

Atrous_unet_model.save(Atrous_unet_model_path)  #학습한 모델을 저장해 주세요.

In [None]:
atrous_unet_model_path = '/content/drive/MyDrive/Cloud_data/cloud_model/seg_atrous_unet_model.h5'

In [None]:
Atrous_unet_model = tf.keras.models.load_model(atrous_unet_model_path)