In [1]:
import numpy as np
import matplotlib.pyplot as plt
import sys, os
import tensorflow as tf
import keras
from keras.optimizers import Adam
import cv2
from keras.callbacks import ModelCheckpoint, LearningRateScheduler
from PIL import Image
from copy import deepcopy
import tqdm
import math, random

sys.path.append('../src')

from data_loading import load_datasets_singleduration
from util import get_model_by_name, create_losses

from losses_keras2 import *
from sal_imp_utilities import *
from cb import InteractivePlot

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

%load_ext autoreload
%autoreload 2

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
Using TensorFlow backend.


# Check GPU status

In [2]:
%%bash
nvidia-smi

Thu Aug 31 10:34:23 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02              Driver Version: 530.30.02    CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                  Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla V100-SXM2-32GB            Off| 00000000:1A:00.0 Off |                    0 |
| N/A   32C    P0               55W / 300W|  17515MiB / 32768MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  Tesla V100-SXM2-32GB            Off| 00000000:1B:00.0 Off |  

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
os.environ["CUDA_VISIBLE_DEVICES"]

'0'

# Load data

In [4]:
# FILL THESE IN 
bp = "/projects/wang/datasets/"

dataset_imp = "SalChartQA"
dataset_sal = "UMSI_SALICON"

In [5]:
data_imp = load_datasets_singleduration(dataset_imp, bp)

Using SalChartQA
Length of loaded files:
train images: 2113
train maps: 2113
val images: 595
val maps: 595


In [6]:
data_sal = load_datasets_singleduration(dataset_sal, bp)

Using SALICON (no fixation coords)
Length of loaded files:
train images: 10000
train maps: 10000
val images: 5000
val maps: 5000
test images 5000
Length of loaded files:
train images: 10000
train maps: 10000
val images: 5000
val maps: 5000


# Model and training params

In [7]:
ckpt_savedir = "ckpt"

# FILL THESE IN: set training parameters
# If you want to resume from previous training, set load_weights = True
load_weights = False
weightspath = "./ckpt/weights.hdf5"

batch_size = 4
init_lr = 0.0001
lr_reduce_by = .1
reduce_at_epoch = 3
n_epochs = 15

opt = Adam(lr=init_lr) 

model_name = "UMSI"
model_inp_size = (240, 320)
model_out_size = (480, 640)

In [8]:
input_shape = model_inp_size + (3,)

In [9]:
# get model 
model_params = {
    'input_shape': input_shape,
    'n_outs': 2
}
model_func, mode = get_model_by_name(model_name)
model = model_func(**model_params)

if load_weights: 
    model.load_weights(weightspath)
    print("load")


xception output shapes: (?, 30, 40, 2048)


TypeError: __init__() missing 2 required positional arguments: 'filters' and 'kernel_size'

In [None]:
# set up data generation and checkpoints
if not os.path.exists(ckpt_savedir): 
    os.makedirs(ckpt_savedir)

# Generators
gen_train = ImpAndClassifGenerator(
        img_filenames=data_imp[0],
        imp_filenames=data_imp[1],
        fix_filenames=None,
        extra_imgs=data_sal[0], # For feeding a much larger dataset, e.g. salicon, that the generator will subsample to maintain class balance
        extra_imps=data_sal[1],
        extra_fixs=None,
        extras_per_epoch=160,
        batch_size=4,
        img_size=(shape_r,shape_c),
        map_size=(shape_r_out, shape_c_out),
        shuffle=True,
        augment=False,
        n_output_maps=1,
        concat_fix_and_maps=False,
        fix_as_mat=False,
        fix_key="",
        str2label=None,
        dummy_labels=False,
        num_classes=6,
        pad_imgs=True,
        pad_maps=True,
        return_names=False,
        return_labels=True,
        read_npy=False)

gen_val = ImpAndClassifGenerator(
            img_filenames=data_imp[3], 
            imp_filenames=data_imp[4], 
            fix_filenames=None, 
            extra_imgs=data_sal[3], # For feeding a much larger dataset, e.g. salicon, that the generator will subsample to maintain class balance
            extra_imps=data_sal[4],
            extra_fixs=None,
            extras_per_epoch=40,
            batch_size=1, 
            img_size=(shape_r,shape_c), 
            map_size=(shape_r_out, shape_c_out),
            shuffle=False, 
            augment=False, 
            str2label=None,
            dummy_labels=False,
            #n_output_maps=1,
        )

# Callbacks

# where to save checkpoints
filepath = os.path.join(ckpt_savedir, dataset_imp + '_kl+cc+bin_ep{epoch:02d}_valloss{val_loss:.4f}.hdf5')

print("Checkpoints will be saved with format %s" % filepath)

cb_chk = ModelCheckpoint(filepath, monitor='val_loss', verbose=1, save_weights_only=True, period=1)
cb_plot = InteractivePlot()

def step_decay(epoch):
    lrate = init_lr * math.pow(lr_reduce_by, math.floor((1+epoch)/reduce_at_epoch))
    if epoch%reduce_at_epoch:
        print('Reducing lr. New lr is:', lrate)
    return lrate
