In [1]:
from tomoSegmentPipeline.utils.common import read_array, write_array
from tomoSegmentPipeline.utils import setup
from cryoS2Sdrop.predict import load_model


from pytorch_msssim import ssim
from torchmetrics.functional import peak_signal_noise_ratio, mean_squared_error
import torch

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
from glob import glob
import os
import yaml
# from itables import init_notebook_mode
# init_notebook_mode(all_interactive=True)

PARENT_PATH = setup.PARENT_PATH

pd.options.display.float_format = "{:,.3f}".format

%matplotlib inline
%config Completer.use_jedi = False
%load_ext autoreload
%autoreload 2

In [2]:
def standardize(X: torch.tensor):
    mean = X.mean()
    std = X.std()
    new_X = (X - mean) / std

    return new_X

def clip(X, low=0.005, high=0.995):
    # works with tensors =)
    return np.clip(X, np.quantile(X, low), np.quantile(X, high))

def get_metrics(tomo_path, gt_tomo_path):
    if (tomo_path is not None) and (gt_tomo_path is not None):
        data = read_array(tomo_path)
        data = torch.tensor(data).unsqueeze(0).unsqueeze(0)
        data = standardize(clip(data))

        reference = read_array(gt_tomo_path)
        reference = torch.tensor(reference).unsqueeze(0).unsqueeze(0)
        reference = standardize(clip(reference))
        psnr, ssim_idx = float(peak_signal_noise_ratio(data, reference)), float(ssim(data, reference))
        
        try:
            name = tomo_path.split('/')[-1].replace('.mrc', '_n2vDenoised.mrc')
            n2v_pred_path = os.path.join(PARENT_PATH, "data/S2SDenoising/denoised/%s" % (name))
            n2v_data = read_array(n2v_pred_path)
            n2v_data = torch.tensor(n2v_data).unsqueeze(0).unsqueeze(0)
            n2v_data = standardize(clip(n2v_data))
            n2v_psnr, n2v_ssim_idx = float(peak_signal_noise_ratio(n2v_data, reference)), float(ssim(n2v_data, reference))
            
        except OSError:
            n2v_psnr, n2v_ssim_idx = None, None


    else:
        psnr, ssim_idx = None, None
        n2v_psnr, n2v_ssim_idx = None, None
        
    return psnr, ssim_idx, n2v_psnr, n2v_ssim_idx

In [3]:
logdir = 'data/S2SDenoising/tryout_model_logs/'
logdir = os.path.join(PARENT_PATH, logdir)

all_logs = glob(logdir+'*/*/*.yaml')

keys = ['Version_comment', 'transform', 'full_tomo_psnr', 'full_tomo_ssim', 'tomo_path', 'gt_tomo_path']
data_log = []

for yaml_logdir in all_logs:
    model = yaml_logdir.split('/')[-3]
    version = yaml_logdir.split('/')[-2]
    with open(yaml_logdir) as f:
        hparams = yaml.load(f, Loader=yaml.BaseLoader)
    # yaml is stupid
    with open(yaml_logdir) as f:
        s = f.readlines()
        try:
            dataloader = [x for x in s if 'Dataloader' in x][0]
            dataloader = dataloader.split('.')[-1].replace('\n', '').replace('\'', '')
        except:
            dataloader = 'Unknown'
        
    row_vals = [model, version, dataloader, hparams['loss_fn']['alpha']]
    for k in keys:
        try:
            row_vals += [hparams[k]]
        except KeyError:
            row_vals += [None]
    data_log.append(row_vals)
    
data_log = pd.DataFrame(data_log, columns=['model', 'version', 'dataloader', 'TV_alpha']+keys)
data_log = data_log.sort_values(['model', 'version'])

baseline_metrics = data_log[['tomo_path', 'gt_tomo_path']].apply(lambda x: get_metrics(x[0], x[1]), axis=1)

data_log['baseline_psnr'], data_log['baseline_ssim'], data_log['n2v_psnr'], data_log['n2v_ssim'] = zip(*baseline_metrics) 
data_log[['full_tomo_psnr', 'full_tomo_ssim']] = data_log[['full_tomo_psnr', 'full_tomo_ssim']].astype(float)

tomo_path = data_log.tomo_path.map(lambda x: x.split('/')[-1] if x is not None else x)
gt_tomo_path = data_log.gt_tomo_path.map(lambda x: x.split('/')[-1] if x is not None else x)

data_log.drop(['tomo_path', 'gt_tomo_path'], axis=1, inplace=True)
data_log['tomo_path'], data_log['gt_tomo_path'] = [tomo_path, gt_tomo_path]

data_log

Unnamed: 0,model,version,dataloader,TV_alpha,Version_comment,transform,full_tomo_psnr,full_tomo_ssim,baseline_psnr,baseline_ssim,n2v_psnr,n2v_ssim,tomo_path,gt_tomo_path
13,model14,version_0,singleCET_dataset,0.0,Model 14 baseline using regular bernoulli samp...,{'p': '0.5'},20.272,0.984,15.567,0.982,18.805,0.991,tomoPhantom_model14_Poisson5000+Gauss5+stripes...,tomoPhantom_model14.mrc
14,model14,version_1,singleCET_FourierDataset,0.0,Use Fourier samples,{'p': '0.5'},19.347,0.979,15.567,0.982,18.805,0.991,tomoPhantom_model14_Poisson5000+Gauss5+stripes...,tomoPhantom_model14.mrc
9,model16,version_0,singleCET_dataset,0.0,Model 16 baseline,{'p': '0.5'},13.139,0.979,10.708,0.977,12.455,0.98,tomoPhantom_model16_Poisson5000+Gauss5+stripes...,tomoPhantom_model16.mrc
10,model16,version_1,singleCET_FourierDataset,0.0,Fourier sample version,{'p': '0.5'},13.011,0.979,10.708,0.977,12.455,0.98,tomoPhantom_model16_Poisson5000+Gauss5+stripes...,tomoPhantom_model16.mrc
12,model16,version_2,singleCET_FourierDataset,0.0,Fourier sample version with more epochs,{'p': '0.5'},13.194,0.979,10.708,0.977,12.455,0.98,tomoPhantom_model16_Poisson5000+Gauss5+stripes...,tomoPhantom_model16.mrc
11,model16,version_3,singleCET_dataset,0.0,Bernoulli sample version with more epochs,{'p': '0.5'},13.108,0.979,10.708,0.977,12.455,0.98,tomoPhantom_model16_Poisson5000+Gauss5+stripes...,tomoPhantom_model16.mrc
0,model9,version_0,singleCET_FourierDataset,0.0,Compare Fourier and regular Bernoulli sampling.,,23.267,0.992,,,,,,
1,model9,version_1,singleCET_dataset,0.0,Compare Fourier and regular Bernoulli sampling.,{'p': '0.5'},25.995,0.996,19.76,0.991,24.658,0.997,tomoPhantom_model9_Poisson5000+Gauss5.mrc,tomoPhantom_model9.mrc
5,model9,version_2,singleCET_FourierDataset,0.0,Compare Fourier and regular Bernoulli sampling.,{'p': '0.5'},25.598,0.995,19.76,0.991,24.658,0.997,tomoPhantom_model9_Poisson5000+Gauss5.mrc,tomoPhantom_model9.mrc
3,model9,version_3,singleCET_FourierDataset,0.0,Check behaviour on the Fourier dataset when tu...,{'p': '0.5'},24.083,0.992,19.76,0.991,24.658,0.997,tomoPhantom_model9_Poisson5000+Gauss5.mrc,tomoPhantom_model9.mrc
