In [None]:
!pip install madgrad

In [None]:
import os
from tqdm import tqdm

def fix_png(class_nums):
    total_sample = 6
    img_size = (128, 128)
    root = "images/hedged_images"
    background_src = "images/sources/imagenet/data_VGG_label"
    hedge_mask_src = "images/sources/hedge_masks"
                    
    assert type(total_sample) == int
    
    # val / test / train
    for dataset_type in os.listdir(root):
        dataset_type_path = os.path.join(root, dataset_type)
        
        # 0 - 1000
        for class_num in class_nums:
            class_num_path = os.path.join(dataset_type_path, str(class_num))
            
            # ILSVRC2012_XXX_XXXXX
            for img_id in os.listdir(class_num_path):
                img_id_path = os.path.join(class_num_path, img_id)
                
                # 0.1 - 0.8
                for density in os.listdir(img_id_path):
                    if "png" in density:
                        # background validation
                        try:
                            background_path = os.path.join(img_id_path, img_id+".png")
                            background = Image.open(background_path)
                        except:
                            print("background file {} is corrupted/missed".format(background_path))
                            background_src_item = os.path.join(background_src, class_num, img_id)
                            background = Image.open(background_src_item+".JPEG")
                            background.convert('RGB').save(background_src_item+".png")
                    
                    else:
                        density_path = os.path.join(img_id_path, density)
                        # 0-5
                        for i in range(total_sample):
                            
                            # hedge_mask/img validation
                            hedge_mask_path = os.path.join(density_path, str(i)+"_hedge.png")
                            img_path = os.path.join(density_path, str(i)+".png")
                            
                            try:
                                Image.open(hedge_mask_path)
                                Image.open(img_path)
                            except:
                                print("corrupted/miss file: \n image file {} \n hedge file {}".format(img_path, hedge_mask_path))
                                # randomly pick one new hedge file from src

                                hedge_root = os.path.join(hedge_mask_src, density)

                                hedge_mask_code = np.random.choice(os.listdir(hedge_root), size=1)[0]
                                hedge_mask_path = os.path.join(hedge_root, hedge_mask_code)
                                hedge_mask = Image.open(hedge_mask_path)

                                background_path = os.path.join(img_id_path, img_id+".png")
                                background = Image.open(background_path)
                                
                                # generate hedged image and hedge mask then save them
                                hedged_image, hedge_mask = generate_single_hedged_image(background, hedge_mask, img_size)
                                hedged_image.convert('RGB').save(img_path)
                                hedge_mask.convert('RGB').save(hedge_mask_path)
                                
if multi_processing:
    pool = Pool(cpus)
    print("multiprocessing will be run with {} threads".format(cpus))
    data = range(1000)
    pool.map(fix_png, [data[x:x+250] for x in range(0, 1000, 250)])
                    

# def fix_png():
#     path = "/home/jupyter/src/images/hedged_images"
#     for dataset_type in os.listdir(path):
#         dataset_type_path = os.path.join(path, dataset_type)
#         for class_num in tqdm(os.listdir(dataset_type_path)):
#             class_num_path = os.path.join(dataset_type_path, class_num)
#             for img_id in os.listdir(class_num_path):
#                 img_id_path = os.path.join(class_num_path, img_id)
#                 os.system('mogrify *.png')                
#                 for density in os.listdir(img_id_path):
#                     density_path = os.path.join(img_id_path, density)
#                     if not "png" in density_path:
#                         os.chdir(density_path)
#                         os.system('mogrify *.png')                
                    
                    
#     os.chdir("/home/jupyter/src")

# fix_png()


# image file /home/jupyter/src/images/hedged_images/train/743/ILSVRC2012_val_00031944/./0.png 
# or hedge file /home/jupyter/src/images/hedged_images/train/743/ILSVRC2012_val_00031944/./0_hedge.png is corrupted/missed


# TODO: 
1. deal with logger object (local -> global), otherwise logger will echo itself printing repeated messages

2. make sure all file names are stored in same format (e.g. png)

3. check torch.cuda memory allocation (before 11 total 3 reserved then run out of memory)

In [None]:
import matplotlib.pyplot as plt
# from google.colab import drive

root = ""

# select functions to run
multi_processing = True
generate_hedge_masks = False
generate_hedge_image = False
validate_dataset = False
train_flag = True

# Check CPU/GPU status

In [None]:
!nvidia-smi

In [None]:
!cat /proc/cpuinfo

In [None]:
import subprocess
p = subprocess.Popen('df -h', shell=True, stdout=subprocess.PIPE)
print(str(p.communicate()[0], 'utf-8'))

## Multiprocessing

In [None]:
import multiprocessing
from multiprocessing import Process, Pool

cpus = multiprocessing.cpu_count()

def split_data_into_batches(data, number_of_batch):
    portion_per_batch = int(len(data)/number_of_batch)
    batches = []
    for i in range(0, len(data), portion_per_batch):
        new_batch = data[i:i+portion_per_batch]
        batches.append(new_batch)
    return np.array(batches)

# Config (path related)

In [None]:
import torch
import argparse

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # sets device for model and PyTorch tensors
print_freq = 100  # print training/validation stats  every __ batches

# madgrad optimal lr is 0.001 (1e-3)

def parse_args():
    debug = False
    load_path_from_last_checkpoint = "checkpoint/recurrent_uncertainty_unet (madgrad+onecycle+no_dropout+512_complex+luv)/ep400(lr=0.001)/checkpoint.tar"
    # pretrained_UNet = "checkpoint/unet(madgrad+onecycle+no_dropout+512_complex)/ep200(lr=0.001)/BEST_checkpoint.tar"
    pretrained_UNet = "checkpoint/unet (madgrad+onecycle+no_dropout+512_complex+luv)/ep400(lr=0.0001)/BEST_checkpoint.tar"    
    # pretrained_Uncertainty_UNet = "checkpoint/uncertainty_unet(madgrad+ReduceLROnPlateau+no_dropout+512_complex)/ep700(lr=0.0001)/BEST_checkpoint.tar"
    pretrained_Uncertainty_UNet = "checkpoint/uncertainty_unet (madgrad+onecycle+no_dropout+512_complex+luv)/ep400(lr=1e-05)/BEST_checkpoint.tar"
    
    model_type = "recurrent_uncertainty_unet"
    
    if model_type == "unet" or model_type == "uncertainty_unet":
        masked_img_per_item = 1
        in_channel = 3
    elif model_type == "recurrent_uncertainty_unet":
        masked_img_per_item = 2
        in_channel = 9
    else:
        print("model type not supported")
    
    optimizer = "madgrad"    
    
    change_scheduler = False
    scheduler = "onecycle"

    caption = "no_dropout+512_complex+luv"
    model_name = "{} ({}+{}+{})".format(model_type, optimizer, scheduler, caption)
    
    change_lr = False
    lr = 1e-3
    
    end_epoch = 600
    early_stop = 999
    batch_size = 64
    
    parser = argparse.ArgumentParser(description='Train u-net')
    
    parser.add_argument('--debug', type=bool, default=debug, help='debug mode, reduce the size of dataset to run through train/val process faster')
    parser.add_argument('--tensorboard-fileName', type=str, default="{} - ep{}(lr={})".format(model_name, end_epoch, lr), help='tensorboard runs file name')
    
    # model related
    parser.add_argument('--unet-type', type=str, default=model_type, help='unet/uncertainty_unet/recurrent_uncertainty_unet')
    parser.add_argument('--checkpoint-load-path', type=str, default=load_path_from_last_checkpoint, help='checkpoint to pick up from last training session')
    parser.add_argument('--checkpoint-save-path', type=str, default="checkpoint/{}/ep{}(lr={})".format(model_name, end_epoch, lr), help='checkpoint path')
    parser.add_argument('--in-channel', type=int, default=in_channel, help='channel count for model input')
    parser.add_argument('--out-channel', type=int, default=3, help='channel count for model output (prediction)')
    parser.add_argument('--out-uncertainty-channel', type=int, default=3, help='channel count for model output (uncertainty)')    

    # optimizer related
    parser.add_argument('--optimizer', type=str, default=optimizer, help='optimizer(adam/SGD/madgrad)')
    parser.add_argument('--lr', type=float, default=lr, help='start learning rate')
    parser.add_argument('--end-epoch', type=int, default=end_epoch, help='training epoch size.')
    parser.add_argument('--change-lr', type=bool, default=change_lr, help='checkpoint')
    parser.add_argument('--clip-val', type=float, default=10, help='gradient clip value to prevent gradient explosion')
    #     SGD
    parser.add_argument('--momentum', type=float, default=0.9, help='momentum, range=[0,1)')
    parser.add_argument('--weight-decay', type=float, default=0, help='weight_decay L2 penalty')
    parser.add_argument('--nesterov', type=bool, default=False, help='nesterov momentum')
    #     madgrad
    parser.add_argument('--eps', type=float, default=1e-6, help='Term added to the denominator outside of the root operation to improve numerical stability (default: 1e-6)')
    
    # early stopping
    parser.add_argument('--early-stop-ep', type=int, default=early_stop, help='ep to stay patient for no validation loss improvement, if exceed this ep, training stop automatically ')
    
    # scheduler
    parser.add_argument('--scheduler', type=str, default=scheduler, help="None/ReduceLROnPlateau/exp/cosine/onecycle")
    parser.add_argument('--scheduler-step', type=int, default=5, help="n step until scheduler step and perform weight decay")
    parser.add_argument('--change-scheduler', type=bool, default=change_scheduler, help="change scheduler or not")
    parser.add_argument('--verbose', type=bool, default=True, help="verbose for scheduler (notification on changing lr)")
    #     onecycle
    parser.add_argument('--max-lr', type=int, default=1e-4, help="Upper learning rate boundaries in the cycle for each parameter group")
    parser.add_argument('--total-steps ', type=int, default=None, help="The total number of steps in the cycle. Note that if a value is not provided here, then it must be inferred by providing a value for epochs and steps_per_epoch. Default: None")
    # parser.add_argument('--epochs', type=int, default=eps, help="The number of epochs to train for. This is used along with steps_per_epoch in order to infer the total number of steps in the cycle if a value for total_steps is not provided. Default: None")
    # parser.add_argument('--steps-per-epoch', type=int, default=20, help="The number of steps per epoch to train for. This is used along with epochs in order to infer the total number of steps in the cycle if a value for total_steps is not provided. Default: None")
    parser.add_argument('--pct-start', type=float, default=0.3, help="The percentage of the cycle (in number of steps) spent increasing the learning rate. Default: 0.3")
    parser.add_argument('--anneal-strategy', type=str, default='cos', help="Specifies the annealing strategy: “cos” for cosine annealing, “linear” for linear annealing. Default: ‘cos’")
    parser.add_argument('--cycle-momentum', type=bool, default=True, help="If True, momentum is cycled inversely to learning rate between ‘base_momentum’ and ‘max_momentum’. Default: True")
    parser.add_argument('--base-momentum', type=float, default=0.85, help="Lower momentum boundaries in the cycle for each parameter group. Note that momentum is cycled inversely to learning rate; at the peak of a cycle, momentum is ‘base_momentum’ and learning rate is ‘max_lr’. Default: 0.85")
    parser.add_argument('--max-momentum', type=float, default=0.95, help="Upper momentum boundaries in the cycle for each parameter group. Functionally, it defines the cycle amplitude (max_momentum - base_momentum). Note that momentum is cycled inversely to learning rate; at the start of a cycle, momentum is ‘max_momentum’ and learning rate is ‘base_lr’ Default: 0.95")
    parser.add_argument('--div-factor', type=float, default=25, help="Determines the initial learning rate via initial_lr = max_lr/div_factor Default: 25")
    parser.add_argument('--final-div-factor', type=float, default=1e4, help="Determines the minimum learning rate via min_lr = initial_lr/final_div_factor Default: 1e4")
    parser.add_argument('--three-phase', type=bool, default=False, help="If True, use a third phase of the schedule to annihilate the learning rate according to ‘final_div_factor’ instead of modifying the second phase (the first two phases will be symmetrical about the step indicated by ‘pct_start’).")
    #     cosineWarmRestart
    parser.add_argument('--T-0', type=int, default=13, help="Number of iterations for the first restart.")
    parser.add_argument('--T-mult', type=int, default=2, help="A factor increases T1  after a restart. Default: 1")
    parser.add_argument('--eta-min', type=int, default=0, help="minimum learning rate for cosine curve")
    #     ReduceLROnPlateau
    parser.add_argument('--factor', type=float, default=0.5, help="factor to decay weight, default=0.1")
    parser.add_argument('--min-lr', type=float, default=1e-12, help="minimum lr to be set, default=1e-4")
    parser.add_argument('--mode', type=str, default='min', help="choose min for loss, max for accuracy")
    parser.add_argument('--patience', type=int, default=10, help="stay patience for n epoch until decaying LR, default=10")
    
    # dataset related
    parser.add_argument('--train-path', type=str, default="images/hedged_images/train", help='path to training data')
    parser.add_argument('--val-path', type=str, default="images/hedged_images/val", help='path to val data')
    parser.add_argument('--test-path', type=str, default="images/hedged_images/test", help='path to test data')
    parser.add_argument('--dataset', type=str, default='static dataset', help='specify dynamic/static dataset')
    parser.add_argument('--masked-img-per-item', type=int, default=masked_img_per_item, help='number of (masked image + ground truth) pairs output by the trainloader each iteration')
    parser.add_argument('--batch-size', type=int, default=batch_size, help='batch size in each context')
    parser.add_argument('--num-workers', type=int, default=multiprocessing.cpu_count(), help='number of cpu workers')
    parser.add_argument('--pin-memory', type=bool, default=False, help='If True, the data loader will copy tensors into CUDA pinned memory before returning them, and data transfer to GPU can be faster')
    parser.add_argument('--color-mode', type=str, default="LUV", help='all loaded images will be convert to the color mode (RGB/HSV/LAB/LUV)')
    
    # image related
    parser.add_argument('--image-size', type=tuple, default=(128, 128), help='input image size')
    parser.add_argument('--min-density', type=float, default=0.1, help='max value for hedge density')
    parser.add_argument('--max-density', type=float, default=0.9, help='max value for hedge density')
    parser.add_argument('--sample-per-image', type=int, default=32, help='sample per image')
    
    # pretrained checkpoint
    parser.add_argument('--pretrained-MSE-unet', type=str, default=pretrained_UNet, help='pretrained single frame MSE unet')
    parser.add_argument('--pretrained-uncertainty-unet', type=str, default=pretrained_Uncertainty_UNet, help='pretrained uncertainty models for first frame')

    args = parser.parse_args(args=[])
    return args

