## Video Inference

In this notebook, the wavemixSR model is used to perform video inference on a sample video. The video is first split into frames, and then each frame is passed through the model to generate a super-resolved image. The super-resolved images are then combined to form a video. This approach is not recommended as a model trained for  image super-resolution may not perform well on video data. However, this notebook is provided to demonstrate how to perform video inference using the wavemixSR model.


In [1]:
import cv2
from pathlib import Path
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torchvision
import numpy as np
import wavemix.sisr as sisr
import kornia

  @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
video1 = "Kyoto 360p.mp4"
video2 = "f1 360p.mp4"

In [None]:
class WaveMixSR(nn.Module):
    def __init__(
        self,
        *,
        depth,
        mult = 1,
        ff_channel = 16,
        final_dim = 16,
        dropout = 0.3,
        scale_factor = 2
    ):
        super().__init__()
        
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(sisr.Level1Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout))
        
        self.final = nn.Sequential(
            nn.Conv2d(final_dim,int(final_dim/2), 3, stride=1, padding=1),
            nn.Conv2d(int(final_dim/2), 1, 1)
        )


        self.path1 = nn.Sequential(
            nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners = False),
            nn.Conv2d(1, int(final_dim/2), 3, 1, 1),
            nn.Conv2d(int(final_dim/2), final_dim, 3, 1, 1)
        )

        self.path2 = nn.Sequential(
            nn.Upsample(scale_factor=int(scale_factor), mode='bilinear', align_corners = False),
        )

    def forward(self, img):

        y = img[:, 0:1, :, :] 
        crcb = img[:, 1:3, :, :]

        y = self.path1(y)


        for attn in self.layers:
            y = attn(y) + y

        y = self.final(y)

        crcb = self.path2(crcb)
        
        return  torch.cat((y,crcb), dim=1)

In [None]:
weights = torch.load('weights.pth', map_location=device)
model = WaveMixSR(depth = 4, mult = 1, ff_channel = 144, final_dim = 144, dropout = 0.3, scale_factor = 2).to(device)
model.load_state_dict(weights)
model.eval()

In [None]:
transform_target = torchvision.transforms.Compose(
        [   torchvision.transforms.ToTensor(),
     ])

In [None]:
def process_video(input_video_path, output_video_path, model, device):
    cap = cv2.VideoCapture(input_video_path)
    if not cap.isOpened():
        print("Error: Could not open video.")
        return
    
    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    
    fourcc = cv2.VideoWriter_fourcc(*'mp4v') 
    out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width*2, frame_height*2))
    i = 0
    while cap.isOpened():
        print(i)
        ret, frame = cap.read()
        if not ret:
            break
        if i == 1000:
            break
        frame = transform_target(frame)
        frame = kornia.color.bgr_to_rgb(frame)
        frame = kornia.color.rgb_to_ycbcr(frame)
        frame = frame.unsqueeze(0).to(device)
        with torch.no_grad():
            output_tensor = model(frame)
        output_tensor = kornia.color.ycbcr_to_rgb(output_tensor)
        output_frame = output_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy() * 255.0
        output_frame = np.clip(output_frame, 0, 255).astype(np.uint8)
        output_frame_bgr = cv2.cvtColor(output_frame, cv2.COLOR_RGB2BGR)
        out.write(output_frame_bgr)
        i = i + 1
    cap.release()
    out.release()

In [None]:
process_video(video1, "Kyoto 720p_model.mp4", model, device)

In [None]:
process_video(video2, "f1 720p_model.mp4", model, device)

In [None]:
def process_video_bicubic(input_video_path, output_video_path):
    cap = cv2.VideoCapture(input_video_path)
    if not cap.isOpened():
        print("Error: Could not open video.")
        return
    
    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    
    fourcc = cv2.VideoWriter_fourcc(*'mp4v') 
    out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width*2, frame_height*2))
    i = 0
    while cap.isOpened():
        print(i)
        ret, frame = cap.read()
        if not ret:
            break
        if i == 1000:
            break
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        output_frame = cv2.resize(frame_rgb, (frame_width*2, frame_height*2), interpolation=cv2.INTER_CUBIC)
        output_frame_bgr = cv2.cvtColor(output_frame, cv2.COLOR_RGB2BGR)
        out.write(output_frame_bgr)
        i = i + 1
    cap.release()
    out.release()

In [None]:
process_video_bicubic(video1, "Kyoto 720p_bicubic.mp4")

In [None]:
process_video_bicubic(video2, "f1 720p_bicubic.mp4")