In [1]:
from __future__ import print_function
import os
import torch
import argparse
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import time

from math import log10
import numpy as np
from torch.autograd import Variable
from torch.utils.data import DataLoader
# from data_utils import DatasetFromH5_SFSR
# from model import Net_SRCNN
!pip install tensorboard_logger
from tensorboard_logger import configure, log_value

# from data_utils import DatasetFromH5_MFSR
# from model import Net_VSRNet

import torch.nn.functional as F
import torch.nn.init as init
import matplotlib.pyplot as plt


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


## SRCNN

In [2]:
class Net_SRCNN(nn.Module):
    def __init__(self, upscale_factor):
        super(Net_SRCNN, self).__init__()

        self.conv1 = nn.Conv2d(1,  64, (9, 9), (1, 1), (4, 4))
        self.conv2 = nn.Conv2d(64, 32, (5, 5), (1, 1), (2, 2))
        self.conv3 = nn.Conv2d(32, 1,  (5, 5), (1, 1), (2, 2))
        
        self._initialize_weights()

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = (self.conv3(x))
        return x
    
    def _initialize_weights(self):
        init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv3.weight)

In [3]:
from torch.utils.data.dataset import Dataset
import h5py

class DatasetFromH5_MFSR(Dataset):
    def __init__(self, image_dataset_dir, target_dataset_dir, upscale_factor, input_transform=None, target_transform=None):
        super(DatasetFromH5_MFSR, self).__init__()
        
        image_h5_file = h5py.File(image_dataset_dir, 'r')
        target_h5_file = h5py.File(target_dataset_dir, 'r')
        image_dataset = image_h5_file['data']
        target_dataset = target_h5_file['data']
        
        self.image_datasets = image_dataset
        self.target_datasets = target_dataset
        self.total_count = image_dataset.shape[0]
        
        self.input_transform = input_transform
        self.target_transform = target_transform
        
    def __getitem__(self, index):        
        image = self.image_datasets[index, :, :, :]
        target = self.target_datasets[index, [2], :, :]
        
        image  = image.astype(np.float32)
        target = target.astype(np.float32)
        
        #   Notice that image is the bicubic upscaled LR image patch, in float format, in range [0, 1]
#        image = image / 255.0 
        #   Notice that target is the HR image patch, in uint8 format, in range [0, 255]
        target = target / 255.0
        
        image =  torch.from_numpy(image)
        target = torch.from_numpy(target)

        return image, target

    def __len__(self):
        return self.total_count

In [4]:
data_dir = "./data"

downloads_dir = data_dir + '/downloads'
datasets_dir = data_dir + '/datasets'
models_dir = data_dir + '/models'
pretrained_models = data_dir + '/pretrained_models'

os.makedirs(downloads_dir, exist_ok=True)
os.makedirs(datasets_dir, exist_ok=True)
os.makedirs(models_dir, exist_ok=True)
os.makedirs(pretrained_models, exist_ok=True)

uf4_train_dir = datasets_dir + '/uf4_train'
uf4_val_dir = datasets_dir + '/uf4_val'

srrnet_train_lr = uf4_train_dir + '/srrnet_train_lr.h5'
srrnet_train_hr = uf4_train_dir + '/srrnet_train_hr.h5'

srrnet_val_lr = uf4_val_dir + '/srrnet_val_lr.h5'
srrnet_val_hr = uf4_val_dir + '/srrnet_val_hr.h5'

#!wget https://www.dropbox.com/sh/1jz9zeer9wxetx2/AACDZmHK7d2JQi0ADaoliM04a/uf_4/train/Data_CDVL_LR_Bic_MC_uf_4_ps_72_fn_5_tpn_225000.h5 -O srrnet_train_lr
#!wget https://www.dropbox.com/sh/1jz9zeer9wxetx2/AACmrvoqkXXnZTXUFsWvNDCsa/uf_4/train/Data_CDVL_HR_uf_4_ps_72_fn_5_tpn_225000.h5 -O srrnet_train_hr