# Utils

In [None]:
import os
import matplotlib.pyplot as plt
from matplotlib.colors import hsv_to_rgb
import cv2
import numpy as np
from random import getrandbits
import torch
import logging
import urllib.error
import urllib3.exceptions
import json
from multiprocessing import Pool

In [None]:
def random_bool():
    return bool(getrandbits(1))


def make_dir(path):
    try:
        os.makedirs(path, exist_ok=True)
    except OSError:
        print("OSError when creating directory {}".format(path))

def save_json(item, dest_file_name):
    with open(dest_file_name, "w") as json_file:
        json.dump(item, json_file)
    
def load_json(path):
    return json.load(open(path))

## logger

In [None]:
import sys

def get_logger(path):
    # prevent adding multiple handler causing logger to echo
    if 'logger' in globals():
        logger = logging.getLogger()
        fh = logging.FileHandler(os.path.join(path, "log_info.log"))
        fh.setLevel(logging.DEBUG)
        logger.handlers[1] = fh
        return logger
    
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)    
    formatter = logging.Formatter("%(asctime)s %(levelname)s \t%(message)s")

    # print in ide
    handler = logging.StreamHandler()
    handler.setFormatter(formatter)
    logger.addHandler(handler)

    # write to file
    fh = logging.FileHandler(os.path.join(path, "log_info.log"))
    fh.setLevel(logging.DEBUG)
    logger.addHandler(fh)
        
    return logger

args = parse_args()
make_dir(args.checkpoint_save_path)
logger = get_logger(args.checkpoint_save_path)
    

In [None]:
# args = parse_args()
# to_tensor = torchvision.transforms.ToTensor()
# path = "images/hedged_images/train/115/ILSVRC2012_val_00018470/0.4/0.png"

# cv_im = cv2.imread(path)
# pil_im = np.array(Image.open(path).convert("RGB"))


# # RGB_im = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)
# # print(RGB_im == pil_im)

# LAB_im = cv2.cvtColor(cv_im, cv2.COLOR_RGB2LAB)
# RGB_im = cv2.cvtColor(LAB_im, cv2.COLOR_LAB2RGB)


# LAB_im_tensor = to_tensor(cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2LAB))
# RGB_im_from_tensor = LAB_im_tensor.numpy()
# RGB_im_from_tensor = np.transpose(RGB_im_from_tensor, (1, 2, 0))

# RGB_im_from_tensor = cv2.cvtColor((RGB_im_from_tensor*255).astype('uint8'), cv2.COLOR_LAB2RGB)


# plt.imshow(RGB_im)
# plt.figure()
# plt.imshow(RGB_im_from_tensor)


## plot/images related

plot graph - used to plot 2D curves
imshow - used to display single image
show_tensor_images - used to display multiple tensor images

In [None]:
def pil_2_cv2(img):
    return (img*255).astype('uint8')

def lab_to_rgb(img):
    return cv2.cvtColor(pil_2_cv2(img), cv2.COLOR_LAB2RGB)

def luv_to_rgb(img):
    return cv2.cvtColor(pil_2_cv2(img), cv2.COLOR_Luv2RGB)


def plot_graph(title, x_label, y_label, data_batch_x, data_batch_y, legend, path, x_lim=None, y_lim=None):
    plt.figure(figsize=(12, 7))
    plt.title(title)
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    if x_lim:
        plt.xlim(x_lim)
        plt.ylim(y_lim)
    for data_x, data_y, legend in zip(data_batch_x, data_batch_y, legend):
        plt.plot(data_x, data_y, label=legend)
    plt.legend()
    plt.savefig(path)
    plt.show()


    
def imshow(img, title="", clip=False, colorbar=False):
    np_img = img.numpy()
    min_val = np.min(np_img)
    max_val = np.max(np_img)
    
    # greyscale or RGB
    if np_img.shape[0] == 1 or len(np_img.shape) == 2:
        plt.imshow(np_img[0], cmap="gray", vmin=0, vmax=1)
    else:
        np_img = np.transpose(np_img, (1, 2, 0))
            
        if args.color_mode == "HSV":
            np_img = hsv_to_rgb(np_img)
        elif args.color_mode == "LAB":
            np_img = lab_to_rgb(np_img)
        elif args.color_mode == "LUV":
            np_img = luv_to_rgb(np_img)       
        
        # clip / normalize
        if clip and args.color_mode == "RGB":
            np_img = np.clip(np_img, 0, 1)
        else:
            # adding 1/256 to avoid division by 0
            np_img = (np_img - np.min(np_img)) / (np.ptp(np_img) + 1/256)        
        
        plt.imshow(np_img)

    if colorbar:
        plt.colorbar()    
    plt.title(title)
    plt.text(0, 150, "{} range = [{min_val:.3f}, {max_val:.3f}]".format(title, min_val=min_val, max_val=max_val))


def show_tensor_images(tensor_imgs, title=[], path=None):
    try:
        plt.figure(figsize=(len(tensor_imgs) * 5, 5))
        for i, tensor_img in enumerate(tensor_imgs):
            plt.subplot(1, len(tensor_imgs), i + 1)
            if title:
                imshow(tensor_img.detach().cpu(), title[i])
            else:
                imshow(tensor_img.detach().cpu())
        if path:
            plt.savefig(path)
        plt.show()
    except urllib3.exceptions.HTTPError and urllib.error.HTTPError:
        print("error: please close plots to free memory")

In [None]:
def add_images_to_imgdict(img_dict, image, name, add_rgb_map=False):
    img_dict[name] = image
    if add_rgb_map:
        channel = ["(R)", "(G)", "(B)"]
        for i, add_rgb in enumerate(channel):
            img_dict[name + add_rgb] = np.transpose(image, (1, 2, 0))[:, :, i].reshape(1, 128, 128)


