In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!cp /content/drive/MyDrive/DL_Project/image_filtering/utils.py .
!cp /content/drive/MyDrive/DL_Project/image_filtering/vgg.py .

In [None]:
from __future__ import print_function, division
import os
import torchvision
import torch
from skimage import io, transform
from skimage.color import rgb2gray
from skimage.metrics import structural_similarity
import numpy as np
from torchvision import transforms, utils
import torch.nn as nn
import cv2

from torchvision.utils import save_image
import utils
from vgg import Vgg16

In [None]:
class MFFNet(torch.nn.Module):
    def __init__(self):
        super(MFFNet, self).__init__()
        
        self.conv1 = ConvLayer(4, 32, kernel_size=9, stride=1)
        self.in1 = torch.nn.InstanceNorm2d(32, affine=True)
        self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2)
        self.in2 = torch.nn.InstanceNorm2d(64, affine=True)
        self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2)
        self.in3 = torch.nn.InstanceNorm2d(128, affine=True)
        # Residual layers
        self.res1 = ResidualBlock(128)
        self.res2 = ResidualBlock(128)
        self.res3 = ResidualBlock(128)
        self.res4 = ResidualBlock(128)
        self.res5 = ResidualBlock(128)
        self.res6 = ResidualBlock(128)
        self.res7 = ResidualBlock(128)
        self.res8 = ResidualBlock(128)
        self.res9 = ResidualBlock(128)
        self.res10 = ResidualBlock(128)
        self.res11 = ResidualBlock(128)
        self.res12 = ResidualBlock(128)
        self.res13 = ResidualBlock(128)
        self.res14 = ResidualBlock(128)
        self.res15 = ResidualBlock(128)
        self.res16 = ResidualBlock(128)
        
        self.deconv1 = UpsampleConvLayer(128*2, 64, kernel_size=3, stride=1, upsample=2)
        self.in4 = torch.nn.InstanceNorm2d(64, affine=True)
        self.deconv2 = UpsampleConvLayer(64*2, 32, kernel_size=3, stride=1, upsample=2)
        self.in5 = torch.nn.InstanceNorm2d(32, affine=True)
        self.deconv3 = ConvLayer(32*2, 3, kernel_size=9, stride=1)

        self.relu = torch.nn.ReLU()
    
    def forward(self, X):
        o1 = self.relu(self.conv1(X))
        o2 = self.relu(self.conv2(o1))
        o3 = self.relu(self.conv3(o2))

        y = self.res1(o3)
        y = self.res2(y)
        y = self.res3(y)
        y = self.res4(y)
        y = self.res5(y)
        y = self.res6(y)
        y = self.res7(y)
        y = self.res8(y)
        y = self.res9(y)
        y = self.res10(y)
        y = self.res11(y)
        y = self.res12(y)
        y = self.res13(y)
        y = self.res14(y)
        y = self.res15(y)
        y = self.res16(y)
        
        in1 = torch.cat( (y, o3), 1 )
        y = self.relu(self.deconv1(in1))
        in2 = torch.cat( (y, o2), 1 )
        y = self.relu(self.deconv2(in2))
        in3 = torch.cat( (y, o1), 1 )
        y = self.deconv3(in3)
        
        return y

