In [22]:
from google.colab import drive
drive.mount('/content/drive')

import os 
os.chdir('/content/drive/My Drive/pSp_self_customize')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [23]:
import glob
import numpy as np

from tqdm import tqdm

import pytorch_ssim

import torch
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms

from PIL import Image


In [24]:
original_path = "image_datasets/ffhq0k_5k"
generated_path = "image_datasets/inference_results"
#generated_path = "image_datasets/blurred_images"
#generated_path = "image_datasets/mosaiced_images"
#generated_path = "image_datasets/inference_results"

IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG',
    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff'
]


In [25]:
def is_image_file(filename):
  return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
  

In [26]:
class ImagesPathDataset(torch.utils.data.Dataset):
    def __init__(self, path, transforms=None):
        files_path = []
        assert os.path.isdir(path), "{} is not a valid directory".format(path)
        for root, dirs, files in sorted(os.walk(path)):
          for file_name in files:
            if is_image_file(file_name):
              file_path = os.path.join(root, file_name)
              files_path.append(file_path)

        self.files = files_path
        self.transforms = transforms

    def __len__(self):
        return len(self.files)

    def __getitem__(self, i):
        path = self.files[i]
        img = Image.open(path).convert('RGB')
        if self.transforms is not None:
            img = self.transforms(img)
        return img


In [27]:
data_transforms = transforms.Compose([transforms.ToTensor()])

ori_dataset_fn = ImagesPathDataset(path=original_path, transforms=data_transforms)
gen_dataset_fn = ImagesPathDataset(path=generated_path, transforms=data_transforms)

ori_dataset = ori_dataset_fn
gen_dataset = gen_dataset_fn
print(len(ori_dataset),len(gen_dataset))

ori_loader = torch.utils.data.DataLoader(
    ori_dataset,
    shuffle=False,
    batch_size=1,
    drop_last=False,
    num_workers=4
)

gen_loader = torch.utils.data.DataLoader(
    gen_dataset,
    shuffle=False,
    batch_size=1,
    drop_last=False,
    num_workers=4
)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  
print(device)


5000 5000
cuda:0


In [28]:
ssim_values = []
try:
  with tqdm(zip(ori_loader, gen_loader)) as t:
    for batch in t:
      ssim_value = pytorch_ssim.ssim(img1=batch[0][0].to(device), img2=batch[1][0].to(device))
      ssim_values.append(ssim_value.cpu().numpy())
except KeyboardInterrupt:
  t.close()
  raise
  
print(len(ssim_values))


5000it [10:24,  8.01it/s]

5000





In [29]:
ssim_avg_value = np.mean(ssim_values, axis=0)
print("SSIM: ", ssim_avg_value)


SSIM:  0.4827097


In [None]:
# mosaiced_images SSIM: 0.39804336
# blurred_images SSIM: 0.42993903
# inference_results SSIM:  0.4827097
