## Video Inference

In this notebook, the wavemixSR model is used to perform video inference on a sample video. The video is first split into frames, and then each frame is passed through the model to generate a super-resolved image. The super-resolved images are then combined to form a video.

In [22]:
import cv2
from pathlib import Path
import random
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torchvision
import os
import numpy as np
import wavemix.sisr as sisr
from PIL import Image
from torchinfo import summary
import gc
import json

In [23]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [24]:
video1 = "Kyoto 360p.mp4"
video2 = "f1 360p.mp4"

In [25]:
class WaveMixSR(nn.Module):
    def __init__(
        self,
        *,
        depth,
        mult = 1,
        ff_channel = 16,
        final_dim = 16,
        dropout = 0.3,
        scale_factor = 2
    ):
        super().__init__()
        
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(sisr.Level1Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout))
        
        self.final = nn.Sequential(
            nn.Conv2d(final_dim,int(final_dim/2), 3, stride=1, padding=1),
            nn.Conv2d(int(final_dim/2), 1, 1)
        )


        self.path1 = nn.Sequential(
            nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners = False),
            nn.Conv2d(1, int(final_dim/2), 3, 1, 1),
            nn.Conv2d(int(final_dim/2), final_dim, 3, 1, 1)
        )

        self.path2 = nn.Sequential(
            nn.Upsample(scale_factor=int(scale_factor), mode='bilinear', align_corners = False),
        )

    def forward(self, img):

        y = img[:, 0:1, :, :] 
        crcb = img[:, 1:3, :, :]

        y = self.path1(y)


        for attn in self.layers:
            y = attn(y) + y

        y = self.final(y)

        crcb = self.path2(crcb)
        
        return  torch.cat((y,crcb), dim=1)

In [26]:
weights = torch.load('weights.pth', map_location=device)
model = WaveMixSR(depth = 4, mult = 1, ff_channel = 144, final_dim = 144, dropout = 0.3, scale_factor = 3).to(device)
model.load_state_dict(weights)
model.eval()

  weights = torch.load('weights.pth', map_location=device)


WaveMixSR(
  (layers): ModuleList(
    (0-3): 4 x Level1Waveblock(
      (feedforward): Sequential(
        (0): Conv2d(144, 144, kernel_size=(1, 1), stride=(1, 1))
        (1): GELU(approximate='none')
        (2): Dropout(p=0.3, inplace=False)
        (3): Conv2d(144, 144, kernel_size=(1, 1), stride=(1, 1))
        (4): ConvTranspose2d(144, 144, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (5): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (reduction): Conv2d(144, 36, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (final): Sequential(
    (0): Conv2d(144, 72, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Conv2d(72, 1, kernel_size=(1, 1), stride=(1, 1))
  )
  (path1): Sequential(
    (0): Upsample(scale_factor=3.0, mode='bilinear')
    (1): Conv2d(1, 72, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): Conv2d(72, 144, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (path2): Sequential(
   

In [31]:
def process_video(input_video_path, output_video_path, model, device):
    cap = cv2.VideoCapture(input_video_path)
    if not cap.isOpened():
        print("Error: Could not open video.")
        return
    
    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    
    fourcc = cv2.VideoWriter_fourcc(*'mp4v') 
    out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width*3, frame_height*3))
    
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break 
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame_tensor = torch.from_numpy(frame_rgb).permute(2, 0, 1).unsqueeze(0).float().to(device) / 255.0
        with torch.no_grad():
            output_tensor = model(frame_tensor)
        output_frame = output_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy() * 255.0
        output_frame = np.clip(output_frame, 0, 255).astype(np.uint8)
        output_frame_bgr = cv2.cvtColor(output_frame, cv2.COLOR_RGB2BGR)
        out.write(output_frame_bgr)
        
    cap.release()
    out.release()
    cv2.destroyAllWindows()

In [32]:
process_video(video1, "Kyoto 1080p_model.mp4", model, device)