I dont own the codes for SASSI, If you get permission from Vishwanath Saragadam (https://vishwa91.github.io/), I can share the codes for SASSI and SASSI_testing (my wrapper function for SASSI).

In [None]:
# Libraries and Modules
from matplotlib import pyplot as plt
import numpy as np
import scipy as sp
import cv2
import time
import os
import sys
import pickle
nb_dir = os.getcwd()
sys.path.insert(0,nb_dir+'/SASSI')
import aux_funcs.noise as noise
import aux_funcs.visualization as vis
import aux_funcs.data_import as data
import aux_funcs.colorization as clr
import aux_funcs.spectral_dimensionality as dim
import aux_funcs.metrics as qm
import aux_funcs.initialization as init
import aux_funcs.comparioson_helpers as comp
from SASSI.demo import SASSI_testing

In [None]:
(lams, hpim) = data.importer.BearAndFruit_low_res()
l,h,w = hpim.shape
print(l,h,w)
vis.draw_hpim(hpim, draw=True, lams = lams, method='1931')
sd = dim.SpectralDim()
sd.loadSpectralBasis('data/spectral_basis_data_400-10-700.npy')
metrics, hib, _ = qm.getQualityMetrics() # hib stands for higher is better.

In [None]:
def compare_methods(hpim: np.ndarray,lams: np.ndarray,sd:dim.SpectralDim,metrics:dict,sampling_ratio:float=0.03,
                    zoom_region:tuple=None,interest_point1:tuple=None,interest_point2:tuple=None,data_name:str=None):
    base_path = os.path.join('results', 'comparison', data_name)
    os.mkdir(base_path)
    vis_method = '1931'
    tick_font_size = 12
    label_font_size = 14
    pad = 0.1
    exposure_time_scale = hpim.shape[0] / 31
    exposure_times = [exposure_time_scale*1e-3, exposure_time_scale*1e-4, exposure_time_scale*1e-5]
    gray = np.sum(hpim, axis=0, keepdims=True)
    rgb = vis.draw_hpim(hpim, lams, False, vis_method)
    result_dict = {}
    result_dict['cao'] = {}
    result_dict['sassi'] = {}
    result_dict['hc'] = {}
    for key in result_dict.keys():
        result_dict[key]['PSNR'] = []
        result_dict[key]['EMD'] = []
        result_dict[key]['GFC'] = []
        result_dict[key]['SSV'] = []
        result_dict[key]['time'] = 0
    result_dict['exposure_times'] = exposure_times
    for exposure_time in exposure_times:
        current_path = os.path.join(base_path, '{:.0e}'.format(exposure_time))
        os.mkdir(current_path)
        hpim_n = noise.addPoissonNoise(hpim, exposure_time)
        #hpim_n = noise.addGaussianNoise(hpim_n, 0, 0.01)
        t0 = time.time()
        hc_result = comp.hyperColorizationWrapper(gray, hpim_n, lams, sd, sampling_ratio, False)
        t1 = time.time()
        cao_result = comp.Cao2015Wrapper(hpim_n, rgb, sampling_ratio)
        t2 = time.time()
        sassi_result, _ = SASSI_testing(hpim, hpim_n, lams, sampling_ratio)
        sassi_result *= np.sum(hpim) / np.sum(sassi_result)
        t3 = time.time()
        vis.save_hpim(os.path.join(current_path, 'hc.png'), hc_result, lams, method=vis_method)
        vis.save_hpim(os.path.join(current_path, 'sassi.png'), sassi_result, lams, method=vis_method)
        vis.save_hpim(os.path.join(current_path, 'cao.png'), cao_result, lams, method=vis_method)
        for key in result_dict['cao'].keys():
            if key == 'time':
                result_dict['cao']['time'] += (t2-t1)/len(exposure_times)
                result_dict['sassi']['time'] += (t3-t2)/len(exposure_times)
                result_dict['hc']['time'] += (t1-t0)/len(exposure_times)
                continue
            result_dict['cao'][key].append(metrics[key](hpim, cao_result))
            result_dict['sassi'][key].append(metrics[key](hpim, sassi_result))
            result_dict['hc'][key].append(metrics[key](hpim, hc_result))    
        if zoom_region != None:
            top_left = zoom_region[0]
            bottom_right = zoom_region[1]
            hc_zoomed = hc_result[:, top_left[1]:bottom_right[1], top_left[0]:bottom_right[0]]
            sassi_zoomed = sassi_result[:, top_left[1]:bottom_right[1], top_left[0]:bottom_right[0]]
            cao_zoomed = cao_result[:, top_left[1]:bottom_right[1], top_left[0]:bottom_right[0]]
            vis.save_hpim(os.path.join(current_path, 'hc_zoomed.png'), hc_zoomed, lams, method=vis_method)
            vis.save_hpim(os.path.join(current_path, 'sassi_zoomed.png'), sassi_zoomed, lams, method=vis_method)
            vis.save_hpim(os.path.join(current_path, 'cao_zoomed.png'), cao_zoomed, lams, method=vis_method)
        if interest_point1 != None:
            gt_spectral_curve = hpim[:, interest_point1[1], interest_point1[0]]
            hc_spectral_curve = hc_result[:, interest_point1[1], interest_point1[0]]
            sassi_spectral_curve = sassi_result[:, interest_point1[1], interest_point1[0]]
            cao_spectral_curve = cao_result[:, interest_point1[1], interest_point1[0]]
            fig = plt.figure('line', figsize=(10,6))
            ax = fig.add_subplot(111)
            ax.plot(lams, gt_spectral_curve, label='Ground Truth', color='black', linewidth = '3')
            ax.plot(lams, hc_spectral_curve, label='HC', alpha = 1.0)
            ax.plot(lams, sassi_spectral_curve, label='SASSI', alpha = 0.8)
            ax.plot(lams, cao_spectral_curve, label='Cao et al.', alpha = 0.6)
            #plt.legend()
            ax.set_xlabel('λ', fontsize=label_font_size)
            ax.set_ylabel('Intensity', fontsize=label_font_size)
            ax.tick_params(axis='both', which='major', labelsize=tick_font_size)
            fig.savefig(os.path.join(current_path, 'spectral_curve1.png'), bbox_inches='tight', pad_inches=pad)
            plt.close()
        if interest_point2 != None:
            gt_spectral_curve = hpim[:, interest_point2[1], interest_point2[0]]
            hc_spectral_curve = hc_result[:, interest_point2[1], interest_point2[0]]
            sassi_spectral_curve = sassi_result[:, interest_point2[1], interest_point2[0]]
            cao_spectral_curve = cao_result[:, interest_point2[1], interest_point2[0]]
            fig = plt.figure('line', figsize=(10,6))
            ax = fig.add_subplot(111)
            line1 = ax.plot(lams, gt_spectral_curve, label='Ground Truth', color='black', linewidth = '3')
            line2 =ax.plot(lams, hc_spectral_curve, label='HC', alpha = 1.0)
            line3 =ax.plot(lams, sassi_spectral_curve, label='SASSI', alpha = 0.8)
            line4 =ax.plot(lams, cao_spectral_curve, label='Cao et al.', alpha = 0.6)
            ax.set_xlabel('λ', fontsize=label_font_size)
            ax.set_ylabel('Intensity', fontsize=label_font_size)
            ax.tick_params(axis='both', which='major', labelsize=tick_font_size)
            handles,labels = ax.get_legend_handles_labels()
            fig.savefig(os.path.join(current_path, 'spectral_curve2.png'), bbox_inches='tight', pad_inches=pad)
            plt.close()
            fig_legend = plt.figure(figsize=(2,2))
            axi = fig_legend.add_subplot(111)            
            fig_legend.legend(handles, labels, loc='center', scatterpoints = 1)
            axi.xaxis.set_visible(False)
            axi.yaxis.set_visible(False)
            fig_legend.canvas.draw()
            fig_legend.savefig(os.path.join(current_path, 'legend.png'), bbox_inches='tight', pad_inches=pad)
            plt.close()
        print('{} Exposure time {:.0e} finished'.format(data_name, exposure_time))
        print('PSNR: HC {:.2f}, SASSI {:.2f}, Cao {:.2f}'.format(result_dict['hc']['PSNR'][-1], result_dict['sassi']['PSNR'][-1], result_dict['cao']['PSNR'][-1]))
    with open(os.path.join(base_path, 'results.pkl'), 'wb') as f:
        pickle.dump(result_dict, f)

In [None]:
compare_methods(hpim, lams, sd, metrics, 0.03, ((69,149),(154,216)), (362,300), (86,302), 'ICVL')

In [None]:
def checkPerformanceNumbers(data_name: str, metric: str):
    with open(os.path.join('results','comparison', data_name, 'results.pkl'), 'rb') as f:
        data = pickle.load(f)
        print('Average Run time: HC {:.2f}, SASSI {:.2f}, Cao {:.2f}'.format(data['hc']['time'], data['sassi']['time'], data['cao']['time']))
        print('Exposure Time:{:.2e} {}: HC {:.2e}, SASSI {:.2e}, Cao {:.2e}'.format(data['exposure_times'][0],metric,data['hc'][metric][0], data['sassi'][metric][0], data['cao'][metric][0]))
        print('Exposure Time:{:.2e} {}: HC {:.2e}, SASSI {:.2e}, Cao {:.2e}'.format(data['exposure_times'][1],metric,data['hc'][metric][1], data['sassi'][metric][1], data['cao'][metric][1]))
        print('Exposure Time:{:.2e} {}: HC {:.2e}, SASSI {:.2e}, Cao {:.2e}'.format(data['exposure_times'][2],metric,data['hc'][metric][2], data['sassi'][metric][2], data['cao'][metric][2]))
    return
checkPerformanceNumbers('bear_and_fruit', 'EMD')

In [None]:
def compare_methods_over_dataset(data_dict, metrics, sampling_ratio = 0.03):
    for key in data_dict.keys():
        if key == 'bear_and_fruit':
            if os.path.exists(os.path.join('results', 'comparison', key)):
                continue
            lams, hpim = data.importer.BearAndFruit()
            sd = dim.SpectralDim()
            sd.loadSpectralBasis('data/spectral_basis_data_400-10-700.npy')
            compare_methods(hpim, lams, sd, metrics, sampling_ratio, data_name=key)
        if key == 'Harvard':
            sd = dim.SpectralDim()
            sd.loadSpectralBasis('data/spectral_basis_data_420-10-720.npy')
            if not os.path.exists(os.path.join('results', 'comparison', key)):
                os.mkdir(os.path.join('results', 'comparison', key))
            for entry in data_dict[key]:
                if os.path.exists(os.path.join('results', 'comparison', key, str(entry))):
                    continue
                lams, hpim = data.importer.load_Harvard_img(entry)
                compare_methods(hpim, lams, sd, metrics, sampling_ratio, data_name=os.path.join(key, str(entry)))
        if key == 'CAVE':
            sd = dim.SpectralDim()
            sd.loadSpectralBasis('data/spectral_basis_data_400-10-700.npy')
            if not os.path.exists(os.path.join('results', 'comparison', key)):
                os.mkdir(os.path.join('results', 'comparison', key))
            for entry in data_dict[key]:
                if os.path.exists(os.path.join('results', 'comparison', key, str(entry))):
                    continue
                lams, hpim = data.importer.load_CAVE_img(entry)
                compare_methods(hpim, lams, sd, metrics, sampling_ratio, data_name=os.path.join(key, str(entry)))
        if key == 'Kaist':
            sd = dim.SpectralDim()
            sd.loadSpectralBasis('data/spectral_basis_data_420-10-720.npy')
            if not os.path.exists(os.path.join('results', 'comparison', key)):
                os.mkdir(os.path.join('results', 'comparison', key))
            for entry in data_dict[key]:
                if os.path.exists(os.path.join('results', 'comparison', key, str(entry))):
                    continue
                lams, hpim = data.importer.load_KAIST_img(entry)
                compare_methods(hpim, lams, sd, metrics, sampling_ratio, data_name=os.path.join(key, str(entry)))

In [None]:
data_dict = {
    'bear_and_fruit': [0],
    'Harvard': [0, 26, 44, 48, 49],
    'Kaist': [2, 22, 26],
    'CAVE': [0, 1, 5, 7]
}
compare_methods_over_dataset(data_dict, metrics, sampling_ratio = 0.03)

In [None]:
def accumulatorHelper(accumulator:dict, pickleLoc:str, metric:str):
     with open(pickleLoc, 'rb') as f:
         d = pickle.load(f)
         for key in accumulator.keys():
            if key == 'time':
                for key2 in accumulator[key].keys():
                    accumulator[key][key2] += d[key2]['time']
                continue
            for i,entry in enumerate(accumulator[key]):
                accumulator[key][i] += d[key][metric][i]
        
def checkPerformanceNumbersOverDataset(data_dict: dict, metric:str):
    accumulator = {}
    accumulator['hc'] = [0, 0, 0]
    accumulator['sassi'] = [0, 0, 0]
    accumulator['cao'] = [0, 0, 0]
    accumulator['time'] = {'hc': 0, 'sassi': 0, 'cao': 0}
    num_data = 0
    for key in data_dict.keys():
        if key == 'bear_and_fruit':
            path_to_pickle = os.path.join('results', 'comparison', key, 'results.pkl')
            accumulatorHelper(accumulator, path_to_pickle, metric)
            num_data += 1
        if key == 'Harvard':
            for entry in data_dict[key]:
                path_to_pickle = os.path.join('results', 'comparison', key, str(entry), 'results.pkl')
                accumulatorHelper(accumulator, path_to_pickle, metric)
                num_data += 1
        if key == 'CAVE':
            for entry in data_dict[key]:
                path_to_pickle = os.path.join('results', 'comparison', key, str(entry), 'results.pkl')
                accumulatorHelper(accumulator, path_to_pickle, metric)
                num_data += 1
        if key == 'Kaist':
            for entry in data_dict[key]:
                path_to_pickle = os.path.join('results', 'comparison', key, str(entry), 'results.pkl')
                accumulatorHelper(accumulator, path_to_pickle, metric)
                num_data += 1
    print('Average ' + metric + ' score over ' + str(num_data) + ' hyperspectral images')
    print('Exposure Time: 1e-3: HC {:.2e}, SASSI {:.2e}, Cao {:.2e}'.format(accumulator['hc'][0]/num_data, accumulator['sassi'][0]/num_data, accumulator['cao'][0]/num_data))
    print('Exposure Time: 1e-4: HC {:.2e}, SASSI {:.2e}, Cao {:.2e}'.format(accumulator['hc'][1]/num_data, accumulator['sassi'][1]/num_data, accumulator['cao'][1]/num_data))
    print('Exposure Time: 1e-5: HC {:.2e}, SASSI {:.2e}, Cao {:.2e}'.format(accumulator['hc'][2]/num_data, accumulator['sassi'][2]/num_data, accumulator['cao'][2]/num_data))
    print('Average Run time: HC {:.2e}, SASSI {:.2e}, Cao {:.2e}'.format(accumulator['time']['hc']/num_data, accumulator['time']['sassi']/num_data, accumulator['time']['cao']/num_data))

In [None]:
checkPerformanceNumbersOverDataset(data_dict, 'SSIM')

In [None]:
def load_KAIST_recon_and_gt(recon_loc:str, gt_loc:str):
    f = sp.io.loadmat(recon_loc)
    x_recon = f['x_recon']
    wvls2b = np.squeeze(f['wvls2b'])
    x_recon = np.moveaxis(x_recon, -1, 0)
    print('Reconstruction: ')
    rgb_recon = vis.draw_hpim(x_recon, wvls2b, '1931')
    plt.imsave('results/Kaist/KAIST_recon.png', rgb_recon)
    plt.show()
    f = sp.io.loadmat(gt_loc)
    x_gt = f['img_hs']
    x_gt = np.moveaxis(x_gt, -1, 0)
    x_gt = x_gt / (np.max(x_gt))
    print('Ground Truth: ')
    rgb_gt = vis.draw_hpim(x_gt, wvls2b, '1931')
    plt.imsave('results/Kaist/KAIST_gt.png', rgb_gt)
    plt.show()
    print('EMD: {:.4e}'.format(qm.EMD(x_gt, x_recon)))
    print('SSIM: {:.4e}'.format(qm.SSIM(x_gt, x_recon)))
    return x_gt, x_recon, wvls2b

kaist_gt, kaist_recon, lams = load_KAIST_recon_and_gt('data/scene03_recon.mat', 'data/scene03.mat')

In [None]:
hc_result = comp.hyperColorizationWrapper(np.sum(kaist_gt, axis=0, keepdims=True), kaist_gt, lams, sd, 0.03, False)
rgb_gt = vis.draw_hpim(hc_result, lams, draw=True, method= '1931')

In [None]:
plt.plot(lams, kaist_gt[:, 641, 434], label='GT', color='black', linewidth = '2.4')
plt.plot(lams, hc_result[:, 641, 434], label='HyperColorization')
plt.plot(lams, kaist_recon[:, 641, 434], label='Choi el al.')
plt.legend()
plt.show()