In [7]:
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as Data
import tifffile
from Utils.DataAug3D import *
from torch.cuda.amp import GradScaler, autocast
from Models.LargePNet3D import *
import os
import matplotlib.pyplot as plt
from tqdm import tqdm
import glob
from Utils.ssim3d import *
from Utils.FastMSSSIM import *
import pandas as pd

In [2]:
def rescale(restored, gt):  
    '''Affine rescaling'''  
    batch_size = restored.size(0)  
    restored_flat = restored.view(batch_size, -1)   
    gt_flat = gt.view(batch_size, -1)  
    mean_restored = restored_flat.mean()  
    mean_gt = gt_flat.mean()  
    cov_restored_gt = torch.mean((restored - mean_restored) * (gt - mean_gt))  
    var_restored = torch.mean((restored - mean_restored) ** 2)   
    a = cov_restored_gt / var_restored  
    b = mean_gt - a * mean_restored  
    return a * restored + b 

In [10]:
head_dir = r"Data\BackgroundRemoval\BackgroundRemoval\MousePhalloidin\Testing"
GT_path = head_dir + r'\TestGT'
Raw_path = head_dir + r'\TestNoisy'
save_dir = head_dir + '\\' + 'output'
frame = len([name for name in os.listdir(Raw_path) if os.path.isfile(os.path.join(Raw_path, name))])
if not os.path.exists(save_dir):
    os.mkdir(save_dir)

In [None]:
model_label = 'LargePNet3D'
model_dir = r'TrainedModel\BackgroundRemoval\MousePhalloidin\best_model.pth'
model = LargePNet3D(1,1,15,4)
model.load_state_dict(torch.load(model_dir))
save_path = save_dir + '\\' + model_label
device = 'cuda'
PSNR = []
SSIM = []
NRMSE = []
criterionSSIM = SSIM3D()
criterionMSSSIM = MS_SSIM(channel=1)
if not os.path.exists(save_path):
    os.mkdir(save_path)
Endstr ='.tif'
pbar = tqdm(total=frame, desc="Processing images")
for i in range(frame):

    numstr = np.array2string(np.array(i+1))
    b_x_img = tifffile.imread(Raw_path + '\\' + np.array2string(np.array(i+1)) + '.tif')
    b_y_img = tifffile.imread(GT_path + '\\' + np.array2string(np.array(i+1)) + '.tif')
    
    d1, d2, d3 = b_x_img.shape
    b_x = np.array(b_x_img)
    b_y = np.array(b_y_img)
    b_x = (b_x - b_x.min())/(b_x.max() - b_x.min())
    b_y = (b_y - b_y.min())/(b_y.max() - b_y.min())
    
    b_x = torch.tensor(b_x, dtype = torch.float16)
    b_y = torch.tensor(b_y, dtype = torch.float16)
    
    b_x = b_x.unsqueeze(0).unsqueeze(0)
    b_y = b_y.unsqueeze(0).unsqueeze(0)
    model.eval()
    model.to(device)
    with torch.no_grad():
        with autocast():
            b_x = b_x.to(device)
            b_y = b_y.to(device)
            b_x = model(b_x)
            b_x = rescale(b_x, b_y)
            #b_x = (b_x - b_x.min())/(b_x.max() - b_x.min())
            nrmse = torch.sqrt(torch.mean((b_x - b_y) ** 2))
            NRMSE.append(nrmse.cpu())            
            ssim = criterionMSSSIM(b_x.float().squeeze(0).squeeze(0).unsqueeze(1), b_y.float().squeeze(0).squeeze(0).unsqueeze(1))
            SSIM.append(ssim.cpu())
            
            
            MSE = torch.mean((b_x-b_y)**2).cpu()
            PSNR.append(10*torch.log10(1/MSE))
 
            
            b_x = b_x.cpu()
            b_y = b_y.cpu()
            
    # SSIM.append(criterionSSIM(b_x.float(), b_y.float()))
    b_x = np.array(b_x.squeeze(0).squeeze(0))       
    b_y = np.array(b_y.squeeze(0).squeeze(0)) 
    b_x[b_x<0] = 0
    b_x = np.array(b_x*65535).astype(np.uint16)
    tifffile.imwrite(save_path+ '\\' +numstr+model_label+Endstr, b_x)
    pbar.update(1)
pbar.close()

PSNR = np.array(PSNR)
df = pd.DataFrame(PSNR)

df.to_excel(save_dir + r'\LPNet3DPSNR.xlsx', index=False, header=False)
print(PSNR.mean())

SSIM = np.array(SSIM)
df = pd.DataFrame(SSIM)

df.to_excel(save_dir + r'\LPNet3DSSIM.xlsx', index=False, header=False)
print(SSIM.mean())

NRMSE = np.array(NRMSE)
df = pd.DataFrame(NRMSE)

df.to_excel(save_dir + r'\LPNet3DNRMSE.xlsx', index=False, header=False)
print(NRMSE.mean())

In [9]:
PSNR = np.array(PSNR)
print(PSNR.mean())

35.122353