cb_sched = LearningRateScheduler(step_decay)

cbs = [cb_chk, cb_sched, cb_plot]

In [None]:
#test the generator 
img, outs = gen_train.__getitem__(1)
print("batch size: %d. Num inputs: %d. Num outputs: %d." % (batch_size, len(img), len(outs)))
print(outs[0].shape)
print(outs[1].shape)

# Train

In [None]:
model.compile(optimizer=opt, loss={'dec_c_cout': kl_cc_combined, "out_classif":"binary_crossentropy"}, loss_weights={'dec_c_cout': 1, "out_classif":5})

print('Ready to train')
model.fit_generator(gen_train, epochs=n_epochs, verbose=1, callbacks=cbs, validation_data=gen_val, max_queue_size=10, workers=5)

## Visualization scripts to check the training result

In [None]:
if True: 
    W = "./ckpt/weights.hdf5"
    model.load_weights(W)
    print('load')

In [None]:
# Visualize some output on the val set 
gen = UMSI_eval_generator(
    data_sal[3], 
    data_sal[4], 
    inp_size=model_inp_size)

examples = [next(gen) for _ in range(50)]
len(examples)

In [None]:
for example in examples[4:]:
    images, maps, img_filename_= example
    preds = model.predict(images[0])
    preds_map = preds[0]
    preds_classif = preds[1]
    break


print("maps size", len(maps), maps[0].shape)
batch = 0

plt.figure(figsize = (14,8))
plt.subplot(1,3,1)
plt.imshow(reverse_preprocess(np.squeeze(images[0])))

plt.title("natural images %d" % batch)


plt.subplot(1,3,2)
plt.imshow(maps[0])
plt.title('Gt ' )

plt.subplot(1,3,3)
plt.imshow(postprocess_predictions(np.squeeze(preds_map[0]),maps[0].shape[0],maps[0].shape[1], zero_to_255=True))
plt.title('Prediction')


In [None]:
images, maps, img_filename= random.choice(examples)

print("maps size", len(maps), maps[0].shape)
batch = 0
preds = model.predict(images[0])
preds_map = preds[0]
preds_classif = preds[1]
cl = np.argmax(preds_classif)
print(preds_classif)
plt.figure(figsize = (14,8))
plt.subplot(1,3,1)
plt.imshow(reverse_preprocess(np.squeeze(images[0])))
if(cl==0):
    plt.title("advertisment %d" % batch)
if(cl==1):
    plt.title("infographic %d" % batch)
if(cl==2):
    plt.title("movie_posters %d" % batch)
if(cl==3):
    plt.title("infographics %d" % batch)
if(cl==4):
    plt.title("webpages %d" % batch)

plt.subplot(1,3,2)
plt.imshow(maps[0])
plt.title('Gt ' )

plt.subplot(1,3,3)
    # print("preds time sahpe", preds[time].shape)
plt.imshow(postprocess_predictions(np.squeeze(preds_map[0]),maps[0].shape[0],maps[0].shape[1], zero_to_255=True))
plt.title('Prediction')


In [None]:
images, maps, img_filename= random.choice(examples)

print("maps size", len(maps), maps[0].shape)
batch = 0
preds = model.predict(images[0])
preds_map = preds[0]
preds_classif = preds[1]
print(preds_classif)
plt.figure(figsize = (14,8))
plt.subplot(1,3,1)
plt.imshow(reverse_preprocess(np.squeeze(images[0])))
plt.title("original image %d" % batch)

plt.subplot(1,3,2)
plt.imshow(maps[0])
plt.title('Gt ' )

plt.subplot(1,3,3)
    # print("preds time sahpe", preds[time].shape)
plt.imshow(postprocess_predictions(np.squeeze(preds_map[0]),maps[0].shape[0],maps[0].shape[1], zero_to_255=True))
plt.title('Prediction')


# Evaluate

In [None]:
W = "./ckpt/imp1k_kl+cc+bin_ep13_valloss-2.4641.hdf5"
model.load_weights(W)
print("load weights")

In [None]:
def get_prediction(model, test_img, gt_map, inp_size, mode='simple', blur=False,):
    # if test_img_base_path is specified, then preserves the original
    # nested structure of the directory from which the stuff is pulled
    c=0
    if blur:
        print('BLURRING PREDICTIONS')
        if 'blur' not in savedir:
            savedir = savedir+'_blur'
    else:
        print('NOT BLURRING PREDICTIONS')
    pre = []
    cla = []
    maps = []
    for i in tqdm.tqdm(range(len(test_img))):
        imfile = test_img[i]
        heatmap = cv2.imread(gt_map[i], cv2.IMREAD_GRAYSCALE)
        batch = 0
        time = 0
        map_idx = 0
        gt_shape = Image.open(imfile).size[::-1]
        img = preprocess_images([imfile], inp_size[0], inp_size[1])
        preds = model.predict(img)
        if mode == 'multistream_concat':
            p = preds[time][batch][map_idx][:, :, 0]
        elif mode == 'simple':
        #Use first two lines when using our own model    
            p = preds[0][batch][:,:,0]
            classif = preds[1][0]
        elif mode == 'singlestream':
            p = preds[0][batch][time][:,:,0]
        else:
            raise ValueError('Unknown mode')
        # set zero_to_255 to True when using our own model
        p = postprocess_predictions(p, heatmap.shape[0], heatmap.shape[1], blur, normalize=False, zero_to_255=True)
        p_norm = (p-np.min(p))/(np.max(p)-np.min(p))
        p_img = p_norm*255
        hm_img = Image.fromarray(np.uint8(p_img), "L")
        pre.append(p)
        cla.append(classif)
        maps.append(heatmap)
    return np.array(pre), cla, maps
    

