## Imports

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
import time
import numpy as np
from torchvision.transforms import Compose, Resize, ToTensor
from torch.utils.data import Dataset, DataLoader
import os
from PIL import Image
import matplotlib.pyplot as plt
import sys
from torchvision.transforms import functional as TF
import torchvision.transforms as transforms
import random
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
import scipy.io as sio
import cv2
import cudacanvas
import torch.nn.functional as F
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from skimage.metrics import structural_similarity as compare_ssim


## Bicubic Plus Plus Model

In [3]:
class Bicubic_plus_plus(nn.Module):
    """
    Bicuic Plus Plus model. Adapted from Aselsan Researach group.
    - Pretrained weights from their github repository.
    - https://github.com/aselsan-research-imaging-team/bicubic-plusplus 
    """
    def __init__(self, sr_rate=3):
        super(Bicubic_plus_plus, self).__init__()
        self.conv0 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1, bias=False)
        self.conv1 = nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=False)
        self.conv2 = nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=False)
        self.conv_out = nn.Conv2d(32, (2*sr_rate)**2 * 3, kernel_size=3, padding=1, bias=False)
        self.Depth2Space = nn.PixelShuffle(2*sr_rate)
        self.act = nn.LeakyReLU(inplace=True, negative_slope=0.1)

    def forward(self, x):
        x0 = self.conv0(x)
        x0 = self.act(x0)
        x1 = self.conv1(x0)
        x1 = self.act(x1)
        x2 = self.conv2(x1)
        x2 = self.act(x2) + x0
        y = self.conv_out(x2)
        y = self.Depth2Space(y)
        return y

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Bicubic_plus_plus().to(device)
model.load_state_dict(torch.load('bicubic_pp_x3.pth'))
model.eval()
# Print number of parameters