#!wget https://www.dropbox.com/sh/1jz9zeer9wxetx2/AADJnJmRvFxmf7sxEk5G0Uuma/uf_4/val/Data_CDVL_LR_Bic_MC_uf_4_ps_72_fn_5_tpn_45000.h5 -O srrnet_val_lr
#!wget https://www.dropbox.com/sh/1jz9zeer9wxetx2/AAChoVG4fLqdpsmSuq9wrEvFa/uf_4/val/Data_CDVL_HR_uf_4_ps_72_fn_5_tpn_45000.h5 -O srrnet_val_hr



In [5]:
import pickle

from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [6]:
with open('/content/drive/MyDrive/Computer_Vision/train_subset_12800.pkl', 'rb') as f:
    subset_train = pickle.load(f)

In [7]:
with open('/content/drive/MyDrive/Computer_Vision/val_subset_12800.pkl', 'rb') as f:
    subset_val = pickle.load(f)

In [8]:
upscale_factor = 4
threads = 1
batchSize = 256


train_loader = DataLoader(dataset=subset_train, num_workers=threads, batch_size=batchSize, shuffle=False)
val_loader = DataLoader(dataset=subset_val, num_workers=threads, batch_size=batchSize, shuffle=False)

## Get the pretrained SRCNN

In [9]:
!wget https://www.dropbox.com/s/pd5b2ketm0oamhj/srcnn_x4.pth -O {pretrained_models}/srcnn_model.pth

--2023-05-14 23:03:41--  https://www.dropbox.com/s/pd5b2ketm0oamhj/srcnn_x4.pth
Resolving www.dropbox.com (www.dropbox.com)... 162.125.81.18, 2620:100:6031:18::a27d:5112
Connecting to www.dropbox.com (www.dropbox.com)|162.125.81.18|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: /s/raw/pd5b2ketm0oamhj/srcnn_x4.pth [following]
--2023-05-14 23:03:41--  https://www.dropbox.com/s/raw/pd5b2ketm0oamhj/srcnn_x4.pth
Reusing existing connection to www.dropbox.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://uccb873599992d9f8c92bdcb5060.dl.dropboxusercontent.com/cd/0/inline/B8EZN22Q2LS8iOm4pwqWnMNhfvzj_Hd-TpVL-y4ptM_DBEM0db66Hl3P7dT7fBTn7ZGFh24WOVIYCuYD432WE3netObKcOyc2kBcp5lb9rfLf0RuAsZgyaLbUZqIz64DUbbA6_InQdNK6uInY0zKMKNT8w25DwJpfru5f27ozjw8Ig/file# [following]
--2023-05-14 23:03:41--  https://uccb873599992d9f8c92bdcb5060.dl.dropboxusercontent.com/cd/0/inline/B8EZN22Q2LS8iOm4pwqWnMNhfvzj_Hd-TpVL-y4ptM_DBEM0db66Hl3P7dT7fBTn7ZGFh24WOVIYC

In [10]:
upscale_factor = 4
srcnn = Net_SRCNN(upscale_factor=upscale_factor)

state_dict = srcnn.state_dict()
for n, p in torch.load(pretrained_models+'/srcnn_model.pth', map_location=lambda storage, loc: storage).items():
    if n in state_dict.keys():
        state_dict[n].copy_(p)
    else:
        raise KeyError(n)

torch.save(srcnn, pretrained_models+'/srcnn_model.pth')

Define the VSRNet

