Load packages and start Matlab engine

In [1]:
import tensorflow as tf
physical_devices = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)

In [2]:
from tqdm import tqdm
from scipy import io as spio
import os
import sys
import numpy as np
import numpy as np
import random as rnd

seed = 42
np.random.seed(seed)
rnd.seed(seed)
os.environ['PYTHONHASHSEED']=str(seed)

import h5py
import matlab
import matlab.engine

import time
from datetime import timedelta

from hyperopt import fmin, tpe, hp, STATUS_OK, STATUS_FAIL, Trials
from hyperopt.pyll import scope
from hyperopt.fmin import generate_trials_to_calculate
from functools import reduce

from s2synth import rr_s2_data, mod_6_crop_s2_data
from sreval import rmse, sre, uiqi, ergas, sam, ssim
from mat_loaders import get_data

from S2_SSC_CNN.SSCwrap import SSCwrap

from boars import opt_method, data_list, meth_list, best_pars

In [3]:
eng = matlab.engine.start_matlab()

Initialize parameters

In [4]:
max_evals = 150
metric = 'sre'
verbose = False
eval_bands=[2,6]

In [6]:
for dataset in data_list:
    if dataset not in best_pars.keys(): best_pars[dataset] = {}

Find best parameters for S2 SSC

In [7]:
if 2 in eval_bands:
    method = SSCwrap
    SSCparameters = {'batch_size': ['int', 64, ('uniform_int', 16, 64)], 
                  'lr': ['double', 0.0005, ('lognormal', np.exp(0.1), 0.5)], 
                  'ndown': ['int', 3, ('constant')], 
                  'num_epochs': ['int', 200, ('constant')], 
                  'verbose': ['string', False, ('constant')]}

    for dataset in data_list:
        try:
            print('Tuning SSC parameters for', dataset)
            SSCpars = opt_method(method, SSCparameters, max_evals, dataset=dataset, metric=metric, eval_bands=[2], verbose=True, matlab_func=False, savefile='./opt_result/SSC_'+time.strftime("%Y%m%d%H%M"))
            best_pars[dataset]['SSC'] = SSCpars
        except:
            print('Parameter tuning failed!')

Tuning SSC parameters for apex
151trial [24:46,  9.87s/trial, best loss: -22.948764362941898]                      
pars:
 {'batch_size': 63.0, 'lr': 0.0005153685695537464}
best sre : -22.948764362941898


Find best parameters for S2Sharp

In [8]:
eng.addpath('./S2sharp')
S2sharpwrap = eng.S2sharpwrap
method = S2sharpwrap
S2sharpparameters = {'r': ['int', 8, ('uniform_int', 5, 9)], 
              'q1': ['double', 1, ('lognormal', 0, 0.5)], 
              'q2': ['double', 0.3851, ('lognormal', 0, 0.5)], 
              'q3': ['double', 6.9039, ('lognormal', 0, 0.5)], 
              'q4': ['double', 19.9581, ('lognormal', 0, 0.5)], 
              'q5': ['double', 47.8967, ('lognormal', 0, 0.5)], 
              'q6': ['double', 27.5518, ('lognormal', 0, 0.5)], 
              'q7': ['double', 2.7100, ('lognormal', 0, 0.5)], 
              'q8': ['double', 34.8689, ('lognormal', 0, 0.5)], 
              'q9': ['double', 1, ('lognormal', 0, 0.5)], 
              'q10': ['double', 1, ('lognormal', 0, 0.5)], 
              'lam': ['double', 1.8998e-04, ('lognormal', np.exp(0.005), 0.5)]}


In [9]:
for dataset in data_list:
    try:
        print('Tuning S2sharp parameters for', dataset)
        S2sharppars = opt_method(method, S2sharpparameters, max_evals, dataset=dataset, metric='sre', eval_bands=eval_bands, verbose=True, savefile='./opt_result/S2sharp_'+time.strftime("%Y%m%d%H%M"))
        best_pars[dataset]['S2sharp'] = S2sharppars
    except:
        print('Parameter tuning failed!')

Tuning S2sharp parameters for han_iceland_rr
167trial [37:45, 28.30s/trial, best loss: -25.872527212971523]                      
Tuning S2sharp parameters for usa_rr
167trial [24:23,  8.46s/trial, best loss: -28.175247434945277]                      
Tuning S2sharp parameters for vietnam_rr
167trial [2:29:39, 116.70s/trial, best loss: -22.632212686319985]                      
Tuning S2sharp parameters for australia_rr
167trial [57:11, 26.21s/trial, best loss: -26.252225111550036]                       


## Once code has been placed in appropriate folders change the following cells from raw to code in order to run.

Find best parameters for ATPRK

Find best parameters for SSSS

