In [85]:
import numpy as np
import os
import h5py

In [86]:
def calculate_stat(values, stat_type):
    if stat_type == 'CI':
        z_val = 1.96
        mean_val = np.mean(values)
        std = np.std(values)
        err_margin = z_val*(std/np.sqrt(len(values)))
        CI = (mean_val-err_margin, mean_val+err_margin)
        return CI
    elif stat_type == 'SD':
        sd = float("{0:.4f}".format(np.std(values)))
        mean = float("{0:.4f}".format(np.mean(values)))
        return "{} +- {}".format(mean, sd)
    else:
        raise ValueError('Stat type not defined!')

def _read_test_set(path_to_test):
    data_dict = {}
    for file in os.listdir(path_to_test):
        if file.endswith('.h5'):
            h5 = h5py.File(os.path.join(path_to_test, file), 'r')
            if file.split('.h5')[0] not in data_dict:
                data_dict[file.split('.h5')[0]] = h5['data'].shape[0]
    return data_dict

def _gather_data(data, metrics):
    for slice_type in data:
        given_lines = data[slice_type]
        for line in given_lines:
            str_of_interest = line.split('INFO:')[-1]
            data_string = str_of_interest.split('-')[-1].strip()
            psnr_val = float(data_string.split('PSNR:')[-1].split('dB')[0].strip())
            ssim_val = float(data_string.split('SSIM:')[-1].split(';')[0].strip())
            pdist_val = float(data_string.split('pdist:')[-1][:-1])
            metrics['PSNR'][slice_type].append(psnr_val)
            metrics['SSIM'][slice_type].append(ssim_val)
            metrics['pdist'][slice_type].append(pdist_val)
    return metrics

def main_perline(test_log, path_to_test_set, start_at, stop_at, stat_type='SD'):
    get_num_slices = _read_test_set(path_to_test_set)    
    metrics = {
                'PSNR':{'Ax':[], 'Co':[], 'Sag':[]}, 
                'SSIM':{'Ax':[], 'Co':[], 'Sag':[]}, 
                'pdist':{'Ax':[], 'Co':[], 'Sag':[]}
              }
    with open(test_log, 'r') as file:
        lines = file.readlines()
        lines = lines[start_at:stop_at]
        start_at = 0      
        for sub_index, subject in enumerate(get_num_slices):
            lines_dict = {
                            'Ax':lines[start_at: start_at+get_num_slices[subject]],
                            'Co':lines[start_at+get_num_slices[subject]: start_at+get_num_slices[subject]+512], 
                            'Sag':lines[start_at+get_num_slices[subject]+512: start_at+get_num_slices[subject]+512*2]
                          }
            start_at = start_at +  get_num_slices[subject] + 512*2 + 3
            metrics = _gather_data(lines_dict, metrics)
        # calculate CI
        for metric in metrics:
            print('Calculating {} {}'.format(metric, stat_type))
            print('-'*40)
            for view in metrics[metric].keys():
                print('{}: {}'.format(view, calculate_stat(metrics[metric][view], stat_type)))
            print('='*40)

In [87]:
sngan_log = '/workspace/NormGAN/results/SNGAN-AAPM/test.log_211019-212834.log' # (139, 16342)
wgan_log = '/workspace/NormGAN/results/WGAN-AAPM/test.log_211020-131140.log' # (66, 16269)
path_to_test_set = '/aapm_data/aapm_3d_lowdose_testset'
# main(sngan_log, 'sngan')
main_perline(wgan_log, path_to_test_set, 66, 16269) 

Calculating PSNR SD
----------------------------------------
Ax: 32.4978 +- 1.0229
Co: 33.7595 +- 4.2488
Sag: 32.7905 +- 1.9957
Calculating SSIM SD
----------------------------------------
Ax: 0.8304 +- 0.0289
Co: 0.83 +- 0.0296
Sag: 0.831 +- 0.0206
Calculating pdist SD
----------------------------------------
Ax: 0.2226 +- 0.0129
Co: 0.2943 +- 0.053
Sag: 0.2643 +- 0.0172
