# Notebook to calculate CW-SSIM, PSNR, FID, and other Pytorch metrics:

### Calculate CW-SSIM for GT vs pix2pix, GT vs I2SB:
### from repo: https://github.com/jterrace/pyssim/tree/master

In [3]:
from ssim import SSIM
from PIL import Image, ImageOps
import os
from natsort import natsorted
import numpy as np
from tqdm import tqdm

In [4]:
GT_img_pth = r"\\10.99.68.51\Kyu\IHC2HE\Balanced_Aligned\dataset_v1_256x256\val\HE"
p2p_img_pth = r"\\10.99.68.51\Kyu\IHC2HE\Balanced_Aligned\dataset_v1_256x256\infer\pix2pix_sample1_only\val"
i2sb_img_pth = r"\\10.99.68.51\Kyu\IHC2HE\Balanced_Aligned\dataset_v1_256x256\infer\i2sb\unconditional\test-run-4\iter_24388"
GT_img_list = [os.path.join(GT_img_pth,x) for x in os.listdir(GT_img_pth) if x.endswith(".png")]
p2p_img_list = [os.path.join(p2p_img_pth,x) for x in os.listdir(p2p_img_pth) if x.endswith(".png")]
i2sb_img_list = [os.path.join(i2sb_img_pth,x) for x in os.listdir(i2sb_img_pth) if x.endswith(".png")]
GT_img_list = natsorted(GT_img_list)
p2p_img_list = natsorted(p2p_img_list)
i2sb_img_list = natsorted(i2sb_img_list)
GT_img_list = GT_img_list[:5952] #sample1, sample2A
p2p_img_list = p2p_img_list[:5952]

### Pix2Pix CWSSIM score:

In [3]:
pix2pix_cwssim_total = 0
for idx, images in tqdm(enumerate(GT_img_list),colour='red',desc='Images Processed', total = len(GT_img_list)):
    ref_img = Image.open(GT_img_list[idx])
    test_img = Image.open(p2p_img_list[idx])
    ref_img = ImageOps.grayscale(ref_img)
    test_img = ImageOps.grayscale(test_img)
    pix2pix_cwssim_total += SSIM(ref_img).cw_ssim_value(test_img)
