In [None]:
'''
Code to perform evaluation on 3DFM_AC dataset
'''

import os
os.chdir("..")

import numpy as np
from copy import deepcopy
import itertools
import torch
import tifffile
import random
from PIL import Image, ImageDraw, ImageFont

import cv2

from scipy.ndimage import uniform_filter

import data
from data.sampler.fmAC3D_sampler import FMAC3DSampler
from options.options import Options
from util.mesh_handler import *
from util.util import *
from metric_funcs import *

from models.biospade_model import BioSPADEModel
from models.segment_model import SegmentModel

import models.networks as networks

import matplotlib.pyplot as plt

name = 'FM3D_AC' # Name of experiments

# Generate fake images
def forward(vox, powers, frames, z_pos, is_cycle=False):
    n_samples = vox.shape[0]
    in_Dslices = vox.shape[-3]-org_opt.in_Gslices+1

    fake = torch.zeros([len(powers), n_samples, 1, in_Dslices, *vox.shape[-2:]])
    fake_mu = torch.zeros_like(fake)
    
    for i in range(len(powers)):
        data = {'mesh_semantics': vox,
                'power': n_samples*[powers[i]],
                'frames': n_samples*[frames[i]],
                'z_pos': z_pos}
        
        if is_cycle:
            for z in range(0, in_Dslices-1, org_opt.in_Dslices):
                data['mesh_semantics'] = vox[:,:,z:z+org_opt.in_Gslices+org_opt.in_Dslices-1]
                data = tensorize_dict(data)
                
                model_Cycle.set_input(data)
                fake_ = model_Cycle.forward(True)
                fake[i,:,:,z:z+org_opt.in_Dslices] = torch.relu(fake_[:,None])
        else:
            data = tensorize_dict(data)
            fake[i], fake_mu[i] = model_Gan(data,'inference')
    return fake.cpu().numpy(), fake_mu.cpu().numpy()

# Code to Generate Testing Data
def get_data(loader):
    sample = next(iter(loader))
    
    stack_ = sample['real_stack'].numpy()[0]
    org_stack = sample['real_stack'].reshape((-1,*stack_.shape[1:])).numpy()
    
    vox_ = sample['mesh_semantics'].numpy()[0]
    vox = sample['mesh_semantics'].reshape((-1,*vox_.shape[1:])).numpy()
    
    z_sz = vox.shape[-3]-2*sampler.z_pad
    xy_sz = vox.shape[-2:]

    gt = sample['real_slices'].reshape((-1,z_sz, *xy_sz)).numpy()
    mask = np.zeros_like(gt)
    for i, gt_ in enumerate(gt):
        mask[i] = uniform_filter(gt_, 5)>1e-5

    z_pos = sample['z_pos'].reshape((-1,)).numpy()
    
    stack_shape = vox.shape
    stack = np.zeros((len(style_combs), stack_shape[0], stack_shape[1]-2*sampler.z_pad, *xy_sz))
    for i, style in enumerate(style_combs):
        pow_ = int(style[0])
        frame_ = int(style[1])
        stack_ = org_stack[:,pow_,:,:frame_].mean(2)
        stack[i] = stack_[:]
        
    return stack, gt, mask, vox, z_pos

In [None]:
# Set Parameters

org_opt = Options('options/test_options.yaml', 'fmAC3D')
org_opt.initialize()

org_opt.train_mode = 'SEG'
org_opt.dataset_mode = 'fmAC3D'
org_opt.name = name

# Create loader for standard eval metrics (data taken from training dataset)
org_opt.batch_size = 13 # batch size
org_opt.samples_per_instance = 8 # samples per loaded mesh

dataloader, dataset = data.create_dataloader(org_opt,'train')

# Create loader for segmentation metrics (data taking from testing dataset)
org_opt.batch_size = 4
org_opt.samples_per_instance = 52
org_opt.in_Dslices = org_opt.in_Sslices
org_opt.prob_new_mesh = 0
dataloader_seg, dataset_seg = data.create_dataloader(org_opt,'test')

sac_files = dataset.stack_paths
gt_files = dataset.gt_paths
mesh_files = dataset.mesh_paths

