In [1]:
import numpy as np
import matplotlib.pyplot as plt

from skimage import data, img_as_float
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import mean_squared_error
from skimage.metrics import peak_signal_noise_ratio as psnr
import cv2
import os
import lpips
import torch

In [11]:
class NerfMetrics():
    def __init__(self):
        super().__init__()

    def computeMetrics(self, imgs_list, gt_list):
        metrics_table = {}
        ssim_list = []
        mse_list = []
        psnr_list = []
        lpips_list = []
        loss_fn_alex = lpips.LPIPS(net='alex')
        #loss_fn_alex.net = loss_fn_alex.net.float()

        for i in range(len(imgs_list)):
            max = (imgs_list[i]*2-1).max()
            min = (imgs_list[i]*2-1).min()
            # print(max)
            # print(min)
            ssim_ = ssim(imgs_list[i], gt_list[i], data_range=max - min, channel_axis=2)
            mse_ = mean_squared_error(imgs_list[i], gt_list[i])
            psnr_ = psnr(imgs_list[i], gt_list[i])

            ## image from -1 to 1
            lpips_ = loss_fn_alex(torch.Tensor(imgs_list[i]*2-1).permute(2, 0, 1).unsqueeze(0), torch.Tensor(gt_list[i]*2-1).permute(2, 0, 1).unsqueeze(0)).item()

            ssim_list.append(ssim_)
            mse_list.append(mse_)
            psnr_list.append(psnr_)
            lpips_list.append(lpips_)
            length = len(ssim_list)

            # ssim_list.append(sum(ssim_list)/length)
            # mse_list.append(sum(mse_list)/length)
            # psnr_list.append(sum(psnr_list)/length)
            # lpips_list.append(sum(lpips_list)/length)

        metrics_table['ssim'] = ssim_list
        metrics_table['mse'] = mse_list
        metrics_table['psnr'] = psnr_list
        metrics_table['lpips'] = lpips_list

        return metrics_table
    
    def load_imgs(self, img_path, gt_path, dsize=(473,265)):
        imgs_list = []
        gt_list = []
        for filename in os.listdir(img_path):
            img = cv2.imread(os.path.join(img_path,filename))
            gt = cv2.imread(os.path.join(gt_path,'gt_' + filename))
            if img is not None:
                img = cv2.resize(img, dsize)
                img = img /255.0
                imgs_list.append(img)
                gt = cv2.resize(gt, dsize)
                gt = gt /255.0
                gt_list.append(gt)
        # for filename in os.listdir(gt_path):
        #     gt = cv2.imread(os.path.join(gt_path,filename))
        #     if gt is not None:
        #         gt = cv2.resize(gt, dsize)
        #         gt = gt /255.0
        #         gt_list.append(gt)

        return imgs_list, gt_list


In [12]:
a = NerfMetrics()

In [13]:
path1 = '/home/fusang/Desktop/Nerf/torch-ngp-la/test_jardinMines1_scale_0.1/pre'
path2 = '/home/fusang/Desktop/Nerf/torch-ngp-la/test_jardinMines1_scale_0.1/gt'

In [14]:
img_list, gt_list = a.load_imgs(path1, path2)

In [15]:
for i in range(10):
#i = 4
    img = img_list[i]       
    gt = gt_list[i]
    cv2.imshow('image',img)  
    cv2.imshow('gt',gt)
    cv2.waitKey(0) # waits until a key is pressed
    cv2.destroyAllWindows()
 

In [16]:
metrcis_table = a.computeMetrics(img_list, gt_list)

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /home/fusang/anaconda3/envs/ngp/lib/python3.9/site-packages/lpips/weights/v0.1/alex.pth


In [22]:
sum(metrcis_table['ssim'])/len((metrcis_table['ssim']))

0.38565666364224915

In [23]:
sum(metrcis_table['psnr'])/len((metrcis_table['psnr']))

13.221283408964693

In [24]:
sum(metrcis_table['mse'])/len((metrcis_table['mse']))

0.04970115035623289

In [26]:
sum(metrcis_table['lpips'])/len((metrcis_table['lpips']))

0.6958451188235983

In [17]:
import csv

def createCSV(outputpath, metrcis_table):
    field_names = ['ssim', 'mse', 'psnr', 'lpips']
       
    with open(outputpath, 'w') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames = field_names)
        writer.writeheader()
        writer.writerow(metrcis_table)


In [18]:
createCSV("./testjardin1.csv", metrcis_table)