## GIF Super-resolver
This notebook is for GIF super-resolution. Meant to be shown at the project presentation.

In [1]:
# Imports and constants

import numpy as np
import os
import cv2
from time import time

import torch
import torch.backends.cudnn as cudnn

from model import ELSR
from PIL import Image
from preprocessing import prepare_img

In [5]:
# Super-resolution function

def super_resolve(model, device, video):

    gif = []
    for frame in video:
        frame = prepare_img(frame, device)
        gif.append(frame)      

    sr_video = []

    t0 = time()

    with torch.no_grad():
        for frame in gif:
            sr_img = model(frame).clamp(0, 1)
            sr_video.append(sr_img)

    t = time()-t0
    print(f"FPS: {5/t}")

    return sr_video

In [6]:
# Get gif frames

SCALE = 4
WEIGHTS = './checkpoints/best_X4_model.pth'
INPUT = './test/gif/test.gif'

cudnn.benchmark = True
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = ELSR(upscale_factor=SCALE).to(device)

state_dict = torch.load(WEIGHTS)
model.load_state_dict(state_dict=state_dict)
model.eval()

with Image.open(INPUT) as im:
    for i in range(im.n_frames):
        im.seek(i)
        im.save(f'./test/gif/frames/{i}.png')


gif = []
frames_path = "./test/gif/frames/"
for frame_path in os.listdir(frames_path):
    frame = cv2.cvtColor(cv2.imread(os.path.join(frames_path, frame_path)), cv2.COLOR_BGR2RGB)
    gif.append(frame)

In [27]:
# Super-resolve gif

sr_gif = super_resolve(model, device, gif)

FPS: 2534.016433059449


In [28]:
# Save output gif

gif = []
for i, sr_img in enumerate(sr_gif):
    out = sr_img.cpu().numpy().squeeze(0).transpose(1, 2, 0)
    gif.append(Image.fromarray((out*255).astype(np.uint8)))
gif[0].save('test/gif/sr_test.gif', save_all=True, append_images=gif[1:], loop=0, duration=10)