In [None]:
#########################################################################
##########################  Helpers for  EDA ############################
#########################################################################


def visualize_img(df,nrows = 1,ncols = 1): 
    """
    This function is used to visualize the images.
    
    nrows : the number of rows of images
    ncols : the number of columns of images
    df: dataframe, with columns image id and class label
    
    return None
    """
    assert (nrows *  ncols >= len(df)),"Not all the images are shown. Please increase the number of ncols or nrows."

    labels = {0:'No DR',1:'Mild',2:'Moderate',3:'Severe',4:'Proliferative DR'}
    count = 1
    plt.figure(figsize=[20,3.4*nrows])
    sns.set_style("white")
    
    df_copy = df.copy().reset_index(drop = True)
    for img_id in df_copy['id_code']:
      #img_path = cv2.imread('../nput/aptos2019-blindness-detection/train_images/{}.png'.format(img_id))[:,:,::-1]
      img_path = cv2.imread(mypath +'/input/aptos2019-blindness-detection/train_images/{}.png'.format(img_id))[:,:,::-1]
      plt.subplot(nrows,ncols,count)
      plt.imshow(img_path)
      plt.title(labels[df_copy.diagnosis[count - 1]],fontsize = 14)
      count += 1 
    plt.show()
    plt.tight_layout()
    
    


def rgb2hsv(df):
    """
    This function convert RGb color to HSV color in order to detect low brightness images.
    HSV is Hue, Saturation and Value where hue represents the color, Saturation represents the greyness,
    and Value represents the brightness.

    df: dataframe, with columns image id and class label
    
    return datafrme
    """
    df_copy = df.copy().reset_index(drop = True)
    low_bright = pd.DataFrame(columns = ['id_code','diagnosis'])
    
    for i in range(len(df_copy)):

      img = cv2.imread(mypath + '/input/aptos2019-blindness-detection/train_images/{}.png'.format(df_copy.iloc[i].id_code))
      hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
      if hsv[...,2].mean() < 70 and (hsv[...,1].mean() < 90):
          low_bright = low_bright.append(df_copy.iloc[i])  
            
    return low_bright.reset_index(drop = True)




def enhance_bright_contrast(df):
    """
    The function enhances the brightness and contrast of the images and plot
    the modified images.

    df: dataframe, with columns image id and class label

    return None
    """
    
    df_copy = df.copy().reset_index(drop = True)
    labels = {0:'No DR',1:'Mild',2:'Moderate',3:'Severe',4:'Proliferative DR'}
    count = 1
    
    #1 gives the original image
    bright_factor =  1.4
    saturation_factor = 1.6
    contrast_factor = 1.1
    
    plt.figure(figsize=[20,3.4*np.ceil(len(df)/5)])
    sns.set_style("white")
    
    for i in range(len(df_copy)):
        img = Image.open(mypath + '/input/aptos2019-blindness-detection/train_images/{}.png'.format(df_copy.iloc[i].id_code))
        #improve brightness
        enhancer1 = ImageEnhance.Brightness(img)
        img_output1 = enhancer1.enhance(bright_factor)
        enhancer2 = ImageEnhance.Color(img_output1)
        img_output2 = enhancer2.enhance(saturation_factor)
        enhancer2 = ImageEnhance.Contrast(img_output2)
        img_output3 = enhancer2.enhance(contrast_factor)        
        plt.subplot(np.ceil(len(df)/5),5,count)
        plt.imshow(img_output3)
        plt.title(labels[df_copy.diagnosis[count - 1]],fontsize = 14)
        count += 1 
    plt.show()
    plt.tight_layout()


#########################################################################
###########  Helpers for Preprocessing and Modeling #####################
#########################################################################




def warmup_lr_scheduler(epoch, lr):
    '''
    This function is use to increase the warmup learning rate linearly each iteration.
    Assume the total warmup epochs = 5 and the max learning rate = 1e-3.

    epoch：current epoch
    lr: current learning rate 

    return float number
    '''
    warmup_epochs = 5
    max_lr = 1e-3
    increase_rate = max_lr / warmup_epochs
    return lr + increase_rate




class f1_score(Callback):
    '''
    A simple call back for getting the average macro f1_score each step.
      
    val_data: validation generator
    batch_size: model batch size
    save_best_mod: whether or not save the model weights as the macro f1_score is higher than before
    file_name: define the file name, the file will be over written by the current best macro f1_score

    return None
    ''' 
    def __init__(self, val_data, batch_size = 16,save_best_mod = False,file_name = ''):
        super().__init__()
        self.validation_data = val_data
        self.batch_size = batch_size
        self.save_best_mod = save_best_mod
        self.file_name = file_name
    
    def on_train_begin(self, logs={}):
        '''Initialize f1 score and best f1 score of the validation set.'''
        #print(self.validation_data)
        self.val_f1s = []
        self.best_val_f1 = 0
        #self.val_recalls = []
        #self.val_precisions = []
        
    def on_epoch_end(self, epoch, logs={}):
        '''Calculate f1 socre at the end of each eopch and save model weights for the best f1 score'''
        batches = len(self.validation_data)
        val_pred = []
        val_true = []
        
        for batch in range(batches):
            xVal, yVal = next(self.validation_data)

            val_pred = np.concatenate((val_pred,np.argmax(self.model.predict(xVal),-1)),axis = None)
            val_true = np.concatenate((val_true,np.argmax(yVal,-1)),axis = None)
        
        _val_f1 = f1_score(val_true, val_pred,average = 'macro')
        #_val_precision = precision_score(val_true, val_pred)
        #_val_recall = recall_score(val_true, val_pred)

        if self.save_best_mod == True:
          if _val_f1 > self.best_val_f1:
            self.model.save_weights(mypath + '/output/model/%s_best_valf1_%.4f.h5'%(file_name,_val_f1))
            self.best_val_f1 = _val_f1
            print('Best validation macro f1_score: %.4f so far! Saving model...'%(_val_f1))
        
        self.val_f1s.append(_val_f1)
        print(' - val_macro_f1: %f'%(_val_f1))
        #self.val_recalls.append(_val_recall)
        #self.val_precisions.append(_val_precision)

        return