def get_img_dict(masked_image, pred, ground_truth, uncertainty):
    """ get a dict containing [key=name, val=[np images] ] to show images """
    img_dict = dict()

    # turn argument directly into numpy arrays
    masked_image = masked_image.detach().cpu()
    pred = pred.detach().cpu()
    ground_truth = ground_truth.detach().cpu()
    uncertainty = uncertainty.detach().cpu()
    
    # loss function map
    loss_map = np.exp(-uncertainty) * (pred - ground_truth) ** 2 + uncertainty
    
    # sigma map (std deviation)
    sigma = np.exp(uncertainty / 2)

    # abs error map
    abs = torch.sqrt((pred - ground_truth) ** 2)  

    # z-score image = abs/sig (we clip all negative values to 0)
    z_score = abs / sigma
    z_score = np.clip(z_score, 0, torch.max(z_score))

    # store results in image dict so we could store the title as key and images as value
    add_images_to_imgdict(img_dict, masked_image, "masked image", add_rgb_map=False)
    add_images_to_imgdict(img_dict, pred, "clipped prediction", add_rgb_map=False)
    add_images_to_imgdict(img_dict, ground_truth, "ground truth", add_rgb_map=False)
    add_images_to_imgdict(img_dict, loss_map, "loss map", add_rgb_map=False)
    # for the following images we add R, G and B map to observe each color channel
    add_images_to_imgdict(img_dict, uncertainty, "uncertainty map", add_rgb_map=True)
    add_images_to_imgdict(img_dict, sigma, "sigma map", add_rgb_map=True)
    add_images_to_imgdict(img_dict, abs, "abs map", add_rgb_map=True)
    add_images_to_imgdict(img_dict, z_score, "z score", add_rgb_map=True)

    return img_dict


def plot_img_dicts(img_dict, path):
    try:
        plt.figure(figsize=(20, 15))
        for i, (name, img) in enumerate(img_dict.items()):
            plt.subplot(int(np.ceil(len(img_dict) / 4)), 4, i + 1)
            imshow(img, name, name.__contains__("clipped"))

        plt.subplots_adjust(left=0.1, bottom=0.1, right=0.9, top=0.9, wspace=0.4, hspace=0.4)
        if path:
            plt.savefig(path)
        plt.show()
    except urllib3.exceptions.HTTPError and urllib.error.HTTPError:
        print("error: please close plots to free memory")


def show_uncertainty_result(masked_image, pred, ground_truth, uncertainty, path=None):
    img_dict = get_img_dict(masked_image, pred, ground_truth, uncertainty)
    plot_img_dicts(img_dict, path)

## image augmentation

In [None]:
# because the ground truth was stored as its original size
def resize_images(root_dir, size=(128, 128)):
    """ used to correct the size for stored ground truth image files """
    for label in os.listdir(root_dir):
        label_path = os.path.join(root_dir, label)
        for image in os.listdir(label_path):
            image_path = os.path.join(label_path, image, image + ".png")
            ground_truth_image = Image.open(image_path)
            ground_truth_image = ground_truth_image.resize(size)
            ground_truth_image.save(image_path)

## AverageMeter