In [11]:
class Net_VSRNet(nn.Module):
    def __init__(self, upscale_factor, srcnn_model):
        super(Net_VSRNet, self).__init__()

        self.conv1_f0 = nn.Conv2d(1,  64, (9, 9), (1, 1), (4, 4))
        self.conv1_f1 = nn.Conv2d(1,  64, (9, 9), (1, 1), (4, 4))
        self.conv1_f2 = nn.Conv2d(1,  64, (9, 9), (1, 1), (4, 4))
        
        
        self.conv2_1 = nn.Conv2d(192, 32, (5, 5), (1, 1), (2, 2))
        self.conv2_2 = nn.Conv2d(192, 32, (5, 5), (1, 1), (2, 2))
        self.conv3 = nn.Conv2d(64, 1,  (5, 5), (1, 1), (2, 2))
        
        self.srcnn_model = srcnn_model
        self.upscale_factor = upscale_factor
        
        self._initialize_weights()


    def forward(self, x):
        
        h10 = x[:,[0],:,:]
        h11 = x[:,[1],:,:]
        h12 = x[:,[2],:,:]
        h13 = x[:,[3],:,:]
        h14 = x[:,[4],:,:] 

        h10 = self.conv1_f0(h10)
        h11 = self.conv1_f1(h11)
        h12 = self.conv1_f2(h12)
        h13 = self.conv1_f1(h13)
        h14 = self.conv1_f0(h14) 

        x1 = F.relu(torch.cat((h10, h11, h12), 1))
        x2 = F.relu(torch.cat((h12, h13, h14), 1))

        x1 = self.conv2_1(x1)
        x2 = self.conv2_2(x2)

        x = F.relu(torch.cat((x1,x2),1))
        x = (self.conv3(x))

        return x
    
    def _initialize_weights(self):
        
        srcnn_model = torch.load(self.srcnn_model, map_location=lambda storage, loc: storage) # forcing to load to CPU       
        
        self.conv1_f0.weight.data = (srcnn_model.conv1.weight.data).clone()
        self.conv1_f1.weight.data = (srcnn_model.conv1.weight.data).clone()
        self.conv1_f2.weight.data = (srcnn_model.conv1.weight.data).clone()
        
        self.conv1_f0.bias.data = (srcnn_model.conv1.bias.data).clone()
        self.conv1_f1.bias.data = (srcnn_model.conv1.bias.data).clone()
        self.conv1_f2.bias.data = (srcnn_model.conv1.bias.data).clone()
        
        self.conv2_1.weight.data = torch.cat((srcnn_model.conv2.weight.data, 
                                            srcnn_model.conv2.weight.data, 
                                            srcnn_model.conv2.weight.data), 1).clone()/3.0

        self.conv2_2.weight.data = torch.cat((srcnn_model.conv2.weight.data, 
                                            srcnn_model.conv2.weight.data, 
                                            srcnn_model.conv2.weight.data), 1).clone()/3.0
        
        self.conv2_1.bias.data = (srcnn_model.conv2.bias.data).clone()     
        self.conv2_2.bias.data = (srcnn_model.conv2.bias.data).clone()

        self.conv3.weight.data = torch.cat((srcnn_model.conv3.weight.data, 
                                            srcnn_model.conv3.weight.data), 1).clone()/2.0
        
        self.conv3.bias.data = (srcnn_model.conv3.bias.data).clone()

In [12]:
model = Net_VSRNet(upscale_factor=upscale_factor, srcnn_model=pretrained_models+'/srcnn_model.pth')
criterion = nn.MSELoss()

if torch.cuda.is_available():
    model = model.cuda()
    criterion = criterion.cuda()

In [13]:
lr = 0.001
optimizer = optim.Adam([{'params': model.conv1_f0.parameters()},
                        {'params': model.conv1_f1.parameters()},
                        {'params': model.conv1_f2.parameters()},
                        {'params': model.conv2_1.parameters()},
                        {'params': model.conv2_2.parameters()},
                        {'params': model.conv3.parameters(), 'lr': lr/10.0}
                        ], lr=lr)

In [14]:
configure("tensorBoardRuns/VSRNet-relu-mid-fusion-pretrain-sym-x4-batch-128-CDVL-225000x5x72x72-wd")

In [26]:
from scipy import ndimage 
torch.set_grad_enabled(True)  # Context-manager 

sobelx = [[1, 0, -1], [2, 0, -2], [1, 0, -1]]
sobely = [[1, 2, 1], [0, 0, 0], [-1, -2, -1]]        
depth = 1
channels = 1

sobelx_kernel = torch.tensor(sobelx, dtype=torch.float32, requires_grad=True).unsqueeze(0).expand(depth, channels, 3, 3)
sobely_kernel = torch.tensor(sobely, dtype=torch.float32, requires_grad=True).unsqueeze(0).expand(depth, channels, 3, 3)