Find best parameters for SupReME

Find best parameters for MuSA

Display the best parameters, copy results into `boars.py` to update

In [19]:
import json
print('best_pars =', json.dumps(best_pars, sort_keys=True, indent=4))

best_pars = {
    "apex": {
        "ATPRK": {
            "H": 11.0,
            "L_range": 17.0,
            "L_sill": 13.0,
            "Range_min": 0.9295475471200471,
            "Sill_min": 2.0,
            "rate": 1.1093648601650439
        },
        "MuSA": {
            "lam": 0.0005075231839284101,
            "mu": 0.21136587946368565
        },
        "S2sharp": {
            "lam": 1.9341127858603557,
            "q1": 0.006073349262165454,
            "q10": 0.47647853130518464,
            "q2": 7.208822829845069e-05,
            "q3": 0.10893797078726071,
            "q4": 1.135516401017811,
            "q5": 14.103620956839348,
            "q6": 25.39252883568256,
            "q7": 0.9735978109673542,
            "q8": 66.43736365056141,
            "q9": 3.2947729421976493,
            "r": 9.0
        },
        "SSC": {
            "batch_size": 63.0,
            "lr": 0.0005153685695537464
        },
        "SSSS": {
            "lam": 0.16738800822718586,
     

Run methods with best parameters and compare.

In [11]:
result_dir = './results/' + time.strftime("%Y%m%d%H%M")
multi_run=False
limsub = 6
eval_bands=None
eval_dfs = {}

band_scales = np.array([6, 1, 1, 1, 2, 2, 2, 1, 2, 6, 2, 2])
band_names = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B8A', 'B9', 'B11', 'B12']
band_idxs = {'B1':0, 'B2':1, 'B3':2, 'B4':3, 'B5':4, 'B6':5, 'B7':6, 'B8':7, 'B8A':8, 'B9':9, 'B11':10, 'B12':11}

for dataset in data_list:
    #Load data
    (Yim, Xm_im, eval_bands) = get_data(dataset, datadir='./data/')
    (Yim, mtf) = get_data(dataset, datadir='./data/', get_mtf=True)[:2]
    # Format for matlab methods
    mYim = [matlab.double(b.tolist()) for b in Yim]
    mXm_im =[]
    for b in range(Xm_im.shape[-1]):
        mXm_im.append(matlab.double(Xm_im[:,:,b].tolist()))
    mXm_im = matlab.double(mXm_im)
    mXm_im = eng.permute(mXm_im, matlab.int16([2,3,1]))
    # Make directory to save resullts
    os.makedirs(result_dir + '/' + dataset, exist_ok = True)
    for meth_name in meth_list:
        try:
            if verbose: print('Processing', dataset, 'with', meth_name, '!')
            method = get_method(meth_name, matlab_handle=eng)
            pars = best_pars.get(dataset, best_pars['default']).get(meth_name, {})
            if meth_name is 'SSC' or meth_name is 'DSen2':
                Xhat = method(Yim, **pars, verbose=False)
            else:
                Xhat = np.array(method(mYim, *pars))
            Xhats[meth_name] = Xhat
            if verbose: print(meth_name, 'done!')
        except:
            print(meth_name, 'failed!')
    np.savez_compressed(result_dir + '/' + dataset + '/srdata.npz', **Xhats)
    if eval_bands is not None:
        from sreval import sreval, evaluate_performance, dataframe_from_res_list

        res_list = sreval(Xm_im, Xhats.values(), limsub=limsub, bands=eval_bands)
        evals_df = dataframe_from_res_list(res_list, Xhats.keys(), band_idxs, multi_run=multi_run)
        evals_df = evals_df.drop(['B2','B3','B4','B8'], level=1)
        if 2 not in eval_bands:
            evals_df = evals_df.drop(['B5','B6','B7','B8A', 'B11', 'B12','20m','All'], level=1)
        if 6 not in eval_bands:
            evals_df = evals_df.drop(['B1','B9','60m','All'], level=1)
        evals_df.to_csv(result_dir + '/' + dataset + '/tuned_comparison.csv',sep='\t',decimal=',')
        evals_df.to_hdf(result_dir + '/' + dataset + '/tuned_comparison_dataframe.h5', key=dataset)
        eval_dfs[dataset] = evals_df

Processing with  ATPRK !
ATPRK done!
Processing with  DSen2 !
Symbolic Model Created.
Predicting using file: ./DSen2/models/s2_030_lr_1e-05.hdf5


Using TensorFlow backend.


(2, 324, 324)
Symbolic Model Created.
Predicting using file: ./DSen2/models/s2_032_lr_1e-04.hdf5
(6, 324, 324)
DSen2 done!
Processing with  MuSA !
MuSA done!
Processing with  S2sharp !
S2sharp done!
Processing with  SSC !
SSC failed!
Processing with  SSSS !
SSSS done!
Processing with  SupReME !
SupReME done!


Display results.

In [None]:
clip=10000
cmap='gray_r'
for dataset in data_list:
    imdir = result_dir + '/' + dataset + '/images/'
    for band in ['B1', 'B5', 'B6', 'B7', 'B8A', 'B9', 'B11', 'B12']:
        os.makedirs(imdir + band, exist_ok = True)
        
        print('Input', band)
        cbar_fig, cbar_ax = plt.subplots()
        fig, ax = plt.subplots()
        cbar_fig.set_figwidth(fig.get_figwidth()/30)
        sns_plot = sns.heatmap(Yim[band_idxs[band]][:,:],
                    xticklabels=False, yticklabels=False, vmin=0, vmax=clip,
                    square=True, cmap=cmap, ax=ax, cbar_ax=cbar_ax)
        plt.grid(False)
        fig.savefig(imdir + band + '/input_f.png', bbox_inches='tight')
        cbar_fig.savefig(imdir + band + '/input_f_cbar.png', bbox_inches='tight')
        plt.close(fig)
        plt.close(cbar_fig)

        print('Ground truth')
        if np.isnan(Xm_im[:,:,band_idxs[band]]).any():
            print('NaN')
        else:
            cbar_fig, cbar_ax = plt.subplots()
            fig, ax = plt.subplots()
            cbar_fig.set_figwidth(fig.get_figwidth()/30)
            sns_plot = sns.heatmap(Xm_im[:,:,band_idxs[band]],
                        xticklabels=False, yticklabels=False, vmin=0, vmax=clip,
                        square=True, cmap=cmap, ax=ax, cbar_ax=cbar_ax)
            plt.grid(False)
            fig.savefig(imdir + band + '/gt_f.png', bbox_inches='tight')
            cbar_fig.savefig(imdir + band + '/gt_f_cbar.png', bbox_inches='tight')
            plt.close(fig)
            plt.close(cbar_fig)

        print('Input zoom')
        cbar_fig, cbar_ax = plt.subplots()
        fig, ax = plt.subplots()
        cbar_fig.set_figwidth(fig.get_figwidth()/30)
        z = 16 if band in ['B1', 'B9'] else 48
        sns_plot = sns.heatmap(Yim[band_idxs[band]][:z,:z],
                    xticklabels=False, yticklabels=False, vmin=0, vmax=clip,
                    square=True, cmap=cmap, ax=ax, cbar_ax=cbar_ax)
        plt.grid(False)
        fig.savefig(imdir + band + '/input_z.png', bbox_inches='tight')
        cbar_fig.savefig(imdir + band + '/input_z_cbar.png', bbox_inches='tight')
        plt.show()
        plt.close(fig)
        plt.close(cbar_fig)

        print('Ground truth zoom')
        if np.isnan(Xm_im[:96,:96,band_idxs[band]]).any():
            print('NaN')
        else:
            cbar_fig, cbar_ax = plt.subplots()
            fig, ax = plt.subplots()
            cbar_fig.set_figwidth(fig.get_figwidth()/30)
            sns_plot = sns.heatmap(Xm_im[:96,:96,band_idxs[band]],
                        xticklabels=False, yticklabels=False, vmin=0, vmax=clip,
                        square=True, cmap=cmap, ax=ax, cbar_ax=cbar_ax)
            plt.grid(False)
            fig.savefig(imdir + band + '/gt_z.png', bbox_inches='tight')
            cbar_fig.savefig(imdir + band + '/gt_z_cbar.png', bbox_inches='tight')
            #plt.show()
            plt.close(fig)
            plt.close(cbar_fig)
        for method, Xhat_im in Xhats.items():
            print(band, method)
            if np.isnan(Xhat_im[:96,:96,band_idxs[band]]).any():
                print('NaN')
            else:
                cbar_fig, cbar_ax = plt.subplots()
                fig, ax = plt.subplots()
                cbar_fig.set_figwidth(fig.get_figwidth()/30)
                sns_plot = sns.heatmap(np.maximum(0,Xhat_im[:96,:96,band_idxs[band]]),
                            xticklabels=False, yticklabels=False, vmin=0, vmax=clip,
                            square=True, cmap=cmap, cbar_ax=cbar_ax, ax=ax)
                plt.grid(False)
                fig.savefig(imdir + band + '/' + method + '_z.png', bbox_inches='tight')
                cbar_fig.savefig(imdir + band + '/' + method + '_z_cbar.png', bbox_inches='tight')
                plt.show()
                plt.close(fig)
                plt.close(cbar_fig)