In [1]:
import os
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import numpy as np
import cv2
import random
from torchjpeg import dct
import torch
import torch.fft
from torch.nn import functional as F
import torchvision.transforms as transforms
from pycocotools.coco import COCO

In [2]:
def dct_transform(x, chs_remove=None, chs_pad=False,
                  size=8, stride=8, pad=0, dilation=1, ratio=8, chs_select = 9):
    """
        Transform a spatial image into its frequency channels.
        Prune low-frequency channels if necessary.
    """

    # assert x is a (3, H, W) RGB image
    assert x.shape[1] == 3

    # convert the spatial image's range into [0, 1], recommended by TorchJPEG
    x = x / 255.0

    # up-sample
    x = F.interpolate(x, scale_factor=ratio, mode='bilinear', align_corners=True)

    # convert to the YCbCr color domain, required by DCT
    x = x * 255
    x = dct.to_ycbcr(x)
    x = x - 128

    # perform block discrete cosine transform (BDCT)
    b, c, h, w = x.shape
    h_block = h // stride
    w_block = w // stride
    x = x.view(b * c, 1, h, w)
    x = F.unfold(x, kernel_size=(size, size), dilation=dilation, padding=pad, stride=(stride, stride))
    x = x.transpose(1, 2)
    x = x.view(b, c, -1, size, size)
    x_freq = dct.block_dct(x)
    x_freq = x_freq.view(b, c, h_block, w_block, size * size).permute(0, 1, 4, 2, 3)

    # prune channels
    if chs_remove is not None:
        channels = list(set([i for i in range(64)]) - set(chs_remove))
        #channels = [0, 1, 2, 3, 8, 9, 10, 16, 17, 24]
        selected_channels = random.sample(channels, chs_select)
        if not chs_pad:
            # simply remove channels
            x_freq = x_freq[:, :, selected_channels, :, :]
        else:
            # pad removed channels with zero, helpful for visualization
            x_freq[:, :, channels] = 0

    # stack frequency channels from each color domain
    x_freq = x_freq.reshape(b, -1, h_block, w_block)

    return x_freq

def idct_transform(x, size=8, stride=8, pad=0, dilation=1, ratio=8):
    """
        The inverse of DCT transform.
        Transform frequency channels (must be 192 channels, can be padded with 0) back to the spatial image.
    """
    b, c, h, w = x.shape
    expanded_x = torch.zeros(b, 192, h, w, dtype=x.dtype, device=x.device)
    for i in range(3):  
        expanded_x[:, i*64:(i*64)+(c//3), :, :] = x[:, i*(c//3):(i+1)*(c//3), :, :]
    # print(expanded_x.shape)
    
    b, c, h, w = expanded_x.shape
    x = expanded_x.view(b, 3, 64, h, w)
    x = x.permute(0, 1, 3, 4, 2)
    x = x.view(b, 3, h * w, 8, 8)
    x = dct.block_idct(x)
    x = x.view(b * 3, h * w, 64)
    x = x.transpose(1, 2)
    x = F.fold(x, output_size=(h * ratio, w * ratio),
               kernel_size=(size, size), dilation=dilation, padding=pad, stride=(stride, stride))
    x = x.view(b, 3, h * ratio, w * ratio)
    x = x + 128
    x = dct.to_rgb(x)
    #x = x / 255
    x = F.interpolate(x, scale_factor=1 / ratio, mode='bilinear', align_corners=True)
    #x = x.clamp(min=0.0, max=1.0)
    return x

In [3]:
def compute_batch_ssim_psnr(torch_img, numpy_img_set):
    """
    Compute the batch-wise SSIM and PSNR metrics.
    :param torch_img: The torch tensor image.
    :param numpy_img_set: The numpy array image set.
    :return: The batch-wise SSIM mean and PSNR mean metrics.
    """
    ssim_batch = []
    psnr_batch = []
    torch_img = torch_img.cpu()
    torch_img = torch_img.numpy()
    torch_img = torch_img.transpose((0, 2, 3, 1))
    for i in range(len(numpy_img_set)):
        psnr_value = psnr(torch_img[i], numpy_img_set[i], data_range=255)
        psnr_batch.append(psnr_value)
        for j in range(len(numpy_img_set[0])):
            ssim_index = ssim(numpy_img_set[i][j], torch_img[i][j], channel_axis=-1, data_range=255.0)
            ssim_batch.append(ssim_index)
    ssim_batch = np.array(ssim_batch)
    psnr_batch = np.array(psnr_batch)
    return np.mean(ssim_batch), np.mean(psnr_batch)

In [4]:
val_dataset = "/root/autodl-tmp/coco2017/val2017"
annFile = '/root/autodl-tmp/coco2017/annotations/person_keypoints_val2017.json'
#img_sets = os.listdir(path)
device = torch.device('cuda:0')
transform = transforms.ToTensor()
test_set = []
test_ssim = []
test_psnr = []
temp = 0

# COCO API
coco = COCO(annFile)

# get all image ID
image_ids = coco.getImgIds()

# get every image id
image_info = coco.loadImgs(image_ids)

loading annotations into memory...
Done (t=0.30s)
creating index...
index created!


In [5]:
for i in range(len(image_info)):
    img_path = os.path.join(val_dataset + "/" + image_info[i]['file_name'])
    img = cv2.imread(img_path)
    img = cv2.resize(img, (256, 256))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    test_set.append(img)
    temp = temp + 1
    if temp % 8 == 0 or i == len(image_info)-1:
        images = torch.stack([transform(img).to(device) for img in test_set])
        #images = dct_transform(images, chs_remove = [0, 1, 2, 3, 8, 9, 10, 16, 17, 24, 39, 46, 47, 53, 54, 55, 60, 61, 62, 63])
        images = dct_transform(images, chs_remove = [0, 1, 2, 3, 8, 9, 10, 16, 17, 24])
        img_spat = idct_transform(images)
        a, b = compute_batch_ssim_psnr(img_spat, test_set)
        if temp % 200 == 0:
            print(a, b)
        test_ssim.append(a)
        test_psnr.append(b)
        test_set = []

test_ssim = np.array(test_ssim)
test_psnr = np.array(test_psnr)
print("val_s9_ssim: ", np.mean(test_ssim), "val_s9_psnr: ", np.mean(test_psnr))

0.32202406817008766 10.804399690500794
0.43923410647965133 13.192959391409044
0.4990015243927739 10.975258889247286
0.43624621292841814 12.529504631558147
0.3992937911145037 11.116746866627945
0.37758064479892134 11.640122438508508
0.48475095213178016 12.977788729118709
0.35631677925766025 11.35582164749406
0.32278947595087204 10.532536467544766
0.4507062023642441 10.907366532492302
0.3828182632514336 10.706525900113887
0.3514391153527673 10.60919645494293
0.4200334066812093 10.939692961055353
0.3199174919563026 10.925025663005844
0.42851695256532174 12.179270188143317
0.42971394957730435 10.65957088805165
0.3838379564279494 10.531025455492852
0.38990591986277734 11.32853242614673
0.376727841708293 10.406993443998322
0.3864319481955223 11.100305765558804
0.3950127384345212 12.337731399202946
0.38400018777630673 11.898691562040923
0.4092148318101783 12.068269217182886
0.36386880239966835 11.113905705588868
0.35568918644841574 10.628210967246341
val_s9_ssim:  0.3867971773640353 val_s9_ps