sobelz = [[[1, 2, 1], [2, 4, 2], [1, 2, 1]],[[0, 0, 0], [0, 0, 0], [0, 0, 0]],[[-1, -2, -1], [-2, -4, -2], [-1, -2, -1]]] 
sobelz_kernel = torch.tensor(sobelz, dtype=torch.float32, requires_grad=True).unsqueeze(0)

if torch.cuda.is_available():
    sobelx_kernel = sobelx_kernel.cuda()
    sobely_kernel = sobely_kernel.cuda()
    sobelz_kernel = sobelz_kernel.cuda()

def motion_loss(input, predicted, target):

    mse = 0
    batch_size=256

    for i in range(batch_size):

        input[i,2,:,:] = predicted[i,:,:,:]

        frames_inp = input[i,1:4,:,:] 
        frames_tar = frames_inp.clone()
        frames_tar[i,:,:] = target[i,:,:,:] 

        x=input[i,2,:,:].unsqueeze(0).unsqueeze(0)
        x=x.reshape((x.shape[1],1,x.shape[2],x.shape[3]))
        t=target[i,0,:,:].unsqueeze(0).unsqueeze(0)
        t=t.reshape((t.shape[1],1,t.shape[2],t.shape[3]))
        # print(frames_inp.shape)
        
        # dx = torch.tensor(ndimage.sobel(input[:,2,:,:], 0), requires_grad=True)  # horizontal derivative
        # dy = torch.tensor(ndimage.sobel(input[:,2,:,:], 1), requires_grad=True)
        dx = F.conv2d(x, sobelx_kernel, stride=1, padding=1).squeeze(0)
        dy = F.conv2d(x, sobely_kernel, stride=1, padding=1).squeeze(0)
        # dz = ndimage.sobel(frames_inp, 2)[1,:,:]
        dz = F.conv2d(frames_inp, sobelz_kernel, stride=1, padding=1)
        # print(dz.shape)

        # dxt = torch.tensor(ndimage.sobel(target[:,0,:,:], 0), requires_grad=True)  # horizontal derivative
        # dyt = torch.tensor(ndimage.sobel(target[:,0,:,:], 1), requires_grad=True)
        dxt = F.conv2d(t, sobelx_kernel, stride=1, padding=1).squeeze(0)
        dyt = F.conv2d(t, sobely_kernel, stride=1, padding=1).squeeze(0)
        # dzt = ndimage.sobel(frames_tar, 2)[1,:,:]
        dzt = F.conv2d(frames_tar, sobelz_kernel, stride=1, padding=1)

        # dxdt = torch.tensor(np.divide(dx, dz, out=np.zeros_like(dx), where=dz!=0))
        # dydt = torch.tensor(np.divide(dy, dz, out=np.zeros_like(dy), where=dz!=0))

        mask = (dz != 0)
        dxdt = torch.full_like(dx, fill_value=float(0))
        dxdt[mask] = dx[mask] / dz[mask]
        dydt = torch.full_like(dy, fill_value=float(0))
        dydt[mask] = dy[mask] / dz[mask]

        # dxtdz = torch.tensor(np.divide(dxt, dzt, out=np.zeros_like(dxt), where=dzt!=0))
        # dytdz = torch.tensor(np.divide(dyt, dzt, out=np.zeros_like(dyt), where=dzt!=0))

        mask = (dzt != 0)
        dxtdz = torch.full_like(dxt, fill_value=float(0))
        dxtdz[mask] = dxt[mask] / dzt[mask]
        dytdz = torch.full_like(dyt, fill_value=float(0))
        dytdz[mask] = dyt[mask] / dzt[mask]

        dxdt = (dxdt - torch.min(dxdt)) / (torch.max(dxdt) - torch.min(dxdt))
        dxtdz = (dxtdz - torch.min(dxtdz)) / (torch.max(dxtdz) - torch.min(dxtdz))
        dydt = (dydt - torch.min(dydt)) / (torch.max(dydt) - torch.min(dydt))
        dytdz = (dytdz - torch.min(dytdz)) / (torch.max(dytdz) - torch.min(dytdz))

        mse += 0.8*F.mse_loss(input[i,2,:,:], target[i,0,:,:])+0.1*F.mse_loss(dxdt, dxtdz)+0.1*F.mse_loss(dydt, dytdz)

    return mse