In [None]:
class AverageMeter(object):
    """
    used to store and update the loss value
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def update_meters(meters, values):
    for meter, value in zip(meters, values):
        meter.update(value)

## checkpoint related 

In [None]:
def save_checkpoint(epoch, epochs_since_improvement, u_net_model, optimizer, scheduler, val_loss, best_loss, is_best, checkpoint_path):
    
    if isinstance(u_net_model, torch.nn.DataParallel):
        u_net_model = u_net_model.module
    
    state = {'epoch': epoch,
             'epoch_since_improvement': epochs_since_improvement,
             'u_net_model': u_net_model,
             'optimizer': optimizer,
             'scheduler': scheduler,
             'val_loss': val_loss,
             'best_loss': best_loss,
             'is_best': is_best}
    
    file_name = os.path.join(checkpoint_path, 'checkpoint.tar')
    torch.save(state, file_name)
    if is_best:
        logger.info("new best weight with validation loss = {}".format(val_loss))
        file_name = os.path.join(checkpoint_path, 'BEST_checkpoint.tar')
        torch.save(state, file_name)
    logger.info("check point saved")


def load_checkpoint_model(checkpoint):
    checkpoint = torch.load(checkpoint)
    return checkpoint['u_net_model']


def choose_best_checkpoint(path):
    """ get the best checkpoint based on validation loss from each epoch segments"""
    min_loss = float('inf')
    best_path = ""
    for ep in os.listdir(path):
        ep_path = os.path.join(path, ep, "BEST_checkpoint.tar")
        checkpoint = torch.load(ep_path)
        if checkpoint["loss"] < min_loss:
            min_loss = checkpoint["loss"]
            best_path = ep_path

    return best_path

# Hedge Generation

In [None]:
import os
import time
from queue import Queue
import numpy as np
from PIL import Image, ImageOps

# from utils import random_bool, make_dir

def rename_hedge_masks(src_path, start_num):
    # used to rename hedge images based on their code (when merging hedge masks together into 1 directory)
    try:
        for density in enumerate(os.listdir(src_path)):
            density_path = os.path.join(src_path, density)
            for i, mask_path in enumerate(os.listdir(density_path)):
                src_path = os.path.join(density_path, mask_path)
                new_path = os.path.join(density_path, str(start_num+int(mask_path.split("_")[0])) + "_" + mask_path.split("_")[1])
                os.rename(src_path, new_path)

    except FileNotFoundError:
        logger.info("No such file or directory: " + src_path)


# rename_hedge_mask("/content/drive/MyDrive/Image Dehedger Project/images/sources/hedge masks", 212)


## algorithm 1: generate transparent holes
- generate holes in hedges


In [None]:
def generate_transparent_holes(hedge_img, hedge_density=0.3):
    """
    :param hedge_img: PIL Image File
    :param hedge_density: decide how dense is the hedge mask
    :return hedge mask image
    """
    hedge_img = hedge_img.convert('RGBA')  # RGBA (RGB Alpha)
    pixels = hedge_img.getdata()  # convert to ImagingCore object (containing pixel values as tuples)

    # get all the red channel, sort and find the threshold (1 - density)
    r_pixels = [pixel[0] for pixel in pixels]
    r_pixels.sort()
    r_threshold = r_pixels[int((1 - hedge_density) * (len(r_pixels) - 1))]

    # for each pixel if the red channel is smaller than the threshold, replace it with white pixel
    new_pixels = []
    for pixel in pixels:
        if pixel[0] <= r_threshold:
            new_pixels.append((255, 255, 255, 0))
        else:
            new_pixels.append(pixel)

    # place the rearranged pixels into the image object
    hedge_img.putdata(new_pixels)
    return hedge_img

## algorithm 2: hedge mask denoising
- remove patches of hedges that is too small

In [None]:
def search_pixel_group(m_img, i, j, width, height, pixel_group_ids, pixel_group_id):
    """
    an algorithm that "spreads" out from pixel (i, j) and check if its neighbor is non-hedge
    if it is non-hedge then they would be allocated into the same group
    :param m_img: image to be searched
    :param i: index i of pixel
    :param j: index j of pixel
    :param width: width of image
    :param height: height of image
    :param pixel_group_ids: list that stores the group ID
    :param pixel_group_id: id for the current pixel group
    :return: area of current pixel group
    """
    pixel_queue = Queue()  # pixels to be processed
    pixel_queue.put((i, j))
    pixel_group_ids[i, j] = pixel_group_id

    area = 0

    while not pixel_queue.empty():
        (i, j) = pixel_queue.get()
        area += 1
        # for all neighbor (all pixels around current pixel)
        for di in [-1, 0, 1]:
            for dj in [-1, 0, 1]:
                # make sure the index is within the scope of image
                if 0 <= i + di < height and 0 <= j + dj < width:
                    if (m_img[i + di, j + dj] > 0) and (pixel_group_ids[i + di, j + dj] == 0):
                        pixel_group_ids[i + di, j + dj] = pixel_group_id
                        # add all neighboring unvisited non-hedge pixel to the queue
                        pixel_queue.put((i + di, j + dj))

    return area


def hedge_denoising(hedge_img, noise_area_threshold=200):
    """
    remove all pixel segments with area smaller than MAX_NOISE_AREA
    :param hedge_img: image to be de-noised
    :param noise_area_threshold: the area threshold for the pixel group to remain in the image
    :return: denoised hedge mask
    """
    width, height = hedge_img.size
    pixels = hedge_img.getdata()
    m_img = np.asarray(pixels)[..., 3].reshape(height, width)

    pixel_group_ids = np.zeros(m_img.shape).astype(int)  # store group ID for each non-hedge pixels
    area = np.zeros(m_img.size).astype(int)  # stores area for each group ID

    pixel_group_id = 0
    new_pixels = []
    for x in range(height):
        for j in range(width):
            # check if current pixel is non-hedge, if it is hedge, then just append
            if m_img[x, j] > 0:
                # if the pixel haven't been visited
                if pixel_group_ids[x, j] == 0:
                    pixel_group_id += 1
                    area[pixel_group_id] = search_pixel_group(m_img, x, j, width, height, pixel_group_ids,
                                                              pixel_group_id)

                # if the current pixel belongs to a group with area smaller than noise_threshold, remove it
                if area[pixel_group_ids[x, j]] < noise_area_threshold:
                    new_pixels.append((255, 255, 255, 0))
                else:
                    new_pixels.append(pixels[len(new_pixels)])
            else:
                new_pixels.append(pixels[len(new_pixels)])

    hedge_img.putdata(new_pixels)
    return hedge_img


## algorithm 3: compute density
- the density assigned in algorithm 1 will be affected by algorithm 2 since it removes parts of hedges that is too small, so the density have to be re-calculated

In [None]:
def compute_density(hedge_img):
    """ compute the actual hedge density of the hedge mask"""
    width, height = hedge_img.size
    pixels = hedge_img.getdata()
    m_img = np.asarray(pixels)[..., 3].reshape(height, width)
    return np.sum(m_img > 0) / m_img.size

## generate hedge masks

In [None]:
def generate_hedge_mask(hedge_img, hedge_density, noise_area_threshold):
    """
    1) first generate holes with algorithm 1
    2) then remove noises using algorithm 2
    3) finally compute the hedge density
    """
    start = time.time()
    logger.info("step 1 - generate holes")
    hedge_mask = generate_transparent_holes(hedge_img, hedge_density)
    logger.info("step 2 - denoising")
    hedge_mask = hedge_denoising(hedge_mask, noise_area_threshold)
    logger.info("step 3 - density calculation")
    hedge_density = round(compute_density(hedge_mask), 1)
    return hedge_mask, hedge_density


def generate_all_hedge_masks(resource_dir, save_dir, temp_dir, min_density, max_density, density_step, noise_area_threshold, processed_count=0):
    for i, item in enumerate(os.listdir(resource_dir)):
        logger.info("processing image {}: {}".format(i, item))
        hedge_path = os.path.join(resource_dir, item)
        try:
            hedge_img = Image.open(hedge_path)
            # generate hedge masks for every 5%
            for density in np.arange(min_density, max_density, density_step):
                logger.info("processing density: {}".format(density))
                img = hedge_img.copy()
                hedge_mask, result_density = generate_hedge_mask(img, density, noise_area_threshold)
                file_name = str(i + processed_count) + "_" + str(int(density * 100)) + ".png"
                save_path = os.path.join(save_dir, str(result_density), file_name)
                hedge_mask.save(save_path)

            os.replace(hedge_path, os.path.join(temp_dir, item))
        except(OSError, NameError):
            logger.info("OSError, Path:", hedge_path)

## execute hedge mask generation code

In [None]:
if generate_hedge_masks:
    source_path = os.path.join(root, "images/sources/hedge_images")
    save_dir = os.path.join(root, "images/sources/hedge_masks")
    temp_dir = os.path.join(root, "images/sources/hedge_images(processed)")
    noise_area_threshold = 800
    min_density=0.1
    max_density=0.81
    density_step=0.05
    processed_count=0
    
    for density in np.arange(0.1, 0.91, 0.1):
        current_path = os.path.join(save_dir, str(round(density, 1)))
        os.makedirs(current_path, exist_ok=True)

    generate_all_hedge_masks(source_path, save_dir, temp_dir, min_density, max_density, density_step, noise_area_threshold, processed_count)
    
    drive.flush_and_unmount()

## generate hedged images

In [None]:
def generate_single_hedged_image(input_img, hedge_mask, img_size=(128, 128)):
    crop_width = crop_height = 1000
    current_background = input_img.copy()

    hedge_width, hedge_height = hedge_mask.size
    if min(hedge_width, hedge_height) < max(crop_width, crop_height):
        crop_width = crop_height = min(hedge_width, hedge_height)

    current_background = current_background.resize((crop_width, crop_height))

    # find a random point to crop
    left = np.random.randint(0, hedge_width - crop_width + 1)
    top = np.random.randint(0, hedge_height - crop_height + 1)

    current_hedge_mask = hedge_mask.crop((left, top, left + crop_width, top + crop_height))
    current_background.paste(current_hedge_mask, (0, 0), mask=current_hedge_mask)
    current_background = current_background.resize(img_size)
    current_hedge_mask = current_hedge_mask.resize(img_size)
    return current_background, current_hedge_mask      


def generate_all_hedged_images(hedge_mask_dir, background_paths, target_dir, img_size=(128, 128), sample_no_per_image=10, prefix=0):
    """
    1. store the ground truth, generated hedged images into target path
    2. record the corresponding hedge masks used for each hedged images
        ground truth -> (target_dir/img_code/img_code.png) done
        hedged image -> (target_dir/img_code/0.1/0.png) 
        hedge mask record -> (target_dir/img_code/0.1/0.json)
        
    :param hedge_mask_dir: directory containing hedge masks directories grouped with their densities
    :param background_paths: a list containing all the full imagenet paths
    :param target_dir: directory that would be used to store the generated hedged images (target_dir/img_code/0.1/0.png)
    :param img_size: decides the size of the generated hedged image
    :param sample_no_per_image: number of hedged images generated for each background
    :param prefix: used when adding extra hedge_images to existed datasets to prevent overwritting data
    """
    # iterate through the 2D background paths
    for background_in_each_class in background_paths:
        dest_label_path = os.path.join(target_dir, background_in_each_class[0].split("/")[-2])
        make_dir(dest_label_path)

        for full_path in background_in_each_class:
            img_code, class_name = full_path.split("/")[-1].split(".")[0], full_path.split("/")[-2]
            logger.info("processing img name: {} with class label: {}".format(img_code, class_name))


            # create directory (target_dir/class_name/img_code)
            dest_img_code_path = os.path.join(dest_label_path, img_code)
            make_dir(dest_img_code_path)

            # read background image
            background = Image.open(full_path)
            
            # save the ground_truth (target_dir/img_code/gt.png)
            ground_truth = background.copy()
            ground_truth = ground_truth.resize(img_size)
            ground_truth.convert('RGB').save(os.path.join(dest_img_code_path, img_code+".png"))
            
            for hedge_density in os.listdir(hedge_mask_dir):
                # create directory (target_dir/img_code/0.1)
                hedge_masks_path = os.path.join(hedge_mask_dir, hedge_density)
                hedge_masks = os.listdir(hedge_masks_path)
                dest_hedge_density_path = os.path.join(dest_img_code_path, hedge_density)
                make_dir(dest_hedge_density_path)

                # repeat for sample_no_per_image times
                for i in range(sample_no_per_image):
                    img_name = os.path.join(dest_hedge_density_path, str(i+prefix)+".png")
                    hedge_name = os.path.join(dest_hedge_density_path, str(i+prefix)+"_hedge.png")
                    
                    # if img already existed, just skip it
                    if img_name in os.listdir(dest_hedge_density_path) and hedge_name in os.listdir(dest_hedge_density_path):
                        continue

                    # randomly choose a hedge mask from current hedge density
                    hedge_mask_path = np.random.choice(hedge_masks, size=1)[0]
                    hedge_mask_path = os.path.join(hedge_masks_path, hedge_mask_path)
                    hedge_mask_name = hedge_mask_path.split(".")[0]                    
                    hedge_mask = Image.open(hedge_mask_path)
                    # generate hedged image and hedge mask then save them
                    hedged_image, hedge_mask = generate_single_hedged_image(background, hedge_mask, img_size)
                    hedged_image.convert('RGB').save(img_name)
                    hedge_mask.convert('RGB').save(hedge_name)  



## execute hedged image generation code

In [None]:
import os

# 40gb for 1 image in [train/val/test]

if generate_hedge_image:
    target_root_dir = os.path.join(root, "images/hedged_images")
    make_dir(target_root_dir)
    train_background_paths = os.path.join(root, "images/sources/imagenet/total_train_paths.json")
    val_background_paths = os.path.join(root, "images/sources/imagenet/total_val_paths.json")
    test_background_paths = os.path.join(root, "images/sources/imagenet/total_test_paths.json")

    train_background_paths = load_json(train_background_paths)
    val_background_paths = load_json(val_background_paths)
    test_background_paths = load_json(test_background_paths)
    
    train_background_paths.sort()
    val_background_paths.sort()
    test_background_paths.sort()    
    
    hedge_dir = os.path.join(root, "images/sources/hedge_masks")
    data_types = [["train", train_background_paths], ["val", val_background_paths], ["test", test_background_paths]]
    # data_types = [["train", train_background_paths]]
    # data_types = [["val", val_background_paths]]
    # data_types = [["test", test_background_paths]]
    img_size = (128, 128)
    sample_no_per_image = 2
    prefix = 4

    for data_type, background_paths in data_types:
        target_dir = os.path.join(target_root_dir, data_type)
        make_dir(target_dir)

        # split the background_paths into #cpu parts        
        path_batches = split_data_into_batches(background_paths, cpus)

        
        if multi_processing:
            pool = Pool(cpus)
            print("multiprocessing will be run with {} threads".format(cpus))
            # generate_all_hedged_images(hedge_dir, background_paths, target_dir)
            pool.starmap(generate_all_hedged_images, [(hedge_dir, background_path, target_dir, img_size, sample_no_per_image, prefix) for background_path in path_batches])


## Validate synthetic images

In [None]:
def hedged_images_validation(path, label_count, img_count, hedge_density_count, sample_count):
    # 1000 class in total
    assert len(os.listdir(path)) == label_count

    for class_label in os.listdir(path):
        class_path = os.path.join(path, class_label)
        # train:32, val=8, test=10
        assert len(os.listdir(class_path)) == img_count
        
        for img_label in os.listdir(class_path):
            img_path = os.path.join(class_path, img_label)
            content_in_density_dir = os.listdir(img_path)
            # 8 density + 1 ground truth
            assert len(content_in_density_dir) == hedge_density_count
            for i in range(1, 9):
                assert str(round(i/10, 1)) in content_in_density_dir
            assert img_label+".png" in content_in_density_dir

            for hedge_label in content_in_density_dir:
                if "png" not in hedge_label:
                    try:
                        assert len(os.listdir(os.path.join(img_path, hedge_label))) == sample_count
                    except:
                        print(img_path, hedge_label)
                        print(len(os.listdir(os.path.join(img_path, hedge_label))), sample_count)
                        
if validate_dataset:
    label_count = 1000
    hedge_density_count = 8+1
    sample_count = 6

    train_path = os.path.join(root, "images/hedged_images/train")
    val_path = os.path.join(root, "images/hedged_images/val")
    test_path = os.path.join(root, "images/hedged_images/test")

    hedged_images_validation(train_path, 1000, 32, hedge_density_count, sample_count*2)
    hedged_images_validation(val_path, 1000, 8, hedge_density_count, sample_count*2)
    hedged_images_validation(test_path, 1000, 10, hedge_density_count, sample_count*2)


    assert len(os.listdir(train_path)) == len(os.listdir(val_path)) == len(os.listdir(test_path)) == 1000

# Loss Functions

In [None]:
import numpy as np
import torch
import torch.nn.functional as F


def vgg_accuracy(model, images, labels):
    # there is 64 labels and images coming in
    model.eval()
    pred = model(images)
    label = labels.reshape(labels.shape[0], 1)

    # top-1
    _, predicted = torch.topk(pred.data, k=1, dim=1)
    top_1_result = torch.sum(predicted == label).item()

    # top-5
    _, predicted = torch.topk(pred.data, k=5, dim=1)
    top_5_result = np.sum([label[i] in predicted[i] for i in range(labels.shape[0])])

    size = labels.shape[0]
    return top_1_result / size, top_5_result / size


def uncertainty_MSE_loss(pred, uncertainty, target):
    mse = F.mse_loss(pred, target, reduction="none")
    loss = torch.mean(torch.exp(-uncertainty) * mse + uncertainty)
    return loss, mse.detach().mean()


# Datasets

In [None]:
import random
import numpy as np
import torch
from torch.utils.data import Dataset
import os
from PIL import Image, UnidentifiedImageError
import torchvision

## Static Hedge Dataset
Used to load pre-generated hedged images

In [None]:
class StaticHedgedDataset(Dataset):
    def __init__(self, root_dir, density_range, masked_img_per_item, require_hedge_mask=False, color_mode="RGB", debug=False):
        # read args
        self.root_dir = root_dir
        self.density_range = density_range
        self.masked_img_per_item = masked_img_per_item
        self.require_hedge_mask = require_hedge_mask
        
        # collect dataset size
        self.label_count = -1
        self.image_count = -1
        self.density_count = -1
        self.sample_count = -1
        self.available_samples = None
        self.get_image_counts()

        # read image paths
        self.image_path = []
        self.read_image_code_path()
        
        # validation on dataset
        self.image_code_path_validation()
        self.to_tensor = torchvision.transforms.ToTensor()

        self.debug = debug
        self.color_mode = color_mode
        

    def __len__(self):
        if self.debug:
            return 100
        else:
            return self.label_count * self.image_count


    def read_img(self, path):
        im = np.array(Image.open(path).convert("RGB"))
        if self.color_mode == "RGB":
            return self.to_tensor(im)
        elif self.color_mode == "HSV":
            return self.to_tensor(cv2.cvtColor(im, cv2.COLOR_RGB2HSV))
        elif self.color_mode == "LAB":
            return self.to_tensor(cv2.cvtColor(im, cv2.COLOR_RGB2LAB))
        elif self.color_mode == "LUV":
            return self.to_tensor(cv2.cvtColor(im, cv2.COLOR_RGB2Luv))
        else:
            logger.info("color mode not supported")


    def __getitem__(self, idx):
        """
        Return #masked_img_per_item masked image, also returns hedge masks if required
        """
        try:
            # choose a random label + img_code
            current_image_path = self.image_path[random.randint(0, self.label_count - 1)][random.randint(0, self.image_count - 1)]
            label = int(current_image_path.split("/")[-2])
            density = str(round(self.density_range[idx % len(self.density_range)], 2))  # uniform density across dataset
            
            # get the ground truth
            ground_truth_path = os.path.join(current_image_path, current_image_path.split("/")[-1] + ".png")
            ground_truth_image = self.read_img(ground_truth_path)

            # choose n samples from same density level and load all of them
            samples = np.random.choice(self.available_samples, self.masked_img_per_item, replace=False)
            masked_images = []
            hedge_masks = []
            for sample in samples:
                masked_image_path = os.path.join(current_image_path, density, str(int(sample)) + ".png")
                masked_images.append(self.read_img(masked_image_path))
                
            # return hedge masks if required
            if self.require_hedge_mask:
                for sample in samples:
                    full_hedge_path = os.path.join(current_image_path, density, sample + "_hedge.png")
                    hedge_masks.append(self.read_img(full_hedge_path))

            return masked_images, ground_truth_image, hedge_masks, label

        # if can't find try another one just in case if any file is missing
        except (FileNotFoundError, UnidentifiedImageError) as e:
            logger.info("{} not found / is broken".format(masked_image_path))
            return self.__getitem__(idx)


    def get_image_counts(self):
        self.label_count = len(os.listdir(self.root_dir))
        label_path = os.path.join(self.root_dir, os.listdir(self.root_dir)[0])
        
        self.image_count = len(os.listdir(label_path))
        image_path = os.path.join(label_path, os.listdir(label_path)[0])
        
        self.density_count = len(self.density_range)
        density_path = os.path.join(image_path, os.listdir(image_path)[0])
        
        self.sample_count = int(len(os.listdir(density_path)) / 2)  # hedged background and hedge mask so div by 2
        if self.masked_img_per_item > self.sample_count:
            print("warning: sample per density isn't enough")

        self.available_samples = np.arange(0, self.sample_count)
        

    def read_image_code_path(self):
        """
        We store the path for each image in a multi-dimension list with [label, image_code]
        Density/Sample would be decided during get_item()
        """
        for label_no in os.listdir(self.root_dir):
            current_label = []
            label_dir = os.path.join(self.root_dir, label_no)
            for image_code in os.listdir(label_dir):
                current_label.append(os.path.join(label_dir, image_code))
            self.image_path.append(current_label)
        self.image_path = np.array(self.image_path)


    def image_code_path_validation(self):
        """ make sure all files are loaded correctly """
        assert self.image_path.shape == (self.label_count, self.image_count)


# Model

In [None]:
import torch
import torch.nn as nn

import torchvision
from torchvision import transforms

## U-Net
Used to perform image dehedging

### U-Net Block (TODO try PReLU + Dropout2d)




In [None]:
class Block(nn.Module):
    def __init__(self, in_channel, out_channel):
        super().__init__()
#         self.block = nn.Sequential(nn.Conv2d(in_channel, out_channel, 3, padding=1, padding_mode='reflect'),
#                                    nn.LeakyReLU(),
#                                    nn.BatchNorm2d(out_channel, affine=True),
#                                    nn.Conv2d(out_channel, out_channel, 3, padding=1, padding_mode='reflect'),
#                                    nn.LeakyReLU(),
#                                    nn.BatchNorm2d(out_channel, affine=True))

        # 1) replace LeakyReLU with ParametricReLU (discard due to memory issue)
        # 2) added dropout with p=0.01
        self.block = nn.Sequential(nn.Conv2d(in_channel, out_channel, 3, padding=1, padding_mode='replicate'),
                                   nn.LeakyReLU(),
                                   # nn.Dropout2d(p=0.01),
                                   nn.BatchNorm2d(out_channel, affine=True),
                                   nn.Conv2d(out_channel, out_channel, 3, padding=1, padding_mode='replicate'),
                                   nn.LeakyReLU(),
                                   # nn.Dropout2d(p=0.01),
                                   nn.BatchNorm2d(out_channel, affine=True))

        self.block.apply(self.weights_init)

    def forward(self, x):
        return self.block(x)

    def weights_init(self, m):
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
            nn.init.xavier_uniform_(m.weight)


### U-Net Encoder/Decoder

In [None]:
class Encoder(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.blocks = nn.ModuleList([Block(in_channels[0], in_channels[1])])
        self.down_samples = nn.ModuleList([])

        for i in range(1, len(in_channels) - 1):
            self.blocks.append(Block(in_channels[i + 1], in_channels[i + 1]))
            self.down_samples.append(
                nn.Conv2d(in_channels[i], in_channels[i + 1], 3, stride=2, padding=1, padding_mode='reflect'))

    def forward(self, x):
        connections = [x]
        for i in range(len(self.down_samples)):
            x = self.blocks[i](x)
            connections.append(x)
            x = self.down_samples[i](x)
        x = self.blocks[-1](x)
        return x, connections


class Decoder(nn.Module):
    def __init__(self, out_channels):
        super().__init__()
        self.blocks = nn.ModuleList([])
        self.up_samples = nn.ModuleList([])

        for i in range(0, len(out_channels) - 2):
            self.up_samples.append(
                nn.ConvTranspose2d(out_channels[i], out_channels[i + 1], 3, stride=2, padding=1, output_padding=1))
            self.blocks.append(Block(out_channels[i], out_channels[i + 1]))

        self.blocks.append(Block(out_channels[-2], out_channels[-1]))

    def forward(self, x, connections):
        for i in range(len(self.up_samples)):
            x = self.up_samples[i](x)
            x = torch.cat([x, connections[i]], dim=1)
            x = self.blocks[i](x)
        x = self.blocks[-1](x)
        return x


### U-Net

In [None]:
class UNet(nn.Module):
    def __init__(self, c_in, c_out):
        super().__init__()
        # in_channels = [c_in, 32, 64, 128, 256]
        # out_channels = [256, 128, 64, 32, c_out]
        in_channels = [c_in, 64, 128, 256, 512]
        out_channels = [512, 256, 128, 64, c_out]
        
        
        self.encoder = Encoder(in_channels)
        self.decoder = Decoder(out_channels)

    def forward(self, x):
        x, connections = self.encoder(x)
        output_pred = self.decoder(x, connections[::-1])
        return output_pred

### Uncertainty U-Net

In [None]:
class UncertaintyUNet(nn.Module):
    def __init__(self, c_in, c_out, u_out):
        super().__init__()
        # in_channels = [c_in, 32, 64, 128, 256]
        # out_channels = [256, 128, 64, 32, c_out]
        # uncertainty_channels = [256, 128, 64, 32, u_out]

        in_channels = [c_in, 64, 128, 256, 512]
        out_channels = [512, 256, 128, 64, c_out]
        uncertainty_channels = [512, 256, 128, 64, c_out]
        
        self.encoder = Encoder(in_channels)
        self.decoder = Decoder(out_channels)
        self.uncertainty_decoder = Decoder(uncertainty_channels)

    def forward(self, x):
        x, connections = self.encoder(x)
        output_pred = self.decoder(x, connections[::-1])
        uncertainty_pred = self.uncertainty_decoder(x, connections[::-1])
        # TODO test on torch.clamp on prediction to force net output stays in range [0, 1]
        output_pred = torch.clamp(output_pred, 0, 1)
        return output_pred, uncertainty_pred

## VGG Model
Used to perform sanity check on model output

In [None]:
class VGGClass(torch.nn.Module):
    def __init__(self):
        super(VGGClass, self).__init__()
        self.vgg_net = torchvision.models.vgg16(pretrained=True)
        self.preprocess = transforms.Compose([
            transforms.Resize(224),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    def forward(self, input_img):
        input_img = torch.clamp(input_img, 0, 1)
        input_img = self.preprocess(input_img)
        vgg_class = self.vgg_net(input_img)
        return vgg_class

### Validate VGG model

In [None]:
import os
import numpy as np
from torch.utils.data import DataLoader

# from dataset import StaticHedgedDataset
# from model import VGGClass
# from config import device
# from utils import AverageMeter
# from loss_functions import vgg_accuracy


class ImageNetDataset(Dataset):
    def __init__(self, root_dir):
        # read args
        self.root_dir = root_dir
        self.image_path = []
        self.read_image_code_path()
        self.to_tensor = torchvision.transforms.ToTensor()


    def __len__(self):
        return len(self.image_path)


    def read_img(self, path):
        return self.to_tensor(cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB))
        # return self.to_tensor(cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)).resize(224, 224, 3)
        # return self.to_tensor(Image.open(path).convert("RGB").resize(224, 224, 3))


    def __getitem__(self, idx):
        """
        Return #masked_img_per_item masked image, also returns hedge masks if required
        """
        try:
            # choose a random label + img_code
            current_image_path = self.image_path[idx]
            label = int(current_image_path.split("/")[-2])
            image = self.read_img(current_image_path)
            return image, label

        # if can't find try another one just in case if any file is missing
        except FileNotFoundError:
            print("{} not found".format(masked_image_path))
            return self.__getitem__(idx)
        

    def read_image_code_path(self):
        """
        We store the path for each image in a multi-dimension list with [label, image_code]
        Density/Sample would be decided during get_item()
        """
        for label_no in os.listdir(self.root_dir):
            label_dir = os.path.join(self.root_dir, label_no)
            for image_code in os.listdir(label_dir):
                self.image_path.append(os.path.join(label_dir, image_code))
        self.image_path = np.array(self.image_path)


In [None]:
# ## VGG on (224,224) imageNet dataset (original version)
# vgg_model = VGGClass().to(device)
# vgg_model.eval()

# path = os.path.join(root, "images/sources/imagenet/data_VGG_label")
# dataset = ImageNetDataset(path)
# dataloader = DataLoader(dataset, batch_size=32, num_workers=0)
# top1_acc_meter = AverageMeter()
# top5_acc_meter = AverageMeter()

# # top 1 should be about 0.625, top 5 should be around 0.85
# for i, (image, label) in enumerate(dataloader):
#     image = image.to(device)
#     label = label.to(device)

#     top1_acc, top5_acc = vgg_accuracy(vgg_model, image, label)

#     update_meters([top1_acc_meter, top5_acc_meter], [top1_acc, top5_acc])

#     if i % 500 == 0:
#         logger.info("{}/{} top 1 {}, top 5 {}".format(i, len(dataloader), top1_acc_meter.avg, top5_acc_meter.avg))

In [None]:
# ## VGG on (128,128) imageNet dataset (our version)
# path = os.path.join(root, "images/hedged_images/test")
# density_range = np.arange(0.1, 0.9, 0.1)
# dataset = StaticHedgedDataset(path, density_range, masked_img_per_item=1, require_hedge_mask=False)
# dataloader = DataLoader(dataset, batch_size=32, num_workers=4)
# top1_acc_meter.reset()
# top5_acc_meter.reset()

# # top 1 should be about 0.625, top 5 should be around 0.85
# for i, (_, ground_truths, _,  label) in enumerate(dataloader):
#     ground_truths = ground_truths.to(device)
#     label = label.to(device)

#     top1_acc, top5_acc = vgg_accuracy(vgg_model, ground_truths, label)
    
#     update_meters([top1_acc_meter, top5_acc_meter], [top1_acc, top5_acc])
    
#     if i % 50 == 0:
#         logger.info("{}/{} top 1 {}, top 5 {}".format(i, len(dataloader), top1_acc_meter.avg, top5_acc_meter.avg))

# Optimizer Wrapper

In [None]:
class OptimizerWrapper(object):

    def __init__(self, optimizer):
        self.optimizer = optimizer
        # self.step_num = 0
        # self.lr = 0.1

    def clip_gradient(self, clip_val):
        for group in self.optimizer.param_groups:
            for param in group['params']:
                if param.grad is not None:
                    param.grad.data.clamp_(-clip_val, clip_val)

    def zero_grad(self):
        self.optimizer.zero_grad()

    def step(self):
        # self._update_lr()
        self.optimizer.step()

    def adjust_lr(self, lr):
        for param in self.optimizer.param_groups:
            param['lr'] = lr
            

# Train U-Net

reason for adding scheduler for adam:
https://arxiv.org/abs/1711.05101

"Adam can substantially benefit from a scheduled learning rate multiplier. The fact that Adam
is an adaptive gradient algorithm and as such adapts the learning rate for each parameter
does not rule out the possibility to substantially improve its performance by using a global
learning rate multiplier, scheduled, e.g., by cosine annealing."

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import math
import madgrad
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import ExponentialLR, CosineAnnealingWarmRestarts, ReduceLROnPlateau, OneCycleLR
# from done.utils import AverageMeter, save_checkpoint, get_logger, show_tensor_images
# from done.VGGClass import VGGClass
# from done.model import UNet
# from done.dataset import StaticHedgedDataset
# from done.loss_functions import vgg_accuracy
# from done.config import device, print_freq, parse_args
# from done.optimizer_wrapper import OptimizerWrapper

In [None]:
def load_pretrained_mse_param(uncertainty_unet, path_to_pretrained_mse_unet, pretrained_fixed=True):
    checkpoint = torch.load(path_to_pretrained_mse_unet)
    pretrained_model = checkpoint["u_net_model"]
    if isinstance(pretrained_model, torch.nn.DataParallel):
        pretrained_model = pretrained_model.module
    uncertainty_unet.encoder.load_state_dict(pretrained_model.encoder.state_dict())
    uncertainty_unet.decoder.load_state_dict(pretrained_model.decoder.state_dict())
    
    if pretrained_fixed:
        for param in uncertainty_unet.encoder.parameters():
            param.requires_grad = False
        for param in uncertainty_unet.decoder.parameters():
            param.requires_grad = False
    
    del pretrained_model
    
    
def log_and_record(logger, writer, train_loss, val_loss, ep, train_mse=None, train_uncertainty=None, val_mse=None, val_uncertainty=None):

    if args.unet_type == "unet":
        logger.info('Epoch: [{0}]\t'
                    'Train Loss {train_loss:.5f}\t'
                    'Val Loss {val_loss:.5f}\t'.format(ep, train_loss=train_loss, val_loss=val_loss))

        writer.add_scalar('model/train_loss', train_loss, ep)
        writer.add_scalar('model/val_loss', val_loss, ep)

    elif args.unet_type == "uncertainty_unet" or args.unet_type == "recurrent_uncertainty_unet":
        logger.info('\n\nEpoch: [{0}]\n'
                    'Train Loss {train_loss:.5f}\t'
                    'Train MSE {train_mse:.5f}\t'
                    'Train average uncertainty value {train_uncertainty:.5f}\n'
                    'Val Loss {val_loss:.5f}\t'
                    'Val MSE {val_mse:.5f}\t\t'
                    'Val average uncertainty {val_uncertainty:.5f}\n'.format(ep,
                                                                             train_loss=train_loss,
                                                                             train_mse=train_mse,
                                                                             train_uncertainty=train_uncertainty,
                                                                             val_loss=val_loss,
                                                                             val_mse=val_mse,
                                                                             val_uncertainty=val_uncertainty))
        
        writer.add_scalar('model/train_loss', train_loss, ep)
        writer.add_scalar('model/uncertainty', train_uncertainty, ep)
        writer.add_scalar('model/mse', train_mse, ep)
        writer.add_scalar('model/val_loss', val_loss, ep)
        writer.add_scalar('model/uncertainty', val_uncertainty, ep)
        writer.add_scalar('model/mse', val_mse, ep)

        
def get_scheduler(optimizer, steps_per_epoch, start_epoch):
    if args.scheduler == 'exp':
        return ExponentialLR(optimizer.optimizer, gamma=args.LR_gamma, verbose=args.verbose)
    elif args.scheduler == 'cosine':
        return CosineAnnealingWarmRestarts(optimizer.optimizer, T_0=args.T_0, T_mult=args.T_mult, eta_min=args.eta_min, verbose=args.verbose)
    elif args.scheduler == 'ReduceLROnPlateau':
        return ReduceLROnPlateau(optimizer.optimizer, mode=args.mode, factor=args.factor, patience=args.patience, min_lr=args.min_lr, verbose=args.verbose)
    elif args.scheduler == 'onecycle':
        return OneCycleLR(optimizer.optimizer, max_lr=args.max_lr, steps_per_epoch=steps_per_epoch, epochs=args.end_epoch - start_epoch)
    else:
        raise TypeError('scheduler {} is not supported.'.format(args.scheduler))
        logger.info("using {} as scheduler".format(args.scheduler))
        
def scheduler_step(scheduler, val_loss):
    if args.scheduler=='ReduceLROnPlateau':
        scheduler.step(val_loss)
    else:
        scheduler.step()
        
    logger.info("learning rate: {}".format(scheduler.optimizer.param_groups[0]['lr']))    
    
    
def get_data_parallel_model(model):
    # if there are multiple GPU available and the model is not a DataParallel Object
    if torch.cuda.device_count() > 1 and not isinstance(model, torch.nn.DataParallel):
        model = nn.DataParallel(model)
    return model

## Unet

In [None]:
def unet_training(start_epoch, end_epoch, train_loader, val_loader, model, criterion, optimizer, image_save_path, logger, writer, scheduler, best_loss, epochs_since_improvement):
    for ep in range(start_epoch, end_epoch):
        train_loss = train_unet(train_loader, model, criterion, optimizer, args.clip_val, ep, scheduler)          
        val_loss = val_unet(val_loader, model, criterion, ep, image_save_path)
        log_and_record(logger, writer, train_loss, val_loss, ep)
                               
        if args.scheduler and args.scheduler != "onecycle":
            scheduler_step(scheduler, val_loss)

        # due to the uncertainty property there are chance that the loss go up to inf or nan
        # so we have to treat them differently
        if math.isnan(val_loss) or math.isinf(val_loss):
            is_best = False
        else:
            is_best = val_loss < best_loss
            best_loss = min(val_loss, best_loss)
            
        if not is_best:
            epochs_since_improvement += 1
            logger.info("Epochs since last improvement: {}".format(epochs_since_improvement))
            if epochs_since_improvement > args.early_stop_ep:
                logger.info("early stop at ep {} with no improvement for {} eps".format(ep, args.early_stop_ep))
                return      
        else:
            epochs_since_improvement = 0

        save_checkpoint(ep, epochs_since_improvement, model, optimizer, scheduler, val_loss, best_loss, is_best, args.checkpoint_save_path)

                            

        
def train_unet(train_loader, model, criterion, optimizer, clip_val, epoch, scheduler=None):
    # train mode
    model.train()

    if args.scheduler == "cosine":
        iters = len(train_loader)
    # train
    training_loss = AverageMeter()

    for i, (masked_images, ground_truths, _,  _) in enumerate(train_loader):
        ground_truths = ground_truths.to(device)
        masked_images[0] = masked_images[0].to(device)

        # L2 loss training
        dehedged_prediction = model(masked_images[0])
        loss = criterion(dehedged_prediction, ground_truths)

        # back prop
        optimizer.zero_grad()
        loss.backward()

        # update weights
        if clip_val:
            optimizer.clip_gradient(clip_val)

        optimizer.step()
        training_loss.update(loss.item())

        # onecyle scheduler step for every batches
        if args.scheduler == "onecycle":
            scheduler.step()
        elif args.scheduler == "cosine":
            scheduler.step(ep+i/iters)
            
        # if i % print_freq == 0:
        #     logger.info('Epoch: [{0}][{1}/{2}]\t'
        #                 'Train Loss {loss.val:.5f} ({loss.avg:.5f})\t'.format(epoch, i, len(train_loader), loss=training_loss))

    return training_loss.avg


def val_unet(val_loader, model, criterion, epoch, image_save_path):
    with torch.no_grad():
        # eval mode
        model.eval()
        # loss values init
        validation_loss = AverageMeter()
        for i, (masked_images, ground_truths, _, label) in enumerate(val_loader):
            ground_truths = ground_truths.to(device)
            masked_images[0] = masked_images[0].to(device)
            label = label.to(device)
            # update L2 loss
            dehedged_predictions = model(masked_images[0])
            loss = criterion(dehedged_predictions, ground_truths)
            validation_loss.update(loss.item())

            # if i % print_freq == 0:
            #     show_tensor_images([masked_images[0][0], dehedged_predictions[0], ground_truths[0]], ["masked image", "pred", "ground truth"])
            #     logger.info('Epoch: [{0}][{1}/{2}]\t'
            #                 'Val Loss {loss.val:.5f} ({loss.avg:.5f})\t'.format(epoch, i, len(val_loader), loss=validation_loss))

        # log.txt info
        show_tensor_images([masked_images[0][0], dehedged_predictions[0], ground_truths[0]], ["masked image", "pred", "ground truth"], os.path.join(image_save_path, str(epoch) + ".png"))

        return validation_loss.avg   
    

## Uncertainty Unet

In [None]:
def uncertainty_unet_training(start_epoch, end_epoch, train_loader, val_loader, model, criterion, optimizer, image_save_path, logger, writer, scheduler, best_loss, epochs_since_improvement):
    for ep in range(start_epoch, end_epoch):
        train_loss, train_mse, train_uncertainty = train_uncertainty_unet(train_loader, model, criterion, optimizer, args.clip_val, ep, scheduler)
        val_loss, val_mse, val_uncertainty = val_uncertainty_unet(val_loader, model, criterion, ep, image_save_path)
        log_and_record(logger, writer, train_loss, val_loss, ep, train_mse, train_uncertainty, val_mse, val_uncertainty)
        
        if args.scheduler and args.scheduler != "onecycle" and args.scheduler != "cosine":
            scheduler_step(scheduler, val_loss)

        # due to the uncertainty property there are chance that the loss go up to inf or nan
        # so we have to treat them differently
        if math.isnan(val_loss) or math.isinf(val_loss):
            is_best = False
        else:
            is_best = val_loss < best_loss
            best_loss = min(val_loss, best_loss)
            
        if not is_best:
            epochs_since_improvement += 1
            logger.info("Epochs since last improvement: {}".format(epochs_since_improvement))
            if epochs_since_improvement > args.early_stop_ep:
                logger.info("early stop at ep {} with no improvement for {} eps".format(ep, args.early_stop_ep))
                return      
        else:
            epochs_since_improvement = 0
                
        save_checkpoint(ep, epochs_since_improvement, model, optimizer, scheduler, val_loss, best_loss, is_best, args.checkpoint_save_path)


def train_uncertainty_unet(train_loader, model, criterion, optimizer, clip_val, epoch, scheduler=None):
    # train mode
    model.train()

    # loss value init
    training_loss = AverageMeter()
    mse_meter = AverageMeter()
    uncertainty_meter = AverageMeter()

    for i, (masked_images, ground_truths, _,  _) in enumerate(train_loader):
        ground_truths = ground_truths.to(device)
        masked_images[0] = masked_images[0].to(device)

        # L2 loss training
        dehedged_prediction, uncertainty = model(masked_images[0])
        loss, mse = criterion(dehedged_prediction, uncertainty, ground_truths)

        # back prop
        optimizer.zero_grad()
        loss.backward()

        # update weights
        if clip_val:
            optimizer.clip_gradient(clip_val)

        optimizer.step()
        training_loss.update(loss.item())
        mse_meter.update(mse.mean().item())
        uncertainty_meter.update(torch.mean(uncertainty))
        
        # onecyle scheduler step for every batches
        if args.scheduler == "onecycle" or args.scheduler == "consine":
            scheduler.step()
        # if i % print_freq == 0:
        #     logger.info('Epoch: [{0}][{1}/{2}]\t'
        #                 'Train Loss {loss.val:.5f} ({loss.avg:.5f})\t'.format(epoch, i, len(train_loader), loss=training_loss))

    return training_loss.avg, mse_meter.avg, uncertainty_meter.avg


def val_uncertainty_unet(val_loader, model, criterion, epoch, image_save_path):
    with torch.no_grad():
        # eval mode
        model.eval()    
        # loss values init
        validation_loss = AverageMeter()
        mse_meter = AverageMeter()
        uncertainty_meter = AverageMeter()
        for i, (masked_images, ground_truths, _, label) in enumerate(val_loader):
            ground_truths = ground_truths.to(device)
            masked_images[0] = masked_images[0].to(device)
            label = label.to(device)
            # update L2 loss
            dehedged_predictions, uncertainty = model(masked_images[0])
            loss, mse = criterion(dehedged_predictions, uncertainty, ground_truths)

            validation_loss.update(loss.item())
            mse_meter.update(mse.mean().item())
            uncertainty_meter.update(torch.mean(uncertainty))

            # if i % print_freq == 0:
            #     show_tensor_images([masked_images[0][0], dehedged_predictions[0], ground_truths[0]], ["masked image", "pred", "ground truth"])
            #     logger.info('Epoch: [{0}][{1}/{2}]\t'
            #                 'Val Loss {loss.val:.5f} ({loss.avg:.5f})\t'.format(epoch, i, len(val_loader), loss=validation_loss))

        # log.txt info
        show_uncertainty_result(masked_images[0][0], dehedged_predictions[0], ground_truths[0], uncertainty[0], os.path.join(image_save_path, str(epoch) + ".png"))

        return validation_loss.avg, mse_meter.avg, uncertainty_meter.avg

## Recurrent Uncertainty Unet

In [None]:
def recurrent_uncertainty_unet_training(start_epoch, end_epoch, train_loader, val_loader, model, single_frame_model, criterion, optimizer, image_save_path, logger, writer, scheduler, best_loss, epochs_since_improvement):
    for ep in range(start_epoch, end_epoch):
        train_loss, train_mse, train_uncertainty = train_recurrent_uncertainty_unet(train_loader, model, single_frame_model, criterion, optimizer, args.clip_val, ep, scheduler)
        val_loss, val_mse, val_uncertainty = val_recurrent_uncertainty_unet(val_loader, model, single_frame_model, criterion, ep, image_save_path)
        log_and_record(logger, writer, train_loss, val_loss, ep, train_mse, train_uncertainty, val_mse, val_uncertainty)
        
        if args.scheduler and args.scheduler != "onecycle" and args.scheduler != "cosine":
            scheduler_step(scheduler, val_loss)

        # due to the uncertainty property there are chance that the loss go up to inf or nan
        # so we have to treat them differently
        if math.isnan(val_loss) or math.isinf(val_loss):
            is_best = False
        else:
            is_best = val_loss < best_loss
            best_loss = min(val_loss, best_loss)
            
        if not is_best:
            epochs_since_improvement += 1
            logger.info("Epochs since last improvement: {}".format(epochs_since_improvement))
            if epochs_since_improvement > args.early_stop_ep:
                logger.info("early stop at ep {} with no improvement for {} eps".format(ep, args.early_stop_ep))
                return      
        else:
            epochs_since_improvement = 0
                
        save_checkpoint(ep, epochs_since_improvement, model, optimizer, scheduler, val_loss, best_loss, is_best, args.checkpoint_save_path)
        
        
def train_recurrent_uncertainty_unet(train_loader, model, single_frame_model, criterion, optimizer, clip_val, epoch, scheduler=None):
    # train mode
    model.train()

    # loss value init
    training_loss = AverageMeter()
    mse_meter = AverageMeter()
    uncertainty_meter = AverageMeter()

    for i, (masked_images, ground_truths, _,  _) in enumerate(train_loader):
        ground_truths = ground_truths.to(device)
        masked_images[0] = masked_images[0].to(device)
        masked_images[1] = masked_images[1].to(device)
        
        # L2 loss training
        with torch.no_grad():
            dehedged_predictions, uncertainty = single_frame_model(masked_images[0])
            model_input = torch.cat([masked_images[1], dehedged_predictions, uncertainty], dim=1).detach()
            model_input = model_input.to(device)
            
        dehedged_predictions, uncertainty = model(model_input)
        loss, mse = criterion(dehedged_predictions, uncertainty, ground_truths)

        # back prop
        optimizer.zero_grad()
        loss.backward()

        # update weights
        if clip_val:
            optimizer.clip_gradient(clip_val)

        optimizer.step()
        training_loss.update(loss.item())
        mse_meter.update(mse.mean().item())
        uncertainty_meter.update(torch.mean(uncertainty))
        
        # onecyle scheduler step for every batches
        if args.scheduler == "onecycle" or args.scheduler == "consine":
            scheduler.step()
        # if i % print_freq == 0:
        #     logger.info('Epoch: [{0}][{1}/{2}]\t'
        #                 'Train Loss {loss.val:.5f} ({loss.avg:.5f})\t'.format(epoch, i, len(train_loader), loss=training_loss))

    return training_loss.avg, mse_meter.avg, uncertainty_meter.avg


def val_recurrent_uncertainty_unet(val_loader, model, single_frame_model, criterion, epoch, image_save_path):
    with torch.no_grad():
        # eval mode
        model.eval()    
        # loss values init
        validation_loss = AverageMeter()
        mse_meter = AverageMeter()
        uncertainty_meter = AverageMeter()
        for i, (masked_images, ground_truths, _, label) in enumerate(val_loader):
            ground_truths = ground_truths.to(device)
            masked_images[0] = masked_images[0].to(device)
            masked_images[1] = masked_images[1].to(device)
            label = label.to(device)
            
            prev_pred, prev_uncertainty = single_frame_model(masked_images[0])
            model_input = torch.cat([masked_images[1], prev_pred, prev_uncertainty], dim=1).detach()
            dehedged_predictions, uncertainty = model(model_input)
            loss, mse = criterion(dehedged_predictions, uncertainty, ground_truths)

            validation_loss.update(loss.item())
            mse_meter.update(mse.mean().item())
            uncertainty_meter.update(torch.mean(uncertainty))

            # if i % print_freq == 0:
            #     show_tensor_images([masked_images[0][0], dehedged_predictions[0], ground_truths[0]], ["masked image", "pred", "ground truth"])
            #     logger.info('Epoch: [{0}][{1}/{2}]\t'
            #                 'Val Loss {loss.val:.5f} ({loss.avg:.5f})\t'.format(epoch, i, len(val_loader), loss=validation_loss))

        # log.txt info
        show_tensor_images([masked_images[0][0], prev_pred[0], prev_uncertainty[0]], ["prev masked image", "prev pred", "prev uncertainty"], os.path.join(image_save_path, str(epoch) + "_compare.png"))
        show_uncertainty_result(masked_images[1][0], dehedged_predictions[0], ground_truths[0], uncertainty[0], os.path.join(image_save_path, str(epoch) + ".png"))

        return validation_loss.avg, mse_meter.avg, uncertainty_meter.avg

# Model Training

In [None]:
def model_training(args):
    checkpoint = args.checkpoint_load_path
    image_save_path = os.path.join(args.checkpoint_save_path, "images")
    make_dir(args.checkpoint_save_path)
    make_dir(image_save_path)
    make_dir("runs")

    writer = SummaryWriter(log_dir = os.path.join("runs", args.tensorboard_fileName))
    density_range = np.arange(args.min_density, args.max_density, 0.1)

    train_dataset = StaticHedgedDataset(args.train_path, density_range, masked_img_per_item=args.masked_img_per_item, require_hedge_mask=False, color_mode=args.color_mode, debug=args.debug)
    val_dataset = StaticHedgedDataset(args.val_path, density_range, masked_img_per_item=args.masked_img_per_item, require_hedge_mask=False, color_mode=args.color_mode, debug=args.debug)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=args.pin_memory)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=args.pin_memory)
    logger.info("CPU count = {}".format(args.num_workers))
    
    if args.unet_type == "unet":
        criterion = nn.MSELoss()
    elif args.unet_type == "uncertainty_unet" or "recurrent_uncertainty_unet":
        criterion = uncertainty_MSE_loss
    else:
        logger.info("model {} is not supported".format(args.unet_type))

    # init check point
    if checkpoint is None:
        if args.unet_type == "unet":
            model = UNet(args.in_channel, args.out_channel).to(device)
        elif args.unet_type == "uncertainty_unet" or args.unet_type == "recurrent_uncertainty_unet":
            model = UncertaintyUNet(args.in_channel, args.out_channel, args.out_uncertainty_channel).to(device)
            if args.unet_type == "uncertainty_unet" and args.pretrained_MSE_unet:
                load_pretrained_mse_param(model, args.pretrained_MSE_unet, pretrained_fixed=False)
        else:
            logger.info("model {} is not supported".format(args.unet_type))
        
        logger.info("training model type: {}".format(args.unet_type))

        start_epoch = 0
        epochs_since_improvement = 0
        best_loss = float('inf')
        
        
        if args.optimizer == 'adam':
            optimizer = OptimizerWrapper(torch.optim.Adam(model.parameters(), lr=args.lr))
        elif args.optimizer == 'SGD':
            optimizer = OptimizerWrapper(torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov))
        elif args.optimizer == 'madgrad':
            optimizer = OptimizerWrapper(madgrad.MADGRAD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, eps=args.eps))
        else:
            raise TypeError('optimizer {} is not supported.'.format(args.optimizer))

        logger.info("using {} as optimizer, initial learning rate = {}".format(args.optimizer, args.lr))

        if args.scheduler:
            scheduler = get_scheduler(optimizer, len(train_loader), start_epoch)
        else:
            scheduler = None
            logger.info("not using any scheduler")        
    else:
        checkpoint = torch.load(checkpoint)
        model = checkpoint['u_net_model']
        
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epoch_since_improvement']
        optimizer = checkpoint['optimizer']
        if args.change_scheduler:
            logger.info("change scheduler to {}".format(args.scheduler))
            scheduler = get_scheduler(optimizer, len(train_loader), start_epoch)
        else:
            scheduler = checkpoint['scheduler']
            if args.scheduler == 'onecycle':
                scheduler = OneCycleLR(optimizer.optimizer, max_lr=args.max_lr, steps_per_epoch=len(train_loader), epochs=args.end_epoch - start_epoch)
        best_loss = checkpoint['best_loss']

        logger.info("continue training from ep{}, best loss so far is {}".format(start_epoch, best_loss))

        if args.change_lr:
            logger.info("adjust lr to ", args.lr)
            optimizer.adjust_lr(args.lr)
        
    model = get_data_parallel_model(model)
    
    if args.unet_type == "unet":
        unet_training(start_epoch, args.end_epoch, train_loader, val_loader, model, criterion, optimizer, image_save_path, logger, writer, scheduler, best_loss, epochs_since_improvement)
    elif args.unet_type == "uncertainty_unet":
        uncertainty_unet_training(start_epoch, args.end_epoch, train_loader, val_loader, model, criterion, optimizer, image_save_path, logger, writer, scheduler, best_loss, epochs_since_improvement)
    elif args.unet_type == "recurrent_uncertainty_unet":
        single_unet_model = get_data_parallel_model(load_checkpoint_model(args.pretrained_uncertainty_unet))
        single_unet_model.eval()
        recurrent_uncertainty_unet_training(start_epoch, args.end_epoch, train_loader, val_loader, model, single_unet_model, criterion, optimizer, image_save_path, logger, writer, scheduler, best_loss, epochs_since_improvement)    

if train_flag:
    args = parse_args()
    model_training(args)

In [None]:
# /tmp/ipykernel_31536/4134605674.py in <module>
#      91 if train_flag:
#      92     args = parse_args()
# ---> 93     model_training(args)

# /tmp/ipykernel_31536/4134605674.py in model_training(args)
#      81 
#      82     if args.unet_type == "unet":
# ---> 83         unet_training(start_epoch, args.end_epoch, train_loader, val_loader, model, criterion, optimizer, image_save_path, logger, writer, scheduler, best_loss, epochs_since_improvement)
#      84     elif args.unet_type == "uncertainty_unet":
#      85         uncertainty_unet_training(start_epoch, args.end_epoch, train_loader, val_loader, model, criterion, optimizer, image_save_path, logger, writer, scheduler, best_loss, epochs_since_improvement)

# /tmp/ipykernel_31536/585321851.py in unet_training(start_epoch, end_epoch, train_loader, val_loader, model, criterion, optimizer, image_save_path, logger, writer, scheduler, best_loss, epochs_since_improvement)
#       1 def unet_training(start_epoch, end_epoch, train_loader, val_loader, model, criterion, optimizer, image_save_path, logger, writer, scheduler, best_loss, epochs_since_improvement):
#       2     for ep in range(start_epoch, end_epoch):
# ----> 3         train_loss = train_unet(train_loader, model, criterion, optimizer, args.clip_val, ep, scheduler)
#       4         val_loss = val_unet(val_loader, model, criterion, ep, image_save_path)
#       5         log_and_record(logger, writer, train_loss, val_loss, ep)

# /tmp/ipykernel_31536/585321851.py in train_unet(train_loader, model, criterion, optimizer, clip_val, epoch, scheduler)
#      39     training_loss = AverageMeter()
#      40 
# ---> 41     for i, (masked_images, ground_truths, _,  _) in enumerate(train_loader):
#      42         ground_truths = ground_truths.to(device)
#      43         masked_images[0] = masked_images[0].to(device)

# /opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py in __next__(self)
#     519             if self._sampler_iter is None:
#     520                 self._reset()
# --> 521             data = self._next_data()
#     522             self._num_yielded += 1
#     523             if self._dataset_kind == _DatasetKind.Iterable and \

# /opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _next_data(self)
#    1201             else:
#    1202                 del self._task_info[idx]
# -> 1203                 return self._process_data(data)
#    1204 
#    1205     def _try_put_index(self):

# /opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _process_data(self, data)
#    1227         self._try_put_index()
#    1228         if isinstance(data, ExceptionWrapper):
# -> 1229             data.reraise()
#    1230         return data
#    1231 

# /opt/conda/lib/python3.7/site-packages/torch/_utils.py in reraise(self)
#     423             # have message field
#     424             raise self.exc_type(message=msg)
# --> 425         raise self.exc_type(msg)
#     426 
#     427 

# UnidentifiedImageError: Caught UnidentifiedImageError in DataLoader worker process 2.
# Original Traceback (most recent call last):
#   File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
#     data = fetcher.fetch(index)
#   File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
#     data = [self.dataset[idx] for idx in possibly_batched_index]
#   File "/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
#     data = [self.dataset[idx] for idx in possibly_batched_index]
#   File "/tmp/ipykernel_31536/3801022188.py", line 70, in __getitem__
#     masked_images.append(self.read_img(masked_image_path))
#   File "/tmp/ipykernel_31536/3801022188.py", line 37, in read_img
#     im = np.array(Image.open(path).convert("RGB"))
#   File "/opt/conda/lib/python3.7/site-packages/PIL/Image.py", line 3031, in open
#     "cannot identify image file %r" % (filename if filename else fp)
# PIL.UnidentifiedImageError: cannot identify image file 'images/hedged_images/train/794/ILSVRC2012_val_00000527/0.8/0.png'

In [None]:
checkpoint = "checkpoint/unet(madgrad+onecycle)/ep300(lr=0.001)/checkpoint.tar"
checkpoint = torch.load(checkpoint)
optimizer = checkpoint['optimizer']
learning_rate = optimizer.optimizer.param_groups[0]['lr']
learning_rate

In [None]:
t = torch.cuda.get_device_properties(0).total_memory
r = torch.cuda.memory_reserved(0)
a = torch.cuda.memory_allocated(0)

print("total", t)
print("reserved", r)
print("allocated", a)