# Wrapper for SwinIR

> I have taken the code from [__Wrapper for SwinIR__](https://github.com/Lin-Sinorodin/SwinIR_wrapper) repository. This wrapper based on the official PyTorch implementation of [SwinIR: Image Restoration Using Shifted Window Transformer](https://arxiv.org/abs/2108.10257).




In [None]:
#@title Imports and Utils { display-mode: "form" }

%matplotlib inline
%config InlineBackend.figure_format = 'svg'
%config InlineBackend.rc = {'figure.figsize': (10.0, 10.0)}

!git clone -qq https://github.com/Lin-Sinorodin/SwinIR_wrapper.git
!pip install -qq timm
from SwinIR_wrapper.SwinIR_wrapper import SwinIR_SR

import cv2
import torch
import urllib.request
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available():
    print(f'Using GPU: {torch.cuda.get_device_properties(0).name}')
else:
    print('Using CPU. Concider using GPU for faster inference.')

def compare_sr_with_original(img_lq, img_hq):
    plt.figure()

    plt.subplot(1, 2, 1)
    plt.imshow(img_lq[::,::,::-1])
    plt.title(f'Original - {img_lq.shape}')
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(img_hq[::,::,::-1])
    plt.title(f'Super Resolution - {img_hq.shape}')
    plt.axis('off')

    plt.tight_layout()
    plt.show()

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m24.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m302.0/302.0 kB[0m [31m38.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m89.6 MB/s[0m eta [36m0:00:00[0m
[?25h

  from tqdm.autonotebook import tqdm


Using GPU: Tesla T4


In [None]:
import cv2

video_path = '/content/drive/MyDrive/test.mp4'
output_folder = r'/content/drive/MyDrive/frames/'

cap = cv2.VideoCapture(video_path)
frame_count = 0

if not cap.isOpened():
    print("Error: Video file not found or cannot be opened.")
else:
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        frame_count += 1
        frame_filename = f"{output_folder}frame_{frame_count:04d}.jpg"
        print(frame_filename)
        cv2.imwrite(frame_filename, frame)
        if not cv2.imwrite(frame_filename, frame):
          raise Exception("Could not write image")

    cap.release()
    print(f"Extracted {frame_count} frames.")

/content/drive/MyDrive/frames/frame_0001.jpg
/content/drive/MyDrive/frames/frame_0002.jpg
/content/drive/MyDrive/frames/frame_0003.jpg
/content/drive/MyDrive/frames/frame_0004.jpg
/content/drive/MyDrive/frames/frame_0005.jpg
/content/drive/MyDrive/frames/frame_0006.jpg
/content/drive/MyDrive/frames/frame_0007.jpg
/content/drive/MyDrive/frames/frame_0008.jpg
/content/drive/MyDrive/frames/frame_0009.jpg
/content/drive/MyDrive/frames/frame_0010.jpg
/content/drive/MyDrive/frames/frame_0011.jpg
/content/drive/MyDrive/frames/frame_0012.jpg
/content/drive/MyDrive/frames/frame_0013.jpg
/content/drive/MyDrive/frames/frame_0014.jpg
/content/drive/MyDrive/frames/frame_0015.jpg
/content/drive/MyDrive/frames/frame_0016.jpg
/content/drive/MyDrive/frames/frame_0017.jpg
/content/drive/MyDrive/frames/frame_0018.jpg
/content/drive/MyDrive/frames/frame_0019.jpg
/content/drive/MyDrive/frames/frame_0020.jpg
/content/drive/MyDrive/frames/frame_0021.jpg
/content/drive/MyDrive/frames/frame_0022.jpg
/content/d

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive




---





---



In [None]:
#@title Setup Super Resolution Model { run: "auto" }
pretrained_model = "real_sr x4" # ["real_sr x4", "classical_sr x2", "classical_sr x3", "classical_sr x4", "classical_sr x8", "lightweight x2", "lightweight x3", "lightweight x4"]

model_type, scale = pretrained_model.split(' ')
scale = int(scale[1])

# initialize super resolution model
sr = SwinIR_SR(model_type, scale)

print(f'Loaded {pretrained_model} successfully')

downloading weights to /content/SwinIR_wrapper/SwinIR_wrapper/weights/003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.pth
Loaded real_sr x4 successfully




---



## Use The Model

> If the image size isn't too big, the image can be fed to the model as it is However, for larger images the GPU will run out of memory. A simple solution is to use the model on smaller patches and then combine the results for a large image.

### Directly - `sr.upscale()`


In [None]:


# show results
    #compare_sr_with_original(img_lq, img_hq)
import cv2
import os
import glob


input_folder='/content/drive/MyDrive/frames'
output_folder = '/content/drive/MyDrive/output/'  # Adjust the folder name if needed

# Create the output folder if it doesn't exist
os.makedirs(output_folder, exist_ok=True)

# Process all images in the input folder
frame_files = sorted(glob.glob(os.path.join(input_folder, '*.jpg')))  # Adjust the file extension if needed
 # Adjust the file extension if needed

for frame_file in frame_files:
    # Load the input image
    img_lq = cv2.imread(frame_file, cv2.IMREAD_COLOR)

    # Perform super-resolution
    img_hq = sr.upscale(img_lq)

    # Save the high-quality image to the output folder
    output_filename = os.path.join(output_folder, os.path.basename(frame_file))
    cv2.imwrite(output_filename, img_hq)


In [None]:
import cv2
import os
import glob

# Path to the folder containing image frames
frame_folder = '/content/drive/MyDrive/output'  # Update with the correct folder path
output_video_path = '/content/drive/MyDrive/reconstructed_video.mp4'  # Path for the output video

# List the image frame files in the folder
frame_files = sorted(glob.glob(os.path.join(frame_folder, '*.jpg')))  # Adjust the file extension if needed

# Check if there are any frames
if not frame_files:
    print("No frames found in the specified folder.")
else:
    # Read the first frame to get frame dimensions
    frame = cv2.imread(frame_files[0])
    frame_height, frame_width, channels = frame.shape

    # Define the codec and create a VideoWriter object
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # Video codec, change as needed
    output_video = cv2.VideoWriter(output_video_path, fourcc, 30.0, (frame_width, frame_height))

    # Write each frame to the video
    for frame_file in frame_files:
        frame = cv2.imread(frame_file)
        output_video.write(frame)

    # Release the VideoWriter
    output_video.release()

    print(f"Reconstructed video saved to {output_video_path}")


Reconstructed video saved to /content/drive/MyDrive/reconstructed_video.mp4


In [1]:
!pip install opencv-python-headless
!pip install scikit-video


Collecting scikit-video
  Downloading scikit_video-1.1.11-py2.py3-none-any.whl (2.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m19.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: scikit-video
Successfully installed scikit-video-1.1.11


In [6]:
from moviepy.editor import VideoFileClip, AudioFileClip


swinir_video = VideoFileClip("/content/drive/MyDrive/reconstructed_video.mp4")
original_audio = AudioFileClip("/content/drive/MyDrive/test.mp4")
video_with_audio = swinir_video.set_audio(original_audio)
video_with_audio.write_videofile("/content/drive/MyDrive/swinir.mp4", codec='libx264')


Moviepy - Building video /content/drive/MyDrive/swinir.mp4.
MoviePy - Writing audio in swinirTEMP_MPY_wvf_snd.mp3




MoviePy - Done.
Moviepy - Writing video /content/drive/MyDrive/swinir.mp4





Moviepy - Done !
Moviepy - video ready /content/drive/MyDrive/swinir.mp4


In [10]:
#@title # **Metrics**
import cv2
import numpy as np
from skimage.metrics import peak_signal_noise_ratio, structural_similarity


original_video = cv2.VideoCapture("/content/drive/MyDrive/test.mp4")
swinir_output_video = cv2.VideoCapture("/content/drive/MyDrive/swinir.mp4")


fps = original_video.get(cv2.CAP_PROP_FPS)


frame_count = min(
    int(original_video.get(cv2.CAP_PROP_FRAME_COUNT)),
    int(swinir_output_video.get(cv2.CAP_PROP_FRAME_COUNT))
)


psnr_values = []
ssim_values = []


for frame_index in range(frame_count):

    original_frame = original_video.read()[1]
    swinir_frame = swinir_output_video.read()[1]


    if original_frame.shape != swinir_frame.shape:
        swinir_frame = cv2.resize(swinir_frame, (original_frame.shape[1], original_frame.shape[0]))


    psnr = peak_signal_noise_ratio(original_frame, swinir_frame)
    ssim = structural_similarity(original_frame, swinir_frame, multichannel=True)


    psnr_values.append(psnr)
    ssim_values.append(ssim)


average_psnr = sum(psnr_values) / len(psnr_values)
average_ssim = sum(ssim_values) / len(ssim_values)

print(f"Average PSNR: {average_psnr:.2f}")
print(f"Average SSIM: {average_ssim:.2f}")





  ssim = structural_similarity(original_frame, swinir_frame, multichannel=True)



Average PSNR: 28.29
Average SSIM: 0.87