In [16]:
def train(epoch):
    lr = 0.001
    epoch_loss = 0
    epoch_psnr = 0
    start = time.time()
    #   Step up learning rate decay
    #   The network have 3 layers
    lr = lr * (0.1 ** (epoch // (nEpochs // 4)))
    
    optimizer.param_groups[0]['lr'] = lr
    optimizer.param_groups[1]['lr'] = lr
    optimizer.param_groups[2]['lr'] = lr
    optimizer.param_groups[3]['lr'] = lr
    optimizer.param_groups[4]['lr'] = lr/10.0
    

    n = 0
    for iteration, batch in enumerate(train_loader, 1):
        if n >= 49:
          break 
        n = n+1
        image, target = Variable(batch[0]), Variable(batch[1])
        if torch.cuda.is_available():
            image = image.cuda()
            target = target.cuda()

        optimizer.zero_grad()
        #loss = criterion(model(image), target)
        loss = motion_loss(image,model(image), target)
        psnr = 10 * log10(1 / loss.data.item())
        epoch_loss += loss.data.item()
        epoch_psnr += psnr
        loss.backward()
        optimizer.step()
        
    end = time.time()
    print("===> Epoch {} Complete: lr: {}, Avg. Loss: {:.4f}, Avg.PSNR:  {:.4f} dB, Time: {:.4f}".format(epoch, lr, epoch_loss / len(train_loader), epoch_psnr / len(train_loader), (end-start)))
    
    log_value('train_loss', epoch_loss / len(train_loader), epoch)
    log_value('train_psnr', epoch_psnr / len(train_loader), epoch)

In [17]:
def val(epoch):
    #   Validation on CDVL val set
    lr = 0.001
    avg_psnr = 0
    avg_mse = 0
    frame_count = 0
    start = time.time()
    n = 0
    for batch in val_loader:
        if n >= 49:
          break 
        n = n+1
        image, target = Variable(batch[0]), Variable(batch[1])
        if torch.cuda.is_available():
            image = image.cuda()
            target = target.cuda()

        prediction = model(image)

        for i in range(0, image.shape[0]):
            mse = criterion(prediction[i], target[i])
            # mse = motion_loss(image[],prediction[i], target[i])
            psnr = 10 * log10(1 / mse.data.item())
            avg_psnr += psnr
            avg_mse  += mse.data.item()
            frame_count += 1

    end = time.time()
    print("===> Epoch {} Validation CDVL: Avg. Loss: {:.4f}, Avg.PSNR:  {:.4f} dB, Time: {:.4f}".format(epoch, avg_mse / frame_count, avg_psnr / frame_count, (end-start)))

    log_value('val_loss', avg_mse / frame_count, epoch)
    log_value('val_psnr', avg_psnr / frame_count, epoch)

In [18]:
def checkpoint(epoch):
    if epoch%10 == 0:
        if not os.path.exists("epochs_VSRNet"):
            os.makedirs("epochs_VSRNet")
        model_out_path = "epochs_VSRNet/" + "model_epoch_{}.pth".format(epoch)
        torch.save(model, model_out_path)
        print("Checkpoint saved to {}".format(model_out_path))

In [27]:
nEpochs = 5
lr = 0.001

#val(0)
#checkpoint(0)
for epoch in range(1, nEpochs + 1):
    train(epoch)
    val(epoch)
    checkpoint(epoch)

RuntimeError: ignored

### Let's test with a video!

In [None]:
uf4_test_dir = datasets_dir + '/uf4_test'

vsrnet_test_lr = uf4_test_dir + '/vsrnet_test_lr.h5'
vsrnet_test_hr = uf4_test_dir + '/vsrnet_test_hr.h5'

!wget https://www.dropbox.com/s/q3evjn917cwv9ax/scene_40.h5?dl=0 -O vsrnet_test_lr
!wget https://www.dropbox.com/s/lxm30agjddg72xe/scene_40.h5?dl=0  -O vsrnet_test_hr

#!wget https://www.dropbox.com/sh/1jz9zeer9wxetx2/AAADzBQ7iA492oQ26ag67ZsKa/uf_4/test/LR_Bic_MC/scene_30.h5 -O vsrnet_test_lr
#!wget https://www.dropbox.com/sh/1jz9zeer9wxetx2/AADSka3PgSR5EuCt9ByugfY6a/uf_4/test/HR/scene_30.h5 -O vsrnet_test_hr


--2023-05-13 15:15:19--  https://www.dropbox.com/s/q3evjn917cwv9ax/scene_40.h5?dl=0
Resolving www.dropbox.com (www.dropbox.com)... 162.125.80.18, 2620:100:6031:18::a27d:5112
Connecting to www.dropbox.com (www.dropbox.com)|162.125.80.18|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: /s/raw/q3evjn917cwv9ax/scene_40.h5 [following]
--2023-05-13 15:15:20--  https://www.dropbox.com/s/raw/q3evjn917cwv9ax/scene_40.h5
Reusing existing connection to www.dropbox.com:443.
HTTP request sent, awaiting response... 302 Found
Location: https://uc145348ac01982bad3f9db461f2.dl.dropboxusercontent.com/cd/0/inline/B7_1_v5JyUOfpUhDM01CBICsyK6-4KLxrfRBLEMd6vXG7FeS_lbEykTOtAfOIe-jvs8BxhAnMhVDEp8MAwvHkjUrZGtDK9S1By9FlWrazxnlthyedYVhU7n-DBb7LHOVrzgHuZGOspt96Wqlaw4_cOgJ9vZHK_7Sy8-TOTM8g5-QHQ/file# [following]
--2023-05-13 15:15:20--  https://uc145348ac01982bad3f9db461f2.dl.dropboxusercontent.com/cd/0/inline/B7_1_v5JyUOfpUhDM01CBICsyK6-4KLxrfRBLEMd6vXG7FeS_lbEykTOtAfOIe-jvs8BxhAnMhV

In [None]:
path_LR_Bic_MC = './vsrnet_test_hr'
path_HR = './vsrnet_test_hr'
videos_h5_name = ['scene_40.h5']
videos_h5_name.sort()

In [None]:
h5_len = len(videos_h5_name)
model_PSNR   = np.zeros(h5_len)
model_SSIM   = np.zeros(h5_len)
bicubic_PSNR = np.zeros(h5_len)
bicubic_SSIM = np.zeros(h5_len)
model_time   = np.zeros(h5_len)

In [None]:
out_path = './'
if not os.path.exists(out_path):
    os.makedirs(out_path)

In [None]:
import numpy
import math

def psnr(img1, img2):
    mse = numpy.mean( (img1 - img2) ** 2 )
    if mse == 0:
        return 100
    PIXEL_MAX = 255.0
    return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))