Bicubic_plus_plus(
  (conv0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (conv_out): Conv2d(32, 108, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (Depth2Space): PixelShuffle(upscale_factor=6)
  (act): LeakyReLU(negative_slope=0.1, inplace=True)
)

## Handle Data

### Function to get cursor position

In [4]:
# Mouse callback function to update cursor position
def update_cursor_position(event, x, y, flags, param):
    global cursor_x, cursor_y
    if event == cv2.EVENT_MOUSEMOVE:
        cursor_x, cursor_y = x, y

In [17]:
# Define transform to be applied to frames
# Currently only transforms.ToTensor()
transform = transforms.Compose([
    transforms.ToTensor(),
])

def upscale_video(video_path, model, transform = None, out_video_path = None, evaluate_mode = False):
    """
    Upscale a video 3x by using bicubic plus plus model
    - video_path: path to the video
    - model: bicubic plus plus model
    - transform: transform to be applied to frames
    - out_video_path: path to the output video (only if evaluate_mode is True)
    - evaluate_mode: if True, model is used to write the video to storage so that it can be evaluated.
    - evaluate mode is slow due to GPu -> CPU transfer overhead

    - Output: Upscaled Video feed or File storage if in evaluate mode 
    """
    # Set CUDA device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Load video file
    cap = cv2.VideoCapture(video_path)

    # Get video properties
    fps = cap.get(cv2.CAP_PROP_FPS)
    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    
    # Define codec and VideoWriter object
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    output_video = cv2.VideoWriter(out_video_path, fourcc, fps, (frame_width*3, frame_height*3))


    # Set up cudacanvas window for renders
    if not evaluate_mode:
        white_screen = torch.ones((3, frame_height, frame_width*2)).to(device)
        cudacanvas.set_image(white_screen)
        cudacanvas.create_window()

    # Process each frame
    while cap.isOpened():
        elapsed_time = 0
        torch.cuda.synchronize()
        start_time = time.time()
        # Read frame
        ret, frame = cap.read()
        if not ret:
            break
        
        frame_downsampled = cv2.resize(frame, (int(frame.shape[1] / 3), int(frame.shape[0] / 3))) 
        
        # Convert frame to RGB and apply transform
        frame_rgb = frame_downsampled
        frame_rgb = cv2.cvtColor(frame_rgb, cv2.COLOR_BGR2RGB)
        frame_tensor = transform(frame_rgb).unsqueeze(0).to(device)

        frame_bicubic = cv2.resize(frame_rgb, (frame_width, frame_height))
        # frame_bicubic = cv2.cvtColor(frame_bicubic, cv2.COLOR_BGR2RGB)
        frame_bicubic = transform(frame_bicubic).unsqueeze(0).to(device)

        # Perform super-resolution on the frame
        with torch.no_grad():
            upscaled_frame_tensor = model(frame_tensor)
            upscaled_frame_tensor = torch.clamp(upscaled_frame_tensor, 0, 1)
        
        # print(frame_downsampled.shape, frame_bicubic.shape, frame_tensor.shape, upscaled_frame_tensor.shape)
        
        # Write the upscaled frame to output video file if in evaluate mode
        if evaluate_mode:
            # Convert tensor back to numpy array
            upscaled_frame = (upscaled_frame_tensor.squeeze().cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
            upscaled_frame = cv2.cvtColor(upscaled_frame, cv2.COLOR_RGB2BGR)
            # cv2.imshow('Upscaled Frame', upscaled_frame)
            # # Write upscaled frame to output video
            output_video.write(upscaled_frame)
        
        # If no evaluation, just output using cudacanvas
        else:
            cudacanvas.render()
            # Concatenate frame and upscaled frame
            combine = torch.concat([frame_bicubic, upscaled_frame_tensor], dim = -1)
            cudacanvas.set_image(combine.squeeze())
            if cudacanvas.should_close():
                break
        
        
        cv2.waitKey(1)
        if evaluate_mode:
            if cv2.waitKey(0) & 0xFF == ord('q'):
                break

    # Release resources
    cap.release()
    output_video.release()
    cv2.destroyAllWindows()

# testing
video_path = 'test_videos/4K ULtra HD ｜ SAMSUNG UHD Demo׃ LED TV [R3GfuzLMPkA].mp4' 
upscale_video(video_path, model, transform = transform, evaluate_mode = False)

## Calculate average FPS for 1000 frames

In [None]:
noise = torch.randn(1, 3, 720, 1280).to(device)
times = []
for i in range(1000):
  torch.cuda.synchronize()
  start = time.time()
  with torch.no_grad():
    pred = model(noise)
    pred = torch.clamp(pred, 0, 1)
  torch.cuda.synchronize()
  end = time.time() - start
  times.append(end)

# plt.imshow(pred.squeeze().cpu().numpy().transpose(1, 2, 0))
avg_time = np.mean(times)

print("Input frame size =", noise.shape)
print("Output frame size =", pred.shape)
print("Average Time per frame =", 1000*avg_time, "ms")
print("Average FPS =", 1/avg_time, "FPS")

Input frame size = torch.Size([1, 3, 720, 1280])
Output frame size = torch.Size([1, 3, 2160, 3840])
Average Time per frame = 8.745447874069214 ms
Average FPS = 114.34520157224433 FPS


# Get evaluation metrics

## Performance Timings
- Uncomment each line to test individual transfer latencies

In [None]:
import torch
import time

tot_time = 0

# input tensor (360p)
frame_tensor_cpu = torch.randn(1,3,960,720)
frame_tensor_gpu = frame_tensor_cpu.to('cuda')
frame_tensor_cpu_3x = torch.randn(1,3,2160,3840)
frame_tensor_gpu_3x = frame_tensor_cpu_3x.to('cuda')
frame_tensor_shared = frame_tensor_cpu_3x.pin_memory()
shared_ref = torch.zeros(1,3,2160,3840).pin_memory()

for i in range(1000):
    st = time.time()
    # frame_tensor_cpu.to('cuda') # 1.4 ms - normal cpu to gpu / 24.240
    # frame_tensor_shared.to('cuda') # 9 ms - pinned cpu to gpu
    # shared_ref[:] = frame_tensor_cpu_3x # 14.26 ms - normal cpu to pinned cpu
    # shared_ref[:] = frame_tensor_gpu_3x # 8.07 ms - gpu to pinned cpu
    # shared_ref.to('cpu') # 0.001 ms
    frame_tensor_gpu_3x.to('cpu') # 16.12 ms - gpu to normal cpu
    
    tot_time += time.time() - st
print('Average time (ms):', tot_time/i*1000)    
    

Average time (ms): 15.062652192674243


## Quality Metrics
- PSNR
- SSIM
- Inference Time

In [15]:
transform = transforms.Compose([
    transforms.ToTensor(),
])

def evaluate_model(video_path, model, crop_size = [1920, 1080], upscale_factor = 3, n_samples = 10, transform = None):
    """
    - This function samples a random area of 'crop_size' from the video. 
    - This is then downsampled by 3x and then upscaled back to 'crop_size'
    - This is then compared to the original cropped frame
    - The baseline is the regular bicubic upsampling

    Inputs:
    - video_path: path to the test 4k video
    - model: Super res model
    - n_samples: number of samples to evaluate
    - transform: transform to use

    Outputs:
    - PSNR: Average Peak Signal to Noise Ratio
    - SSIM: Average Structural Similarity Index
    - infer_time: Average inference time
    """
    # Get device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Load video
    cap = cv2.VideoCapture(video_path)

    # Get video properties
    fps = cap.get(cv2.CAP_PROP_FPS)
    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    print("Video properties:")
    print("Frame width:", frame_width)
    print("Frame height:", frame_height)
    print("FPS:", fps)

    # Read and process n_samples of frames
    n = 0
    psnr_bicubic_list = []
    psnr_sr_list = []
    ssim_bicubic_list = []
    ssim_sr_list = []
    infer_time_bicubic_list = []
    infer_time_sr_list = []

    while True and n < n_samples:
        # Read frame 
        ret, frame = cap.read()
        
        if np.amax(frame) == 0 and np.amin(frame) == 0:
            continue
        n += 1
        # if n == 1:
        #     continue
        if not ret:
            break

        crop_width, crop_height = crop_size
        # Get cropped region of frame
        x_start = random.randint(0, frame_width - crop_width)
        y_start = random.randint(0, frame_height - crop_height)
        x_end = x_start + crop_width
        y_end = y_start + crop_height

        

        # Crop the frame
        cropped_frame = frame[y_start:y_end, x_start:x_end]

        # Downsample frame
        downsampled_frame = cv2.resize(cropped_frame, (int(cropped_frame.shape[1] / upscale_factor), int(cropped_frame.shape[0] / upscale_factor)))

        # Upsample frame
        torch.cuda.synchronize()
        time.time()
        upscaled_frame_bicubic = cv2.resize(downsampled_frame, None, fx=upscale_factor, fy=upscale_factor, interpolation=cv2.INTER_CUBIC)
        torch.cuda.synchronize()
        infer_time_bicubic = time.time() - time.time()
        # print(frame.shape, cropped_frame.shape, downsampled_frame.shape, upscaled_frame_bicubic.shape)

        # Convert frame to tensor
        frame_rgb = cv2.cvtColor(downsampled_frame, cv2.COLOR_BGR2RGB)
        frame_tensor = transform(frame_rgb).unsqueeze(0).to(device)

        # Perform super-resolution on the frame
        torch.cuda.synchronize()
        start = time.time()
        with torch.no_grad():
            upscaled_frame_sr = model(frame_tensor)
            upscaled_frame_sr = torch.clamp(upscaled_frame_sr, 0, 1)
        torch.cuda.synchronize()
        infer_time_sr = time.time() - start

        # Convert tensor back to numpy array
        upscaled_frame_sr = (upscaled_frame_sr.squeeze().cpu().numpy().transpose(1, 2, 0))# * 255).astype(np.uint8)
        upscaled_frame_sr = cv2.cvtColor(upscaled_frame_sr, cv2.COLOR_RGB2BGR)
        
        # Calculate PSNR
        psnr_bicubic = compare_psnr(cropped_frame, upscaled_frame_bicubic)
        psnr_sr = compare_psnr(cropped_frame/255, upscaled_frame_sr)

        # Calculate SSIM
        ssim_bicubic = compare_ssim(cropped_frame, upscaled_frame_bicubic, channel_axis=-1, data_range=1, multichannel=True)
        ssim_sr = compare_ssim(cropped_frame/255, upscaled_frame_sr, channel_axis=-1, data_range=1, multichannel=True)

        print(psnr_bicubic, psnr_sr, ssim_bicubic, ssim_sr, infer_time_bicubic, infer_time_sr)

        psnr_bicubic_list.append(psnr_bicubic)
        psnr_sr_list.append(psnr_sr)
        ssim_bicubic_list.append(ssim_bicubic)
        ssim_sr_list.append(ssim_sr)
        infer_time_bicubic_list.append(infer_time_bicubic)
        infer_time_sr_list.append(infer_time_sr)

        # Display the upscaled frame by concatenating the original and upscaled frames
        # top = np.concatenate((cropped_frame/255, cropped_frame/255), axis=1)
        # bottom = np.concatenate((upscaled_frame_bicubic/255, upscaled_frame_sr), axis=1)
        # combined = np.concatenate((top, bottom), axis=0)
        # cv2.imshow('frame', combined)
        
        # # cv2.imshow('frame', upscaled_frame_sr)
        # if cv2.waitKey(1) & 0xFF == ord('q'):
        #     break
        

    # Calculate average PSNR and SSIM only for values that are not Nan or Inf
    psnr_bicubic_list = [psnr for psnr in psnr_bicubic_list if not np.isnan(psnr) and not np.isinf(psnr)]
    psnr_sr_list = [psnr for psnr in psnr_sr_list if not np.isnan(psnr) and not np.isinf(psnr)]
    ssim_bicubic_list = [ssim for ssim in ssim_bicubic_list if not np.isnan(ssim) and not np.isinf(ssim)]
    ssim_sr_list = [ssim for ssim in ssim_sr_list if not np.isnan(ssim) and not np.isinf(ssim)]
    infer_time_bicubic_list = [infer_time for infer_time in infer_time_bicubic_list if not np.isnan(infer_time) and not np.isinf(infer_time)]
    infer_time_sr_list = [infer_time for infer_time in infer_time_sr_list if not np.isnan(infer_time) and not np.isinf(infer_time)]

    avg_psnr_bicubic = np.mean(psnr_bicubic_list)
    avg_psnr_sr = np.mean(psnr_sr_list)
    avg_ssim_bicubic = np.mean(ssim_bicubic_list)
    avg_ssim_sr = np.mean(ssim_sr_list)
    avg_infer_time_bicubic = np.mean(infer_time_bicubic_list)
    avg_infer_time_sr = np.mean(infer_time_sr_list)


    return avg_psnr_bicubic, avg_psnr_sr, avg_ssim_bicubic, avg_ssim_sr, avg_infer_time_bicubic, avg_infer_time_sr


video_path = "./test_videos/4K ULtra HD ｜ SAMSUNG UHD Demo׃ LED TV [R3GfuzLMPkA].mp4"
avg_psnr_bicubic, avg_psnr_sr, avg_ssim_bicubic, avg_ssim_sr, avg_infer_time_bicubic, avg_infer_time_sr = evaluate_model(video_path, model, transform=transform, n_samples=5
                                                                                                                         , crop_size=[3840,2160])        
print("Average PSNR (bicubic):", avg_psnr_bicubic)
print("Average PSNR (SR):", avg_psnr_sr)
print("Average SSIM (bicubic):", avg_ssim_bicubic)
print("Average SSIM (SR):", avg_ssim_sr)
print("Average inference time (bicubic):", avg_infer_time_bicubic)
print("Average inference time (SR):", avg_infer_time_sr)

Video properties:
Frame width: 3840
Frame height: 2160
FPS: 30.00101354775499


  psnr_sr = compare_psnr(cropped_frame/255, upscaled_frame_sr)


43.29567274080101 40.427541713214936 0.8155311144817291 0.9857818057670306 0.0 0.015794754028320312
43.68689957231816 40.63730226354383 0.8169425584459106 0.98642192523354 0.0 0.03543710708618164
43.83924868755939 40.73180058732581 0.8155399466915778 0.9864480328324094 0.0 0.1671442985534668
43.83541550455777 40.72340917135741 0.8133725763913162 0.9863034814045518 0.0 0.02533411979675293
44.01337951038991 40.78098481045388 0.8135616657382485 0.98638831772638 0.0 0.0
Average PSNR (bicubic): 43.73412320312524
Average PSNR (SR): 40.660207709179176
Average SSIM (bicubic): 0.8149895723497564
Average SSIM (SR): 0.9862687125927824
Average inference time (bicubic): 0.0
Average inference time (SR): 0.048742055892944336


In [None]:
cap = cv2.VideoCapture(video_path)

# Get video properties
fps = cap.get(cv2.CAP_PROP_FPS)
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

while True:
    ret, frame = cap.read()
    if not ret:
        break
    cv2.imshow('frame', frame)
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break