In [None]:
# Get predicted maps, predicted classification labels, and the ground truth maps
p, p_labels, gt_map = get_prediction(model, data_imp[3], data_imp[4], inp_size=(shape_r, shape_c), mode='simple', blur=False)

In [None]:
gt_labels = get_labels(data_sal[3])
gt_labels[1]

In [None]:
from eval import calculate_metrics
def get_eval_result(p, gt_map, gt_fix_map=None, gt_fix_points=None, gt_labels=None, p_labels=None):    
    #metrics = {"R2":[],'RMSE':[],'CC':[],'CC (saliconeval)':[],'KL':[],'SIM':[],'Acc':[],'Acc_per_class':[]}
    metrics = {"R2":[],'RMSE':[],'CC':[],'CC (saliconeval)':[],'KL':[],'SIM':[],}
    for i in range(len(gt_map)):
        m = calculate_metrics(p[i], gt_map=gt_map[i], gt_fix_map=None, gt_fix_points=None, gt_labels=None, p_labels=None)
        for key in metrics:
            if key in m:
                metrics[key].append(m[key][0])
    for key in metrics:
        if key != 'Acc_per_class':
            metrics[key] = np.mean(metrics[key])
    Acc_per_class = []
    for row in metrics['Acc_per_class'].T:
        acc = np.sum(row!=0)/len(row)
        Acc_per_class.append(acc)
    metrics['Acc_per_class'] = Acc_per_class
    return metrics

get_eval_result(p, gt_map, gt_fix_map=None, gt_fix_points=None, gt_labels=None, p_labels=None) 

In [None]:
def sal_eval(model, test_img, gt_map, inp_size, mode='simple', blur=False,gt_labels=None):
    # if test_img_base_path is specified, then preserves the original
    # nested structure of the directory from which the stuff is pulled
    metrics = {"R2":[],'RMSE':[],'CC':[],'CC (saliconeval)':[],'KL':[],'SIM':[],'Acc':[],'Acc_per_class':[]}
    c=0
    if blur:
        print('BLURRING PREDICTIONS')
        if 'blur' not in savedir:
            savedir = savedir+'_blur'
    else:
        print('NOT BLURRING PREDICTIONS')
    pre = []
    cla = []
    maps = []
    for i in tqdm.tqdm(range(len(test_img))):
        imfile = test_img[i]
        heatmap = cv2.imread(gt_map[i], cv2.IMREAD_GRAYSCALE)
        batch = 0
        time = 0
        map_idx = 0
        gt_shape = Image.open(imfile).size[::-1]
        img = preprocess_images([imfile], inp_size[0], inp_size[1])
        preds = model.predict(img)
        #print(preds[3].shape)
        if mode == 'multistream_concat':
            p = preds[time][batch][map_idx][:, :, 0]
        elif mode == 'simple':
        #Use first two lines when using our own model    
            #p = preds[0][batch][:,:,0]
            #classif = preds[1][0]
            p = preds[0][batch][:,:,0]
            classif = preds[3].reshape(6,)
        elif mode == 'singlestream':
            p = preds[0][batch][time][:,:,0]
        else:
            raise ValueError('Unknown mode')
        # set zero_to_255 to True when using our own model
        p = postprocess_predictions(p, heatmap.shape[0], heatmap.shape[1], blur, normalize=False, zero_to_255=False)
        m = calculate_metrics(p, gt_map=heatmap, gt_fix_map=None, gt_fix_points=None, gt_labels=gt_labels[i], p_labels=classif)
        for key in metrics:
            if key in m:
                metrics[key].append(m[key][0])
    for key in metrics:
        if key != 'Acc_per_class':
            metrics[key] = np.mean(metrics[key])
    Acc_per_class = np.array(metrics['Acc_per_class']).T
    Acc_per_class = Acc_per_class[5]
    acc = np.sum(Acc_per_class!=0)/len(Acc_per_class)
    metrics['Acc_per_class'] = acc
    return metrics
    

In [None]:
# Evaluation of your model (Imp1k, metrics to Table)
model_UMSI = keras.models.load_model('/path/to/model.hdf5')
sal_eval(model_UMSI, data_sal[3], data_sal[4], inp_size=(shape_r, shape_c), mode='simple', blur=False, gt_labels=gt_labels)