In [None]:
import numpy as np
import pandas as pd
import os
import cv2
import matplotlib.pyplot as plt
import tensorflow as tf
from keras.layers import ZeroPadding2D, Conv2D,BatchNormalization,MaxPool2D, Dense, LeakyReLU, ReLU, Add,\
GlobalAveragePooling2D,GlobalMaxPooling2D,Softmax,Concatenate,Input
from keras import backend as K
from model.ResNeXt import ResNeXt50
from keras.losses import mae, sparse_categorical_crossentropy, binary_crossentropy
from keras.models import Model,load_model
from keras.optimizers import Adam, RMSprop
from keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from sklearn.model_selection import train_test_split


In [None]:
from albumentations import *
def Seq():
    seq = Compose([
        RandomRotate90(p=0.5),
        Transpose(p=0.5),
        Flip(p=0.5),
        OneOf([
            CLAHE(clip_limit=2),
            IAASharpen(), 
            IAAEmboss(),
            RandomBrightnessContrast(),
            JpegCompression(),
            Blur(),
            GaussNoise()]), 
        HueSaturationValue(p=0.5),
        ShiftScaleRotate(shift_limit=0.15, scale_limit=0.15, rotate_limit=45, p=0.5),
        Normalize(p=1)],p=1)
    return seq

def get_id_from_file_path(file_path):
    return file_path.split(os.path.sep)[-1].replace('.tif', '')

def HistogramEqualize(img):    
    img = cv2.cvtColor(img,cv2.COLOR_BGR2LAB)
    img[:,:,0] = cv.equalizeHist(img[:,:,0])
    return img

def CLAHE(img):    
    img = cv2.cvtColor(img,cv2.COLOR_BGR2LAB)    
    clahe = cv.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    img[:,:,0] = clahe.apply(img[:,:,0])
    return img

def chunk(seq, size):
    return (seq[pos:pos + size] for pos in range(0, len(seq), size))
 
def data_gen(list_files, id_label_map, batch_size, augment=False):
    seq = Seq()
    while True:        
        for batch in chunk(list_files, batch_size):
            X = [cv2.imread(x) for x in batch]
            Y = [id_label_map[get_id_from_file_path(x)] for x in batch]            
            if augment:
                X = [seq(image=x)['image'] for x in X]                    
            else:
                X = [preprocess_input(x) for x in X]
        yield np.array(X), np.array(Y) 
        

def resnext():
    inputs = Input((96, 96, 3))
    base = ResNeXt50(inputs=inputs)    
    x1 = Concatenate()([GlobalAveragePooling2D()(base),GlobalMaxPooling2D()(base)])
#     out = Flatten()(x1)
    dense1 = Dense(3072)(x1)# 2^10*3      
    dense1 = BatchNormalization()(dense1)      
    dense1 = ReLU()(dense1)
    
    dense2 = Dense(512)(dense1) #2^9
    dense2 = BatchNormalization()(dense2)
    dense2 = ReLU()(dense2)
    
    dense3 = Dense(256)(dense2) # 2^8
    dense3 = BatchNormalization()(dense3)
    dense3 = ReLU()(dense3)
    
    out = Dense(1, activation="sigmoid")(dense3)    
    model = Model(inputs, out)
    model.compile(optimizer=Adam(0.001), loss='binary_crossentropy', metrics=['acc'])
#     model.compile(optimizer=Adam(0.0007), loss=binary_crossentropy, metrics=['acc'])
    model.summary()

    return model



In [None]:
def main():
    df_train = pd.read_csv("./train_labels.csv")
    id_label_dic = {k:v for k,v in zip(df_train.id.values, df_train.label.values)}
    df_train.head()
    
    labeled_files = glob('./train/*.tif')
    test_files = glob('./test/*.tif')
    print("labeled_files size :", len(labeled_files))
    print("test_files size :", len(test_files))
    train, val = train_test_split(labeled_files, test_size=0.01, random_state=1)
    model  =resnext()
    
    batch_size=32
    h5_path = "ResNeXt_val_acc_{val_acc:.5f}.h5"
    checkpoint = ModelCheckpoint(h5_path, monitor='val_loss', verbose=1, save_best_only=True, save_weights_only=True ,mode='max')
    K.set_value(model.optimizer.lr,0.001)
    reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=2, min_lr=0.00001,verbose=1)
    with open('model.json','w') as f :
        f.write(model.to_json())
    
    history = model.fit_generator(
        data_gen(train, id_label_map, batch_size, augment=True),
        validation_data=data_gen(val, id_label_map, batch_size),    
        epochs=40, verbose=1,
        callbacks=[checkpoint,reduce_lr],
        steps_per_epoch=len(train) // batch_size,
        validation_steps=len(val) // batch_size)     
    with open('hist.h5','wb') as f:
        pickle.dump(history,f)
main()