In [None]:
'''
Code to perform evaluation on 2DSyn_GC dataset
'''

import os
os.chdir("..")

import numpy as np
from copy import deepcopy
import itertools
import torch
import tifffile
import random
from scipy.ndimage import uniform_filter

import data
from data.sampler.synGC3D_sampler import SynGC3DSampler
from options.options import Options
from util.mesh_handler import *
from util.util import *
from metric_funcs import *

from scipy.ndimage import uniform_filter

from models.biospade_model import BioSPADEModel
import models.networks as networks

import matplotlib.pyplot as plt

n_patches = 100 # Number of patches to test
n_instances = 25 # Number of instances per patch to test
name = 'Syn3D_GC' # Name of experiments

# Generate Real Data
def generate(stack, powers, frames, n_instances=1):
    real = np.zeros([n_instances, len(powers), *stack.shape[-3:]])
    
    for i in range(n_instances):
        for j in range(len(powers)):
            real[i,j] = add_gauss_noise(stack, powers[j], frames[j], org_opt.std_scalar)
            
    return real

# Generate noise for each layer of G
def generate_noise(sz):
    xy_sz = sz[-2:]
    
    noise = []
    for i in range(num_up_layers+1):
        x_sz = xy_sz[0]//(2**i)
        y_sz = xy_sz[1]//(2**i)

        noise += [torch.randn([*sz[:2],x_sz, y_sz]).cuda()]
    return noise

# Generate fake images
def forward(model, vox, powers, frames, z_pos, n_instances=1, rand_noise=True, is_cycle=False):
    if len(powers) != len(frames):
        raise ValueError('Powers and Frames are not equal in length')

    fake = torch.zeros([n_instances, len(powers), vox.shape[-3]-2*sampler.z_pad, *vox.shape[-2:]])
    fake_mu = torch.zeros_like(fake)
    noise = generate_noise([len(powers),*vox.shape[-3:]])
    for i in range(n_instances):
        data = {'mesh_semantics': len(powers)*[vox],
                'power': powers,
                'frames': frames,
                'z_pos': z_pos}
        data = tensorize_dict(data)
        if not rand_noise:
            data['noise'] = noise
            
        if is_cycle:
            model.set_input(data)
            fake_ = model.forward(True)
            fake[i] = torch.relu(fake_)
        else:
            fake_, fake_mu_ = model(data,'inference')
            fake[i], fake_mu[i] = fake_[:,0], fake_mu_[:,0]
    return fake.cpu().numpy(), fake_mu.cpu().numpy()

In [None]:
# Set Parameters

tag = 'synGC3D'
org_opt = Options('options/test_options.yaml', tag)
org_opt.train_mode = 'GAN'
org_opt.initialize()

org_opt.dataset_mode = tag
org_opt.name = name
org_opt.how_many_patches = n_patches

num_up_layers = networks.generator.compute_latent_vector_size(org_opt)

dataloader, dataset = data.create_dataloader(org_opt, 'all')
files = dataset.mesh_paths

sampler = SynGC3DSampler(org_opt)
sampler.load(files)

sampler.vox = sampler.vox[:,:,:,::-1]
sampler.stack = sampler.stack[:,:,:,::-1]

VGG_Loss = networks.VGGLoss(org_opt.gpu_ids)
style_combs = np.asarray(list(itertools.product(org_opt.powers, org_opt.frames)))

In [None]:
# Generate testing data

input = np.zeros((org_opt.how_many_patches, org_opt.in_Dslices+org_opt.in_Gslices-1, *org_opt.crop_xy_sz))
gt = np.zeros((org_opt.how_many_patches, org_opt.in_Dslices, *org_opt.crop_xy_sz))
pre_stack = np.zeros((org_opt.how_many_patches, org_opt.in_Dslices, *org_opt.crop_xy_sz))
stack = np.zeros((org_opt.how_many_patches, n_instances, len(style_combs), org_opt.in_Dslices, *org_opt.crop_xy_sz))
z_arr = np.zeros((org_opt.how_many_patches,))

