In [None]:
import numpy as np
import keras
import matplotlib.pyplot as plt
import sys
import os
from keras.layers import Input, TimeDistributed, Lambda, Conv2D, MaxPooling2D, UpSampling2D, Concatenate
import keras.backend as K
from keras.models import Model
import tensorflow as tf
from keras.utils import Sequence
from keras.optimizers import Adam
import cv2
from keras.callbacks import ModelCheckpoint, LearningRateScheduler
from PIL import Image
from IPython.display import clear_output
import scipy.io
from copy import deepcopy
import tqdm 
import math
import random
import glob
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)

sys.path.append('./src')
tf.compat.v1.set_random_seed(1234)
np.random.seed(1234)
random.seed(1234)
os.environ['PYTHONHASHSEED'] = str(1234)

from data_loading import load_datasets_singleduration
from util import get_model_by_name, create_losses
from losses_keras2 import kl_cc_combined, kl_cc_nss_combined, kl_cc_nss_combined_new

from sal_imp_utilities import *
from cb import InteractivePlot
from losses_keras2 import loss_wrapper

%load_ext autoreload
%autoreload 2

# Check GPU status

In [None]:
%%bash
nvidia-smi

In [None]:
# os.environ["CUDA_VISIBLE_DEVICES"] = '-1'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
from tensorflow.python.client import device_lib

def get_available_devices():
    local_device_protos = device_lib.list_local_devices()
    return [x.name for x in local_device_protos]

print(get_available_devices())

print(tf.test.is_gpu_available())

# Load data

In [None]:
img_filenames_ours = glob.glob('./images/*g')
imp_filenames_ours = glob.glob('./saliency_gt/*g')

In [None]:
len(img_filenames_ours), len(imp_filenames_ours)

# Model and training params

In [None]:
# FILL THESE IN: set training parameters 
ckpt_savedir = "ckpt"

load_weights = False
weightspath = ""

batch_size = 4
init_lr = 0.0001
lr_reduce_by = .1
reduce_at_epoch = 3
n_epochs = 50

opt = Adam(lr=init_lr) 

# losses is a dictionary mapping loss names to weights 
losses = {
    'kl': 10,
    'cc': -3,
}

model_name = "UMSI"

model_inp_size = (256, 256)
model_out_size = (512, 512)

In [None]:
# get model 
model_params = {
    'input_shape': model_inp_size + (3,),
    'n_outs': len(losses),
}
model_func, mode = get_model_by_name(model_name)
assert mode == "simple"
model = model_func(**model_params)

if load_weights: 
    model.load_weights(weightspath)

In [None]:
# set up data generation and checkpoints
if not os.path.exists(ckpt_savedir): 
    os.makedirs(ckpt_savedir)
    
# sort the losses so that those that use a fixmap are last, by convention
l, lw, l_str, n_heatmaps = create_losses(losses, model_out_size)
n_fixmaps = len(l) - n_heatmaps
print("Loss string", l_str)
    
# Generators
gen_train = ImpAndClassifGenerator(
        img_filenames=img_filenames_ours,
        imp_filenames=imp_filenames_ours,
        fix_filenames=None,
        extra_fixs=None,
        extras_per_epoch=160,
        batch_size=4,
        img_size=(shape_r,shape_c),
        map_size=(shape_r_out, shape_c_out),
        shuffle=True,
        augment=False,
        n_output_maps=1,
        concat_fix_and_maps=False,
        fix_as_mat=False,
        fix_key="",
        str2label=None,
        dummy_labels=False,
        num_classes=6,
        pad_imgs=True,
        pad_maps=True,
        return_names=False,
        return_labels=True,
        read_npy=False)

# where to save checkpoints
filepath = os.path.join(ckpt_savedir, "umsi++_" + l_str + '_ep{epoch:02d}_valloss{loss:.4f}.hdf5')
print("Checkpoints will be saved with format %s" % filepath)

cb_chk = ModelCheckpoint(filepath, monitor='loss', verbose=1, save_weights_only=True, period=1)
cb_plot = InteractivePlot()

def step_decay(epoch):
    lrate = init_lr * math.pow(lr_reduce_by, math.floor((1+epoch)/reduce_at_epoch))
    if epoch%reduce_at_epoch:
        print('Reducing lr. New lr is:', lrate)
    return lrate
cb_sched = LearningRateScheduler(step_decay)

cbs = [cb_chk, cb_sched, cb_plot]

In [None]:
n_heatmaps

In [None]:
img, outs = gen_train.__getitem__(1)
print("batch size: %d. Num inputs: %d. Num outputs: %d." % (batch_size, len(img), len(outs)))
preds = model.predict(img)

## Evaluate on Our Dataset

In [None]:
W = "./weights/umsi++.hdf5"
model.load_weights(W)

In [None]:
img_filenames_ours_test = glob.glob('./images/*g')
imp_filenames_ours_test = glob.glob('./saliency_gt/*g')
len(img_filenames_ours_test), len(imp_filenames_ours_test)

In [None]:
# Visualize some output on the val set 
gen = UMSI_eval_generator(
    img_filenames_ours_test, 
    imp_filenames_ours_test, 
    inp_size=model_inp_size)

examples = [next(gen) for _ in range(len(img_filenames_ours_test))]
len(examples)

In [None]:
# loop over the testing data
for example in examples:
    
    # show the original image
    images, maps, img_filename_= example
    preds = model.predict(images[0])
    preds_map = preds[0]
    preds_classif = preds[1]
    print("maps size", len(maps), maps[0].shape)
    batch = 0
    plt.gray()
    plt.figure(figsize = (14,8))
    plt.subplot(1,3,1)
    plt.imshow(reverse_preprocess(np.squeeze(images[0])))
    plt.title("natural images %d" % batch)
    
    # show the ground truth heatmap
    plt.subplot(1,3,2)
    plt.imshow(maps[0])
    plt.title('Gt ' )
    
    # show the predicted heatmap
    plt.subplot(1,3,3)
    plt.axis('off')
    pred_result = postprocess_predictions(np.squeeze(preds_map[0]),maps[0].shape[0],maps[0].shape[1], normalize=False, zero_to_255=True)
    plt.imshow(pred_result)
    plt.title('Prediction')
    # uncomment to save to the folder save_dir
    # save_dir = ''
    # os.makedirs(save_dir, exist_ok=True)
    # plt.imsave(save_dir + img_filename_.split('/')[-1], pred_result)