z_pad = org_opt.delta_slice*(org_opt.in_Gslices//2)
sampler = FMAC3DSampler(org_opt)

style_combs = np.asarray(list(itertools.product(org_opt.powers, org_opt.frames)))
pow_inds = []
for i in range(5):
    pow_inds += list(np.array([0,3,7])+8*i)

In [None]:
# Generate testing data for texture eval

stack, gt, real_mask, vox, z_pos = get_data(dataloader)

fake_mask = np.zeros_like(gt)
for i, gt_ in enumerate(vox):
    fake_mask[i] = uniform_filter(gt_[3:-3]/vox.max(), 3)>1e-5

avg_real = stack[pow_inds]
text_real = blur_all(avg_real)


# Generate testing data for segmentation eval

seg_stack, seg_gt, _, _, seg_z_pos = get_data(dataloader_seg) # Might not fit in GPU, so segmentation eval must be done seperately


In [None]:
# Perform Evaluation

out_sz = (org_opt.number_of_experiments, len(style_combs))
pow_out_sz = (org_opt.number_of_experiments, len(pow_inds))
losses = {'NMSE': np.zeros(pow_out_sz),
          'PSNR': np.zeros(out_sz),
          'JSD': np.zeros(out_sz),
          'GLCM': np.zeros(pow_out_sz),
          'LBP': np.zeros(pow_out_sz),
          'COOC': np.zeros(pow_out_sz),
          'SEG': np.zeros([*out_sz, 3]),}

real_auto = calc_auto(avg_real) # Calculate real auto_correlation
real_LBP = calc_LBP(text_real,real_mask) # Calculate real LBP
real_GLCM = calc_GLCM(text_real,real_mask) # Calculate real GLCM
real_COOC = calc_COOC(text_real, real_mask>.5) # Calculate real COOC

calc_seg=False
for exp in range(0,org_opt.number_of_experiments):
    # Get experiment details
    opt = deepcopy(org_opt)
    exp_str = ''
    if not org_opt.run_all:
        exp_str = '(exp'+str(exp)+')'
        opt.set_experiment(exp)
    opt.name = org_opt.name+exp_str

    # Calculate Segmentation results
    if calc_seg:
        for i in range(3):
            opt.seg_instance = i
            model_Seg = SegmentModel(opt)
            model_Seg.eval()

            acc = SEG(model_Seg, seg_stack[:,:,None], seg_gt, style_combs[:,0], style_combs[:,1], seg_z_pos, )
            losses['SEG'][exp,:,i] = acc
    
    # Load model
    print('Evaluating:', opt.name)
    model_Gan = BioSPADEModel(opt)
    model_Gan.eval()

    # Generate Data
    fake, channels = forward(vox[:,None], style_combs[:,0], style_combs[:,1], z_pos)
    avg_fake = fake[pow_inds,:,0]
    text_fake = blur_all(avg_fake)
    
    # Evaluate
    fake_auto = calc_auto(avg_fake)
    fake_LBP = calc_LBP(text_fake, fake_mask) # Use fake_mask as we only want to compute metric around dendrites
    fake_GLCM = calc_GLCM(text_fake, fake_mask) # Use fake_mask as we only want to compute metric around dendrites
    fake_COOC = calc_COOC(text_fake, fake_mask>.5) # Use fake_mask as we only want to compute metric around dendrites
    
    losses['NMSE'][exp] = MSE(fake_auto, real_auto, normalize=True) # Compute NMSE of autocorrelation
    losses['GLCM'][exp] = MSE(fake_GLCM, real_GLCM,axis=(1,2,3,4)) # Compute MSE of GLCM
    losses['LBP'][exp] = mult_JSDs(fake_LBP, real_LBP) # Compute JSD of LBP
    losses['COOC'][exp] = MSE(fake_COOC, real_COOC, axis=(1,2,3,4,5,6)) # Compute MSE of COOC
    losses['JSD'][exp] = compare_hist(fake, stack) # Compute JSD of pixels

In [None]:
# Print Results (Note there will be variance trial to trail since we are sampling patches)

keys = ['NMSE', 'LBP', 'GLCM', 'COOC', 'JSD', 'SEG']
scale = {'NMSE': 1e1, 'GLCM': 1e8, 'LBP': 1e3, 'PSNR':1e1, 'JSD': 1e2, 'COOC':1e8, 'Frangi':1e5, 'SEG': 100}

ln = ''
for key in keys:
    ln += key+' & '
print(ln)

for i_loss in range(org_opt.number_of_experiments):
    ln = ''
    for key in keys:
        val = losses[key][i_loss].mean()
        if key == 'SEG':
            val = losses[key][i_loss].mean(0).mean(0)
            var = losses['SEG'][i_loss].mean(0).std(0)
            ln += '%0.3f +/- %0.3f' % (val*scale['SEG'], var*scale['SEG'])
        else:
            ln += '%0.3f ' % (val*scale[key])
    print(ln)