for i in range(org_opt.how_many_patches):
    x, xy, y, z = sampler.sample()
    input[i] = x
    pre_stack[i] = y
    z_arr[i] = z
    gt[i] = uniform_filter(x[sampler.z_pad:-sampler.z_pad], 3)>1e-5
    
for i, pre_stack_ in enumerate(pre_stack):
    stack[i] = generate(pre_stack_, style_combs[:,0], style_combs[:,1].astype(int), n_instances=n_instances)
real = stack.swapaxes(0,2)
avg_real = real.mean(1)
text_real = blur_all(real)[:,0]

In [None]:
# Perform Evaluation

out_sz = (org_opt.number_of_experiments, len(style_combs))
losses = {'NMSE': np.zeros(out_sz),
          'PSNR': np.zeros(out_sz),
          'JSD': np.zeros(out_sz),
          'GLCM': np.zeros(out_sz),
          'LBP': np.zeros(out_sz),
          'COOC': np.zeros(out_sz),
          'Frangi': np.zeros(out_sz)}

real_PSNR = calc_PSNR(real)
real_GLCM = calc_GLCM(text_real) # Calculate real GLCM
print('Real GLCM Done')
real_LBP = calc_LBP(text_real) # Calculate real LBP
print('Real LBP Done')
real_COOC = calc_COOC(text_real) # calculate real COOC
print('Real COOC Done')

In [None]:
for exp in range(0, org_opt.number_of_experiments):
    fake_patches = np.zeros([*stack.shape])

    # Get experiment details
    opt = deepcopy(org_opt)
    exp_str = ''
    if not org_opt.run_all:
        exp_str = '(exp'+str(exp)+')'
        opt.set_experiment(exp)
    opt.name = org_opt.name+exp_str

    # Load model
    print('Evaluating:', opt.name)
    model = BioSPADEModel(opt)
    model.eval()
    
    # Generate data
    for i, input_ in enumerate(input):
        fake_patches[i],_  = forward(model, input_[None,None],
                                     style_combs[:,0], style_combs[:,1], 
                                     [z_arr[i]], n_instances=n_instances, rand_noise=False)
        
    fake = fake_patches.swapaxes(0,2)
    avg_fake = fake.mean(1)
    text_fake = blur_all(fake)[:,0]
    
    fake_LBP = calc_LBP(text_fake)
    fake_GLCM = calc_GLCM(text_fake)
    fake_COOC = calc_COOC(text_fake)
    
    losses['NMSE'][exp] = MSE(avg_fake*gt, avg_real*gt, normalize=True) # Compute NMSE
    losses['GLCM'][exp] = MSE(fake_GLCM, real_GLCM,axis=(1,2,3,4)) # Compute MSE of GLCM
    losses['LBP'][exp] = mult_JSDs(fake_LBP, real_LBP) # Compute JSD of LBP
    losses['COOC'][exp] = MSE(fake_COOC, real_COOC, axis=(1,2,3,4,5,6)) # MSE of COOC
    losses['PSNR'][exp] = (calc_PSNR(fake)-real_PSNR)**2 # Compute MSE of PSNR
    losses['JSD'][exp] = compare_hist(fake, stack) # Compute JSD of pixel values


In [None]:
# Print Results (Note there will be variance trial to trail since we are sampling patches)

keys = ['NMSE', 'LBP', 'GLCM', 'COOC', 'PSNR', 'JSD']
scale = {'NMSE': 1e1, 'GLCM': 1e8, 'LBP': 1e3, 'PSNR':1e0, 'JSD': 1e2, 'COOC':1e8, 'Frangi':1e7}

ln = ''
for key in keys:
    ln += key+' & '
print(ln)

for i_loss in range(org_opt.number_of_experiments):
    ln = 'EXP '+str(i_loss)+': ' # Set to '' if you want to be able to copy and past table into excel
    for key in keys:
        val = losses[key][i_loss].mean()
        ln += '%0.3f & ' % (val*scale[key])
    print(ln)