In [None]:
from scipy.ndimage import gaussian_filter

from numpy.lib.stride_tricks import as_strided as ast

"""
Hat tip: http://stackoverflow.com/a/5078155/1828289
"""
def block_view(A, block=(3, 3)):
    """Provide a 2D block view to 2D array. No error checking made.
    Therefore meaningful (as implemented) only for blocks strictly
    compatible with the shape of A."""
    # simple shape and strides computations may seem at first strange
    # unless one is able to recognize the 'tuple additions' involved ;-)
    shape = (A.shape[0]// block[0], A.shape[1]// block[1])+ block
    strides = (block[0]* A.strides[0], block[1]* A.strides[1])+ A.strides
    return ast(A, shape= shape, strides= strides)


def ssim(img1, img2, C1=0.01**2, C2=0.03**2):

    bimg1 = block_view(img1, (4,4))
    bimg2 = block_view(img2, (4,4))
    s1  = numpy.sum(bimg1, (-1, -2))
    s2  = numpy.sum(bimg2, (-1, -2))
    ss  = numpy.sum(bimg1*bimg1, (-1, -2)) + numpy.sum(bimg2*bimg2, (-1, -2))
    s12 = numpy.sum(bimg1*bimg2, (-1, -2))

    vari = ss - s1*s1 - s2*s2
    covar = s12 - s1*s2

    ssim_map =  (2*s1*s2 + C1) * (2*covar + C2) / ((s1*s1 + s2*s2 + C1) * (vari + C2))
    return numpy.mean(ssim_map)

# FIXME there seems to be a problem with this code
def ssim_exact(img1, img2, sd=1.5, C1=0.01**2, C2=0.03**2):

    mu1 = gaussian_filter(img1, sd)
    mu2 = gaussian_filter(img2, sd)
    mu1_sq = mu1 * mu1
    mu2_sq = mu2 * mu2
    mu1_mu2 = mu1 * mu2
    sigma1_sq = gaussian_filter(img1 * img1, sd) - mu1_sq
    sigma2_sq = gaussian_filter(img2 * img2, sd) - mu2_sq
    sigma12 = gaussian_filter(img1 * img2, sd) - mu1_mu2

    ssim_num = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2))

    ssim_den = ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))

    ssim_map = ssim_num / ssim_den
    return numpy.mean(ssim_map)