class QWK_Score(Callback):
    '''
    A simple call back for getting the QWK score each step.
      
    val_data: validation generator
    batch_size: model batch size
    save_best_mod: whether or not save the model weights as the QWK socre is higher than before
    file_name: define the file name, the file will be over written by the current best QWK score

    return None
    ''' 
    def __init__(self, val_data, batch_size = 16,save_best_mod = False,file_name = ''):
        super().__init__()
        self.validation_data = val_data
        self.batch_size = batch_size
        self.save_best_mod = save_best_mod
        self.file_name = file_name
    
    def on_train_begin(self, logs={}):
        '''Initialize quadratic weighted kappa score and current best quadratic weighted kappa score'''
        self.qwk_scores = []
        self.best_qwk_score = 0
        
    def on_epoch_end(self, epoch, logs={}):
        batches = len(self.validation_data)
        val_pred = []
        val_true = []
        
        for batch in range(batches):
            xVal, yVal = next(self.validation_data)

            val_pred = np.concatenate((val_pred,np.argmax(self.model.predict(xVal),-1)),axis = None)
            val_true = np.concatenate((val_true,np.argmax(yVal,-1)),axis = None)

        _qwk_score = cohen_kappa_score(val_true, val_pred,labels = [0,1,2,3,4],weights = 'quadratic')

        if self.save_best_mod == True:
          if _qwk_score > self.best_qwk_score:
            self.model.save_weights(mypath + '/output/model/%s_best_qwk_%.4f.h5'%(self.file_name,_qwk_score))
            self.best_qwk_score = _qwk_score 
            print('Best qwk_score: %.4f so far! Saving model...'%(_qwk_score))
        
        self.qwk_scores.append(_qwk_score)
        print(' - qwk_score: %f'%(_qwk_score))
   
        return




class LRFinder(Callback):
    '''
    A simple callback for finding the optimal learning rate range for your model + dataset. 

    min_lr: The lower bound of the learning rate range for the experiment.
    max_lr: The upper bound of the learning rate range for the experiment.
    steps_per_epoch: Number of mini-batches in the dataset. Calculated as `np.ceil(epoch_size/batch_size)`. 
    epochs: Number of epochs to run experiment. Usually between 2 and 4 epochs is sufficient. 
        
    References:
    Blog post: jeremyjordan.me/nn-learning-rate
    Original paper: https://arxiv.org/abs/1506.01186
    '''
    
    def __init__(self, min_lr=1e-5, max_lr= 2e-3, steps_per_epoch=None, epochs=None):
        super().__init__()
        
        self.min_lr = min_lr
        self.max_lr = max_lr
        self.total_iterations = steps_per_epoch * epochs
        self.iteration = 0
        self.history = {}
        
    def clr(self):
        '''Calculate the learning rate.'''
        x = self.iteration / self.total_iterations 
        return self.min_lr + (self.max_lr-self.min_lr) * x
        
    def on_train_begin(self, logs=None):
        '''Initialize the learning rate to the minimum value at the start of training.'''
        logs = logs or {}
        K.set_value(self.model.optimizer.lr, self.min_lr)
        
    def on_batch_end(self, epoch, logs=None):
        '''Record previous batch statistics and update the learning rate.'''
        logs = logs or {}
        self.iteration += 1

        self.history.setdefault('lr', []).append(K.get_value(self.model.optimizer.lr))
        self.history.setdefault('iterations', []).append(self.iteration)

        for k, v in logs.items():
            self.history.setdefault(k, []).append(v)
            
        K.set_value(self.model.optimizer.lr, self.clr())
 
    def plot_lr(self):
        '''Helper function to quickly inspect the learning rate schedule.'''
        plt.plot(self.history['iterations'], self.history['lr'])
        plt.yscale('log')
        plt.xlabel('Iteration')
        plt.ylabel('Learning rate')
        plt.show()
        
    def plot_loss(self):
        '''Helper function to quickly observe the learning rate experiment results.'''
        plt.plot(self.history['lr'], self.history['loss'])
        plt.xscale('log')
        plt.xlabel('Learning rate')
        plt.ylabel('Loss')
        plt.show()




def plot_fit_history(fit_hist = []):
    '''
    This function plots the fit history, including the loss on both of the 
    training set and validation set along along with the epochs.
    
    fit_hist: a list of dictionary, where the dictionary is fit history
              (note: if there are more than one fit history, pass them in order,from the 
               earliest to latest.)
    
    return None
    '''
    keys  = ['loss','val_loss']

    append_fit_hist = defaultdict(list)
    for key in keys:
      for h in fit_hist:
        append_fit_hist[key] += h[key]

    total_epoches = len(append_fit_hist[keys[0]])
    plt.plot(append_fit_hist['loss'], label='Train loss')
    plt.plot(append_fit_hist['val_loss'], label='Validation loss')

    plt.xticks(np.arange(total_epoches),np.arange(1,total_epoches + 1))
    plt.legend(loc='best')
    sns.despine()
    plt.show()