In [None]:
# -*- coding: utf-8 -*-
import os
import json
import time
import logging
import numpy as np 
import torch 

def save_option(option, path):
    option_path = os.path.join(path, "options.json")

    with open(option_path, 'w') as fp:
        json.dump(option.__dict__, fp, indent=4, sort_keys=True)

def logger_setting(exp_name, save_dir, path, debug):
    logger = logging.getLogger(exp_name)
    formatter = logging.Formatter('[%(name)s] %(levelname)s: %(message)s')

    log_out = os.path.join(path, 'train.log')
    file_handler = logging.FileHandler(log_out)
    stream_handler = logging.StreamHandler()

    file_handler.setFormatter(formatter)
    stream_handler.setFormatter(formatter)

    logger.addHandler(file_handler)
    logger.addHandler(stream_handler)

    if debug:
        logger.setLevel(logging.DEBUG)
    else:
        logger.setLevel(logging.INFO)
    return logger

class Timer(object):
    def __init__(self, logger, max_step, last_step=0):
        self.logger = logger
        self.max_step = max_step
        self.step = last_step

        curr_time = time.time()
        self.start = curr_time
        self.last = curr_time

    def __call__(self):
        curr_time = time.time()
        self.step += 1

        duration = curr_time - self.last
        remaining = (self.max_step - self.step) * (curr_time - self.start) / self.step / 3600
        msg = 'TIMER, duration(s)|remaining(h), %f, %f' % (duration, remaining)

        self.last = curr_time

# The following methods are used saving the images of the dataloaders in the 
# trainer_merger script see them in action in the /helpers folder:

def bias_label_interpretation(bias_label):
        
    label = bias_label[0].clone().detach().numpy()

    label_trans = (label*32)+16
    label_trans = label_trans.astype(np.uint8)

    b = label_trans.astype(np.uint8)
    b=b.transpose(1, 2, 0) #change to W x H x C

    # set green and red channels to 0
    b[:, :, 1] = 0
    b[:, :, 2] = 0
    
    g = label_trans.astype(np.uint8)
    g=g.transpose(1, 2, 0)
    # set blue and red channels to 0
    g[:, :, 0] = 0
    g[:, :, 2] = 0

    r = label_trans.astype(np.uint8)
    r=r.transpose(1, 2, 0)
    # set blue and green channels to 0
    r[:, :, 0] = 0
    r[:, :, 1] = 0
    
    return label_trans, r, g, b

def un_normalise_image(image):
    # reverse the steps of the normalisation at the end of the dataloader script:
    test_imgs = image[0].cpu().detach().numpy()
    #print(test_imgs.shape)
    test_imgs = np.transpose(test_imgs, (1, 2, 0)) #inverse transpose
    #print(test_imgs.shape)
    test_imgs = test_imgs * np.array([0.229, 0.224, 0.225])
    #print(test_imgs.shape)
    test_imgs = test_imgs + np.array([0.485, 0.456, 0.406])
    #print(test_imgs.shape)
    test_imgs = test_imgs*255

    return test_imgs

def detach_transpose_labels(label):
    #print(label.shape)
    test_label = label[0].cpu().detach().numpy()
    #print(test_label.shape)

    return test_label