print("Average CWSSIM score for Pix2Pix is {}".format(pix2pix_cwssim_total // len(GT_img_list)))

Images Processed: 100%|[31m██████████[0m| 5952/5952 [18:33<00:00,  5.34it/s]  

Average CWSSIM score for Pix2Pix is 0.0





In [8]:
pix2pix_cwssim_total/5952

0.5002800103454716

### I2SB CWSSIM score:

In [4]:
i2sb_cwssim_total = 0
for idx, images in tqdm(enumerate(GT_img_list),colour='red',desc='Images Processed', total = len(GT_img_list)):
    ref_img = Image.open(GT_img_list[idx])
    test_img = Image.open(i2sb_img_list[idx])
    ref_img = ImageOps.grayscale(ref_img)
    test_img = ImageOps.grayscale(test_img)
    i2sb_cwssim_total += SSIM(ref_img).cw_ssim_value(test_img)
print("Average CWSSIM score for I2SB is {}".format(i2sb_cwssim_total // len(GT_img_list)))

Images Processed: 100%|[31m██████████[0m| 5952/5952 [15:57<00:00,  6.21it/s]

Average CWSSIM score for I2SB is 0.0





In [7]:
i2sb_cwssim_total/5952

0.5174699244136098

# Calculate PSNR for GT vs pix2pix, GT vs I2SB:

In [11]:
def calculate_psnr(img1, img2, max_value=255):
    mse = np.mean((np.array(img1, dtype=np.float32) - np.array(img2, dtype=np.float32)) ** 2)
    if mse == 0:
        return 100
    return 20 * np.log10(max_value / (np.sqrt(mse)))

### Pix2Pix PSNR Score:

In [14]:
pix2pix_psnr_total = 0
for idx, images in tqdm(enumerate(GT_img_list),colour='red',desc='Images Processed', total = len(GT_img_list)):
    ref_img = Image.open(GT_img_list[idx])
    test_img = Image.open(p2p_img_list[idx])
    ref_img = ImageOps.grayscale(ref_img)
    test_img = ImageOps.grayscale(test_img)
    pix2pix_psnr_total += calculate_psnr(ref_img, test_img)
print("Average PSNR score for Pix2Pix is {}".format(pix2pix_psnr_total / len(GT_img_list)))

Images Processed: 100%|[31m██████████[0m| 5952/5952 [02:36<00:00, 37.92it/s]

Average PSNR score for Pix2Pix is 14.566912677808372





### I2SB PSNR Score:

In [17]:
i2sb_psnr_total = 0
for idx, images in tqdm(enumerate(GT_img_list),colour='red',desc='Images Processed', total = len(GT_img_list)):
    ref_img = Image.open(GT_img_list[idx])
    test_img = Image.open(i2sb_img_list[idx])
    ref_img = ImageOps.grayscale(ref_img)
    test_img = ImageOps.grayscale(test_img)
    i2sb_psnr_total += calculate_psnr(ref_img, test_img)
print("Average PSNR score for I2SB is {}".format(i2sb_psnr_total / len(GT_img_list)))

Images Processed: 100%|[31m██████████[0m| 5952/5952 [03:26<00:00, 28.85it/s]

Average PSNR score for I2SB is 14.40829158977186





# Calculate FID for GT vs pix2pix, GT vs I2SB:
### from repo: https://github.com/mseitzer/pytorch-fid

### Pix2Pix FID Score:

In [18]:
!python -m pytorch_fid \\shelter\Kyu\IHC2HE\Balanced_Aligned\dataset_v1_256x256\infer\tmp\GT \\shelter\Kyu\IHC2HE\Balanced_Aligned\dataset_v1_256x256\infer\tmp\pix2pix_sample1_only --device cuda:0

FID:  25.728143461570426



  0%|          | 0/120 [00:00<?, ?it/s]
  1%|          | 1/120 [00:26<52:13, 26.33s/it]
  2%|1         | 2/120 [00:26<21:27, 10.91s/it]
  2%|2         | 3/120 [00:26<11:40,  5.98s/it]
  3%|3         | 4/120 [00:26<07:05,  3.67s/it]
  4%|4         | 5/120 [00:26<04:35,  2.39s/it]
  5%|5         | 6/120 [00:26<03:04,  1.62s/it]
  6%|5         | 7/120 [00:27<02:07,  1.13s/it]
  7%|6         | 8/120 [00:27<01:30,  1.24it/s]
  8%|7         | 9/120 [00:27<01:05,  1.69it/s]
  8%|8         | 10/120 [00:27<00:48,  2.25it/s]
  9%|9         | 11/120 [00:27<00:37,  2.90it/s]
 10%|#         | 12/120 [00:27<00:30,  3.58it/s]
 11%|#         | 13/120 [00:27<00:24,  4.41it/s]
 12%|#1        | 14/120 [00:27<00:20,  5.06it/s]
 12%|#2        | 15/120 [00:28<00:18,  5.72it/s]
 13%|#3        | 16/120 [00:28<00:16,  6.31it/s]
 14%|#4        | 17/120 [00:28<00:15,  6.79it/s]
 15%|#5        | 18/120 [00:28<00:13,  7.34it/s]
 16%|#5        | 19/120 [00:28<00:13,  7.34it/s]
 17%|#6        | 20/120 [00:28<00:13,

### I2SB FID Score

In [19]:
!python -m pytorch_fid \\shelter\Kyu\IHC2HE\Balanced_Aligned\dataset_v1_256x256\infer\tmp\GT \\shelter\Kyu\IHC2HE\Balanced_Aligned\dataset_v1_256x256\infer\tmp\iter_24388 --device cuda:0

FID:  6.961478786222131



  0%|          | 0/120 [00:00<?, ?it/s]
  1%|          | 1/120 [00:20<40:41, 20.52s/it]
  2%|1         | 2/120 [00:20<16:44,  8.51s/it]
  2%|2         | 3/120 [00:20<09:07,  4.68s/it]
  3%|3         | 4/120 [00:20<05:33,  2.88s/it]
  4%|4         | 5/120 [00:20<03:36,  1.89s/it]
  5%|5         | 6/120 [00:21<02:26,  1.29s/it]
  6%|5         | 7/120 [00:21<01:42,  1.10it/s]
  7%|6         | 8/120 [00:21<01:14,  1.51it/s]
  8%|7         | 9/120 [00:21<00:54,  2.05it/s]
  8%|8         | 10/120 [00:21<00:41,  2.66it/s]
  9%|9         | 11/120 [00:21<00:32,  3.38it/s]
 10%|#         | 12/120 [00:21<00:25,  4.15it/s]
 11%|#         | 13/120 [00:21<00:21,  4.90it/s]
 12%|#1        | 14/120 [00:22<00:18,  5.64it/s]
 12%|#2        | 15/120 [00:22<00:17,  6.08it/s]
 13%|#3        | 16/120 [00:22<00:15,  6.81it/s]
 14%|#4        | 17/120 [00:22<00:14,  7.01it/s]
 15%|#5        | 18/120 [00:22<00:13,  7.57it/s]
 16%|#5        | 19/120 [00:22<00:12,  7.77it/s]
 17%|#6        | 20/120 [00:22<00:12,

### Trying DISTS (lower the better), LPIPS (lower the better), style loss (lower the better), VSI (0-1, higher the better), and haarpsi (0-1, higher the better) as a metric:

In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from piq import DISTS, LPIPS, vsi, haarpsi, StyleLoss
from torchvision.transforms import ToTensor
import torchvision.transforms as T

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
tensor_transform = T.Compose([ToTensor()]) # [0,255] -> [0,1], HWC -> CHW

### For Pix2Pix (Batch Size doesn't change the results, higher batch size for faster computation):

In [9]:
class TestDataset(Dataset):
    def __init__(self, GT_img_list, test_img_list, transforms=None):
        self.GT_img_list = GT_img_list
        self.test_img_list = test_img_list
        self.transforms = transforms

    def __len__(self):
        if len(self.GT_img_list) != len(self.test_img_list):
            assert("Length of Test and GT Image List Unequal")
        return len(GT_img_list)

    def __getitem__(self,idx):
        ref_img = np.array(Image.open(self.GT_img_list[idx]))
        test_img = np.array(Image.open(self.test_img_list[idx]))
        if self.transforms is not None:
            ref_img = self.transforms(ref_img)
            test_img = self.transforms(test_img)
        return ref_img, test_img

In [12]:
p2p_img_dataset = TestDataset(GT_img_list= GT_img_list, test_img_list= p2p_img_list, transforms = tensor_transform)
p2p_img_dataloader = DataLoader(dataset=p2p_img_dataset,
                                  batch_size=32,
                                  #pin_memory= true allows faster data transport from cpu to gpu
                                  num_workers=0, pin_memory=True, shuffle=False)
#check
ref_img_batch, p2p_img_batch = next(iter(p2p_img_dataloader))
print("Images have a tensor size of {}, and Labels have a tensor size of {}".
      format(ref_img_batch.size(), p2p_img_batch.size()))

Images have a tensor size of torch.Size([32, 3, 256, 256]), and Labels have a tensor size of torch.Size([32, 3, 256, 256])


In [13]:
pix2pix_dists_total = 0
pix2pix_lpips_total = 0
pix2pix_vsi_total = 0
pix2pix_haarpsi_total = 0
pix2pix_style_loss_total = 0
pbar = tqdm(enumerate(p2p_img_dataloader), total=len(p2p_img_dataloader), desc='Images Processed', colour='red')
for idx, (ref_img, p2p_img) in pbar:
    ref_img = ref_img.to(device)
    p2p_img = p2p_img.to(device)
    pix2pix_vsi_total += vsi(p2p_img,ref_img)
    pix2pix_dists_total += DISTS()(p2p_img,ref_img)
    pix2pix_lpips_total += LPIPS()(p2p_img,ref_img)
    pix2pix_haarpsi_total += haarpsi(p2p_img,ref_img)
    pix2pix_style_loss_total += StyleLoss()(p2p_img,ref_img)
print("Average VSI score for Pix2Pix is {}".format(pix2pix_vsi_total / len(p2p_img_dataloader)))
print("Average DISTS score for Pix2Pix is {}".format(pix2pix_dists_total / len(p2p_img_dataloader)))
print("Average LPIPS score for Pix2Pix is {}".format(pix2pix_lpips_total / len(p2p_img_dataloader)))
print("Average HAARPSI score for Pix2Pix is {}".format(pix2pix_haarpsi_total / len(p2p_img_dataloader)))
print("Average Style score for Pix2Pix is {}".format(pix2pix_style_loss_total / len(p2p_img_dataloader)))

Images Processed: 100%|[31m██████████[0m| 186/186 [11:24<00:00,  3.68s/it]

Average VSI score for Pix2Pix is 0.834528923034668
Average DISTS score for Pix2Pix is 0.2199782282114029
Average LPIPS score for Pix2Pix is 0.4710122048854828
Average HAARPSI score for Pix2Pix is 0.35326284170150757
Average Style score for Pix2Pix is 20185016.0





### For I2SB:

In [15]:
i2sb_img_dataset = TestDataset(GT_img_list= GT_img_list, test_img_list= i2sb_img_list, transforms = tensor_transform)
i2sb_img_dataloader = DataLoader(dataset=i2sb_img_dataset,
                                  batch_size=32,
                                  #pin_memory= true allows faster data transport from cpu to gpu
                                  num_workers=0, pin_memory=True, shuffle=False)
#check:
ref_img_batch, i2sb_img_batch = next(iter(i2sb_img_dataloader))
print("Images have a tensor size of {}, and Labels have a tensor size of {}".
      format(ref_img_batch.size(), i2sb_img_batch.size()))

Images have a tensor size of torch.Size([32, 3, 256, 256]), and Labels have a tensor size of torch.Size([32, 3, 256, 256])


In [16]:
i2sb_dists_total = 0
i2sb_lpips_total = 0
i2sb_vsi_total = 0
i2sb_haarpsi_total = 0
i2sb_style_loss_total = 0
pbar = tqdm(enumerate(i2sb_img_dataloader), total=len(i2sb_img_dataloader), desc='Images Processed', colour='red')
for idx, (ref_img, i2sb_img) in pbar:
    ref_img = ref_img.to(device)
    i2sb_img = i2sb_img.to(device)
    i2sb_vsi_total += vsi(i2sb_img,ref_img)
    i2sb_dists_total += DISTS()(i2sb_img,ref_img)
    i2sb_lpips_total += LPIPS()(i2sb_img,ref_img)
    i2sb_haarpsi_total += haarpsi(i2sb_img,ref_img)
    i2sb_style_loss_total += StyleLoss()(i2sb_img,ref_img)
print("Average VSI score for I2SB is {}".format(i2sb_vsi_total / len(i2sb_img_dataloader)))
print("Average DISTS score for I2SB is {}".format(i2sb_dists_total / len(i2sb_img_dataloader)))
print("Average LPIPS score for I2SB is {}".format(i2sb_lpips_total / len(i2sb_img_dataloader)))
print("Average HAARPSI score for I2SB is {}".format(i2sb_haarpsi_total / len(i2sb_img_dataloader)))
print("Average Style score for I2SB is {}".format(i2sb_style_loss_total / len(i2sb_img_dataloader)))

Images Processed: 100%|[31m██████████[0m| 186/186 [11:32<00:00,  3.72s/it]

Average VSI score for I2SB is 0.8354562520980835
Average DISTS score for I2SB is 0.18133263289928436
Average LPIPS score for I2SB is 0.4468766748905182
Average HAARPSI score for I2SB is 0.3566141426563263
Average Style score for I2SB is 7872890.5





In [18]:
### VSI and HAARPSI, like PSNR and SSIM don't seem to be able to differentiate the two! While DISTS LPIPS and Style, which are perceptual based using VGG16, kind of similar to FID, are able to differentiate the two!

In [28]:
from piq import PieAPP, ssim

In [30]:
pix2pix_pieapp_total = 0
pix2pix_ssim_total = 0
pbar = tqdm(enumerate(p2p_img_dataloader), total=len(p2p_img_dataloader), desc='Images Processed', colour='red')
for idx, (ref_img, p2p_img) in pbar:
    ref_img = ref_img.to(device)
    p2p_img = p2p_img.to(device)
    pix2pix_pieapp_total += PieAPP()(p2p_img,ref_img)
    pix2pix_ssim_total += ssim(p2p_img,ref_img)
print("Average PieAPP score for Pix2Pix is {}".format(pix2pix_pieapp_total / len(p2p_img_dataloader)))
print("Average SSIM score for Pix2Pix is {}".format(pix2pix_ssim_total / len(p2p_img_dataloader)))


Images Processed: 100%|[31m██████████[0m| 186/186 [22:16<00:00,  7.19s/it]

Average PieAPP score for Pix2Pix is 2.0399372577667236
Average SSIM score for Pix2Pix is 0.18936103582382202





In [31]:
i2sb_pieapp_total = 0
i2sb_ssim_total = 0
pbar = tqdm(enumerate(i2sb_img_dataloader), total=len(i2sb_img_dataloader), desc='Images Processed', colour='red')
for idx, (ref_img, i2sb_img) in pbar:
    ref_img = ref_img.to(device)
    i2sb_img = i2sb_img.to(device)
    i2sb_pieapp_total += PieAPP()(i2sb_img,ref_img)
    i2sb_ssim_total += ssim(i2sb_img,ref_img)
print("Average PieAPP score for I2SB is {}".format(i2sb_pieapp_total / len(i2sb_img_dataloader)))
print("Average SSIM score for I2SB is {}".format(i2sb_ssim_total / len(i2sb_img_dataloader)))


Images Processed: 100%|[31m██████████[0m| 186/186 [20:08<00:00,  6.50s/it]

Average PieAPP score for I2SB is 1.549119234085083
Average SSIM score for I2SB is 0.18343544006347656