class ConvLayer(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super(ConvLayer, self).__init__()
        reflection_padding = kernel_size // 2
        self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
        self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
    
    def forward(self, x):
        out = self.reflection_pad(x)
        out = self.conv2d(out)
        return out

class ResidualBlock(torch.nn.Module):
    
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.in1 = torch.nn.InstanceNorm2d(channels, affine=True)
        self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
        self.in2 = torch.nn.InstanceNorm2d(channels, affine=True)
        self.relu = torch.nn.ReLU()
    
    def forward(self, x):
        residual = x
        out = self.relu(self.conv1(x))
        out = self.conv2(out)
        out = out + residual
        return out


class UpsampleConvLayer(torch.nn.Module):
    
    def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
        super(UpsampleConvLayer, self).__init__()
        self.upsample = upsample
        reflection_padding = kernel_size // 2
        self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
        self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
    
    def forward(self, x):
        x_in = x
        if self.upsample:
            x_in = torch.nn.functional.interpolate(x_in, mode='nearest', scale_factor=self.upsample)
        out = self.reflection_pad(x_in)
        out = self.conv2d(out)
        return out

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
imageFilter = MFFNet()

model_name = 'MFF-net'
imageFilter.load_state_dict( torch.load('/content/drive/MyDrive/DL_Project/trained_model/%s.ckpt'%(model_name)) )
imageFilter = imageFilter.to(device).float()

In [None]:
# code block for computing loss 

data_root = '/content/drive/MyDrive/DL_Project/image_filtering/test'
out_root = '/content/drive/MyDrive/DL_Project/image_filtering/output'
if not os.path.exists(out_root):
    os.mkdir(out_root)

for seq in range(1,6):
    # ground truth image
    file = ('gt_%s.bmp' % (seq) )
    filename = os.path.join( data_root, file )
    groundtruth = io.imread(filename) / 255
    groundtruth = np.transpose(groundtruth,(2,0,1))
    groundtruth = torch.from_numpy(groundtruth).float()
    groundtruth = groundtruth.to(device)



    file = ('out_%s.png' % (seq) )
    filename = os.path.join( data_root, file )    
    outputs = io.imread(filename) / 255
    outputs = np.transpose(outputs,(2,0,1))
    outputs = torch.from_numpy(outputs).float()
    outputs = outputs.to(device)
    
    outputs = outputs.unsqueeze(0)


    outputs[outputs>1] = 1
    outputs[outputs<0] = 0    

    # the parameter for color balance and brightness should be tuned for different scenes
    # outputs[0,0,:,:] = outputs[0,0,:,:]*1.1*1.5
    # outputs[0,1,:,:] = outputs[0,1,:,:]*1*1.5
    # outputs[0,2,:,:] = outputs[0,2,:,:]*1.5*1.5


    print('Image %s:' % seq)

    # calculate PSNR
    outputs_copy = outputs
    outputs = outputs[0,:,:,:]
    
    mse = torch.mean((groundtruth - outputs) ** 2)
    psnr = 20 * torch.log10(1 / torch.sqrt(mse))

    print('psnr: %s' % psnr.item())
    

    # calculate SSIM
    groundtruth_gray = transforms.functional.rgb_to_grayscale(groundtruth)
    groundtruth_gray = np.squeeze(groundtruth_gray).cpu().numpy()

    outputs_gray = transforms.functional.rgb_to_grayscale(outputs)
    outputs_gray = np.squeeze(outputs_gray).cpu().numpy()

    ssim = structural_similarity(groundtruth_gray, outputs_gray, multichannel=True, gaussian_weights=True, sigma=1.5, use_sample_covariance=False, data_range=1.0)
    # ssim = structural_similarity(groundtruth_gray,outputs_gray,win_size=11,gaussian_weights=True,multichannel=True,data_range=1.0,K1=0.01,K2=0.03,sigma=1.5)
    
    print('ssim: %s' % ssim)

    # calculate VGG loss
    VGG = Vgg16(requires_grad=False)
    VGG = VGG.to(device)

    criterion_vgg = nn.MSELoss()
    loss_tol_vgg  = 0
    
    groundtruth = groundtruth.unsqueeze(0)
    # outputs_copy = outputs_copy.unsqueeze(0)

    groundtruth = utils.normalize_ImageNet_stats(groundtruth)
    outputs_copy  = utils.normalize_ImageNet_stats(outputs_copy)

    feature_t = VGG(groundtruth, 3)
    feature_o = VGG(outputs_copy, 3)

    VGG_loss = []
    for l in range(3+1):
        VGG_loss.append( criterion_vgg(feature_o[l], feature_t[l]) )
    
    loss_vgg = sum(VGG_loss)

    print('VGG loss: %s' % loss_vgg.item())

Image 1:
psnr: 21.442401885986328
ssim: 0.878429125857347


Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth


  0%|          | 0.00/528M [00:00<?, ?B/s]

VGG loss: 3.8615498542785645
Image 2:
psnr: 19.636703491210938
ssim: 0.8890140575614649
VGG loss: 3.738367795944214
Image 3:
psnr: 12.616279602050781
ssim: 0.2850319679552631
VGG loss: 8.377556800842285
Image 4:
psnr: 10.188135147094727
ssim: 0.2576724574067086
VGG loss: 9.671786308288574
Image 5:
psnr: 19.26607894897461
ssim: 0.8511619438607132
VGG loss: 6.547539234161377