In [None]:
from tqdm import tqdm
import cv2

video_idx = 0
#   Read h5 file
LR_Bic_MC_h5_file = h5py.File('./vsrnet_test_lr', 'r')
LR_Bic_MC_h5_data = LR_Bic_MC_h5_file['data']
HR_h5_file = h5py.File('./vsrnet_test_hr', 'r')
HR_h5_data = HR_h5_file['data']
    
# load to memory
HR_h5_data = HR_h5_data[()]#.value
LR_Bic_MC_h5_data = LR_Bic_MC_h5_data[()]#.value
    
# transpose to correct order
HR_h5_data = np.transpose(HR_h5_data, (3, 2, 1, 0))
LR_Bic_MC_h5_data = np.transpose(LR_Bic_MC_h5_data, (3, 2, 1, 0))
    
frame_number = LR_Bic_MC_h5_data.shape[0]

IS_REAL_TIME = False

video_name = 'scene_40'
    
if not IS_REAL_TIME:
    fps = 30
    size = (LR_Bic_MC_h5_data.shape[3], LR_Bic_MC_h5_data.shape[2])
    output_name = out_path + video_name.split('.')[0] + '.avi'
    videoWriter = cv2.VideoWriter(output_name, cv2.VideoWriter_fourcc('M','J','P','G'), fps, size)
#            videoWriter = cv2.VideoWriter(output_name, cv2.VideoWriter_fourcc(*'XVID'), fps, size)
        
#   Prepare to save PSNR and SSIM of the current video
#   Each value corresponding to one test frame
model_PSNR_cur   = np.zeros(frame_number)
model_SSIM_cur   = np.zeros(frame_number)
bicubic_PSNR_cur = np.zeros(frame_number)
bicubic_SSIM_cur = np.zeros(frame_number)
model_time_cur   = np.zeros(frame_number)
    
for idx in tqdm(range(0, frame_number)):
    img_HR = HR_h5_data[idx, 0, :, :] #2D
    img_LR_Bic_MC = LR_Bic_MC_h5_data[idx, :, :, :] #3D 5x1080x1920
    
    # Reshape to 4D
    img_LR_Bic_MC = img_LR_Bic_MC.reshape((1, img_LR_Bic_MC.shape[0], img_LR_Bic_MC.shape[1], img_LR_Bic_MC.shape[2]))
    
    img_LR_Bic_MC = img_LR_Bic_MC.astype(np.float32)

    img_LR_Bic_MC =  torch.from_numpy(img_LR_Bic_MC)
                        
    if torch.cuda.is_available():
        img_LR_Bic_MC = img_LR_Bic_MC.cuda()

    start = time.time()
    if img_LR_Bic_MC.sum() != 0:
        img_HR_net = model(img_LR_Bic_MC)

    else:
        img_HR_net = img_LR_Bic_MC[:,2,:,:]
        img_HR_net = img_HR_net.reshape((1, 1, img_HR.shape[0], img_HR.shape[1])) # reshape to 1x1x1080x1920
        
    end = time.time() # measure the computation time
    
    img_HR_net = img_HR_net.cpu()
    img_HR_net = img_HR_net.data[0].numpy()
    img_HR_net *= 255.0
    img_HR_net = img_HR_net.clip(0, 255)
    img_HR_net = img_HR_net.astype(np.uint8)
    
    img_LR_Bic_MC = img_LR_Bic_MC.cpu()
    img_LR_Bic = img_LR_Bic_MC[:, 2, :, :] # center frame
    img_LR_Bic = img_LR_Bic.data[0].numpy()
    img_LR_Bic *= 255.0
    img_LR_Bic = img_LR_Bic.clip(0, 255)
    img_LR_Bic = img_LR_Bic.astype(np.uint8)
    
    img_HR = img_HR.reshape((1, img_HR.shape[0], img_HR.shape[1]))
    img_LR_Bic = img_LR_Bic.reshape((1, img_LR_Bic.shape[0], img_LR_Bic.shape[1]))

    
    model_PSNR_cur[idx]   = psnr((img_HR).reshape(img_HR.shape[1], img_HR.shape[2]).astype(int), (img_HR_net).reshape(img_HR_net.shape[1], img_HR_net.shape[2]).astype(int))
    model_SSIM_cur[idx]   = ssim((img_HR).reshape(img_HR.shape[1], img_HR.shape[2]).astype(int), (img_HR_net).reshape(img_HR_net.shape[1], img_HR_net.shape[2]).astype(int))
    bicubic_PSNR_cur[idx] = psnr((img_HR).reshape(img_HR.shape[1], img_HR.shape[2]).astype(int), (img_LR_Bic).reshape(img_LR_Bic.shape[1], img_LR_Bic.shape[2]).astype(int))
    bicubic_SSIM_cur[idx] = ssim((img_HR).reshape(img_HR.shape[1], img_HR.shape[2]).astype(int), (img_LR_Bic).reshape(img_LR_Bic.shape[1], img_LR_Bic.shape[2]).astype(int))
    model_time_cur[idx]   = (end-start)

    # Repeat to 3 channels to save and display
    img_HR_net = np.repeat(img_HR_net, 3, axis=0)
    img_HR_net = np.transpose(img_HR_net, (1, 2, 0))

    if IS_REAL_TIME:
        plt.imshow(img_HR_net, cmap = 'gray')
        plt.show()

#                cv2.imshow('LR Video ', img_LR_Bic)
#                cv2.imshow('SR Video ', img_HR_net)
#                cv2.waitKey(DELAY_TIME)
    else:
        # save video
        videoWriter.write(img_HR_net)
    
# Done video writing
videoWriter.release()

# Save PSNR and SSIM
# Exclude PSNR = 100 cases (caused by black frames)
cal_flag = (model_PSNR_cur != 100)
model_PSNR[video_idx]   = np.mean(model_PSNR_cur[cal_flag])
model_SSIM[video_idx]   = np.mean(model_SSIM_cur[cal_flag])
bicubic_PSNR[video_idx] = np.mean(bicubic_PSNR_cur[cal_flag])
bicubic_SSIM[video_idx] = np.mean(bicubic_SSIM_cur[cal_flag])
model_time[video_idx]   = np.mean(model_time_cur[cal_flag])

print("===> Test on Video Idx: " + str(video_idx) +" Complete: Model PSNR: {:.4f} dB, Model SSIM: {:.4f} , Bicubic PSNR:  {:.4f} dB, Bicubic SSIM: {:.4f} , Average time: {:.4f}"
  .format(model_PSNR[video_idx], model_SSIM[video_idx], bicubic_PSNR[video_idx], bicubic_SSIM[video_idx], model_time[video_idx]*1000))
video_idx += 1

100%|██████████| 14/14 [00:12<00:00,  1.11it/s]

===> Test on Video Idx: 0 Complete: Model PSNR: 31.1959 dB, Model SSIM: 0.9991 , Bicubic PSNR:  31.3232 dB, Bicubic SSIM: 0.9995 , Average time: 680.5826



