<a href="https://colab.research.google.com/github/Nick088Official/Real-ESRGAN-Fix/blob/master/Real_ESRGAN_Fix_Inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Real-ESRGAN Pytorch Inference NO UI

[![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2107.10833)
[![GitHub Stars](https://img.shields.io/github/stars/xinntao/Real-ESRGAN?style=social)](https://github.com/xinntao/Real-ESRGAN)
[![download](https://img.shields.io/github/downloads/xinntao/Real-ESRGAN/total.svg)](https://github.com/xinntao/Real-ESRGAN/releases)

This is a **Practical Image Restoration Demo** of our paper [''Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data''](https://arxiv.org/abs/2107.10833).
We extend the powerful ESRGAN to a practical restoration application (namely, Real-ESRGAN), which is trained with pure synthetic data. <br>
The following figure shows some real-life examples.

<img src="https://raw.githubusercontent.com/xinntao/Real-ESRGAN/master/assets/teaser.jpg" width="100%">

**Note that Real-ESRGAN may still fail in some cases as the real-world degradations are really too complex.**<br>
Moreover, it **may not** perform well on **human faces, you could try [GFPGAN](https://colab.research.google.com/github/Nick088Official/GFPGAN-Fix/blob/master/GFPGAN_fix_inference.ipynb) for that
<br>

This is the Pytorch Implementation of https://github.com/ai-forever/Real-ESRGAN

**Credits:** [Nick088](https://linktr.ee/Nick088), Xinntao, Tencent, Geeve George, ai-forever

In [None]:
#@title Installation
#@markdown Before running, make sure that you choose
#@markdown * Runtime Type = Python 3
#@markdown * Hardware Accelerator = GPU (Faster, free daily limit of gpu around 12 hours) or CPU (very slow)

#@markdown in the **Runtime** menu -> **Change runtime type**

#@markdown By running this, we clone the repository, set up the envrironment.


from IPython.display import clear_output
import torch

if torch.cuda.is_available():
    device = "cuda"
    print("Using GPU")
else:
    device = "cpu"
    print("Using CPU, you won't be able to use v1 model")

!pip install git+https://github.com/sberbank-ai/Real-ESRGAN.git
%cd Real-ESRGAN
clear_output()
print("Installed!")

In [None]:
#@title Choose Model
from RealESRGAN import RealESRGAN
from PIL import Image
import numpy as np
import torch

model_scale = "4" #@param ["2", "4", "8"] {allow-input: false}

model = RealESRGAN(device, scale=int(model_scale))
model.load_weights(f'weights/RealESRGAN_x{model_scale}.pth')

# Inference Images

In [None]:
#@title Upload & upscale Images or .Tar Archives
import os
from google.colab import files
import shutil
from io import BytesIO
import io
import tarfile

upload_folder = 'inputs'
result_folder = 'results'

os.makedirs(upload_folder, exist_ok=True)
os.makedirs(result_folder, exist_ok=True)

IMAGE_FORMATS = ('.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif')

def image_to_tar_format(img, image_name):
    buff = BytesIO()
    if '.png' in image_name.lower():
        img = img.convert('RGBA')
        img.save(buff, format='PNG')
    else:
        img.save(buff, format='JPEG')
    buff.seek(0)
    fp = io.BufferedReader(buff)
    img_tar_info = tarfile.TarInfo(name=image_name)
    img_tar_info.size = len(buff.getvalue())
    return img_tar_info, fp

def process_tar(path_to_tar):
    processing_tar = tarfile.open(path_to_tar, mode='r')
    result_tar_path = os.path.join('results/', os.path.basename(path_to_tar))
    save_tar = tarfile.open(result_tar_path, 'w')

    for c, member in enumerate(processing_tar):
        print(f'{c}, processing {member.name}')

        if not member.name.endswith(IMAGE_FORMATS):
            continue

        try:
            img_bytes = BytesIO(processing_tar.extractfile(member.name).read())
            img_lr = Image.open(img_bytes, mode='r').convert('RGB')
        except Exception as err:
            print(f'Unable to open file {member.name}, skipping')
            continue

        img_sr = model.predict(np.array(img_lr))
        # adding to save_tar
        img_tar_info, fp = image_to_tar_format(img_sr, member.name)
        save_tar.addfile(img_tar_info, fp)

    processing_tar.close()
    save_tar.close()
    print(f'Finished! Archive saved to {result_tar_path}')

def process_input(filename):
    if tarfile.is_tarfile(filename):
        process_tar(filename)
    else:
        result_image_path = os.path.join('results/', os.path.basename(filename))
        image = Image.open(filename).convert('RGB')
        sr_image = model.predict(np.array(image))
        sr_image.save(result_image_path)
        print(f'Finished! Image saved to {result_image_path}')

# upload files
uploaded = files.upload()
for filename in uploaded.keys():
    dst_path = os.path.join(upload_folder, filename)
    shutil.move(filename, dst_path)
    print(f'moved input {filename} to {dst_path}')
    print('Processing:', filename)
    process_input(f"inputs/{filename}")

In [None]:
#@title Visualize Input vs Output Image
# utils for visualization
import cv2
import matplotlib.pyplot as plt
def display(img1, img2):
  fig = plt.figure(figsize=(25, 10))
  ax1 = fig.add_subplot(1, 2, 1)
  plt.title('Input image', fontsize=16)
  ax1.axis('off')
  ax2 = fig.add_subplot(1, 2, 2)
  plt.title(f'Real-ESRGAN Pytorch x{model_scale} output', fontsize=16)
  ax2.axis('off')
  ax1.imshow(img1)
  ax2.imshow(img2)
def imread(img_path):
  img = cv2.imread(img_path)
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  return img

# display each image in the upload folder
import os
import glob

input_folder = 'inputs'
result_folder = 'results'
input_list = sorted(glob.glob(os.path.join(input_folder, '*')))
output_list = sorted(glob.glob(os.path.join(result_folder, '*')))
for input_path, output_path in zip(input_list, output_list):
  img_input = imread(input_path)
  img_output = imread(output_path)
  display(img_input, img_output)

In [None]:
#@title Download the results
zip_filename = f'Real-ESRGAN-Pytorch-x{model_scale}_result.zip'
if os.path.exists(zip_filename):
  os.remove(zip_filename)
os.system(f"zip -r -j {zip_filename} results/*")
files.download(zip_filename)

# Inference Videos

In [None]:
#@title Upload Input Video
import os
from google.colab import files
import shutil
import cv2
from google.colab.patches import cv2_imshow

upload_folder = 'upload'
result_folder = 'results'
video_folder = 'videos'
video_result_folder = 'results_videos'
video_mp4_result_folder = 'results_mp4_videos'
result_restored_imgs_folder = 'restored_imgs'

if os.path.isdir(upload_folder):
  print(upload_folder+" exists")
else :
  os.mkdir(upload_folder)

if os.path.isdir(video_folder):
  print(video_folder+" exists")
else :
  os.mkdir(video_folder)

if os.path.isdir(video_result_folder):
  print(video_result_folder+" exists")
else :
  os.mkdir(video_result_folder)

if os.path.isdir(video_mp4_result_folder):
  print(video_mp4_result_folder+" exists")
else :
  os.mkdir(video_mp4_result_folder)

if os.path.isdir(result_folder):
  print(result_folder+" exists")
else :
  os.mkdir(result_folder)

if os.path.isdir(result_restored_imgs_folder):
  print(result_restored_imgs_folder+" exists")
else :
  %cd results
  os.mkdir(result_restored_imgs_folder)
  %cd ..

if os.path.isdir(video_folder):
    shutil.rmtree(video_folder)
os.mkdir(video_folder)

# upload images
uploaded = files.upload()
for filename in uploaded.keys():
  dst_path = os.path.join(video_folder, filename)
  print(f'moved {filename} to {dst_path}')


In [None]:
#@title Inference Video
import cv2
import numpy as np
import glob
from os.path import isfile, join
import subprocess
from IPython.display import clear_output
import os
from google.colab import files
import shutil
from io import BytesIO
import io
import tarfile


IMAGE_FORMATS = ('.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif')

def image_to_tar_format(img, image_name):
    buff = BytesIO()
    if '.png' in image_name.lower():
        img = img.convert('RGBA')
        img.save(buff, format='PNG')
    else:
        img.save(buff, format='JPEG')
    buff.seek(0)
    fp = io.BufferedReader(buff)
    img_tar_info = tarfile.TarInfo(name=image_name)
    img_tar_info.size = len(buff.getvalue())
    return img_tar_info, fp

def process_tar(path_to_tar):
    processing_tar = tarfile.open(path_to_tar, mode='r')
    result_tar_path = os.path.join('results/', os.path.basename(path_to_tar))
    save_tar = tarfile.open(result_tar_path, 'w')

    for c, member in enumerate(processing_tar):
        print(f'{c}, processing {member.name}')

        if not member.name.endswith(IMAGE_FORMATS):
            continue

        try:
            img_bytes = BytesIO(processing_tar.extractfile(member.name).read())
            img_lr = Image.open(img_bytes, mode='r').convert('RGB')
        except Exception as err:
            print(f'Unable to open file {member.name}, skipping')
            continue

        img_sr = model.predict(np.array(img_lr))
        # adding to save_tar
        img_tar_info, fp = image_to_tar_format(img_sr, member.name)
        save_tar.addfile(img_tar_info, fp)

    processing_tar.close()
    save_tar.close()
    print(f'Finished! Archive saved to {result_tar_path}')

def process_input(filename):
    if tarfile.is_tarfile(filename):
        process_tar(filename)
    else:
        result_image_path = os.path.join('results/restored_imgs', os.path.basename(filename))
        image = Image.open(filename).convert('RGB')
        sr_image = model.predict(np.array(image))
        sr_image.save(result_image_path)
        print(f'Finished! Image saved to {result_image_path}')


# assign directory
directory = 'videos' #PATH_WITH_INPUT_VIDEOS
zee = 0

def convert_frames_to_video(pathIn,pathOut,fps):
    frame_array = []
    files = [f for f in os.listdir(pathIn) if isfile(join(pathIn, f))]
    #for sorting the file names properly
    files.sort(key = lambda x: int(x[5:-4]))
    size2 = (0,0)

    for i in range(len(files)):
        filename=pathIn + files[i]
        #reading each files
        img = cv2.imread(filename)
        height, width, layers = img.shape
        size = (width,height)
        size2 = size
        print(filename)
        #inserting the frames into an image array
        frame_array.append(img)
    out = cv2.VideoWriter(pathOut,cv2.VideoWriter_fourcc(*'DIVX'), fps, size2)
    for i in range(len(frame_array)):
        # writing to a image array
        out.write(frame_array[i])
    out.release()


for filename in os.listdir(directory):

    f = os.path.join(directory, filename)
    # checking if it is a file
    if os.path.isfile(f):


      print("PROCESSING :"+str(f)+"\n")
      # Read the video from specified path

      #video to frames
      cam = cv2.VideoCapture(str(f))

      try:

          # PATH TO STORE VIDEO FRAMES
          if not os.path.exists('upload'):
              os.makedirs('upload')

      # if not created then raise error
      except OSError:
          print ('Error: Creating directory of data')

      # frame
      currentframe = 0


      while(True):

          # reading from frame
          ret,frame = cam.read()

          if ret:
              # if video is still left continue creating images
              name = 'upload/frame' + str(currentframe) + '.jpg'

              # writing the extracted images
              cv2.imwrite(name, frame)


                # increasing counter so that it will
                # show how many frames are created
              currentframe += 1
              print(currentframe)
          else:
              #deletes all the videos you uploaded for upscaling
              #for f in os.listdir(video_folder):
              #  os.remove(os.path.join(video_folder, f))

              break

        # Release all space and windows once done
      cam.release()
      cv2.destroyAllWindows()

      #apply super-resolution on all frames of a video

      # Specify the directory path
      all_frames_path = "upload"

      # Get a list of all files in the directory
      file_names = os.listdir(all_frames_path)

      # process the files
      for file_name in file_names:
        process_input(f"upload/{file_name}")


      #convert super res frames to .avi
      pathIn = 'results/restored_imgs/'

      zee = zee+1
      fName = "video"+str(zee)
      filenameVid = f"{fName}.avi"

      pathOut = "results_videos/"+filenameVid

      fps = 25.0 #change this to FPS of your source video

      convert_frames_to_video(pathIn, pathOut, fps)


      #convert .avi to .mp4
      src = 'results_videos/'
      dst = 'results_mp4_videos/'

      for root, dirs, filenames in os.walk(src, topdown=False):
          #print(filenames)
          for filename in filenames:
              print('[INFO] 1',filename)
              try:
                  _format = ''
                  if ".flv" in filename.lower():
                      _format=".flv"
                  if ".mp4" in filename.lower():
                      _format=".mp4"
                  if ".avi" in filename.lower():
                      _format=".avi"
                  if ".mov" in filename.lower():
                      _format=".mov"

                  inputfile = os.path.join(root, filename)
                  print('[INFO] 1',inputfile)
                  outputfile = os.path.join(dst, filename.lower().replace(_format, ".mp4"))
                  subprocess.call(['ffmpeg', '-i', inputfile, outputfile])
              except:
                  print("An exception occurred")

In [None]:
#@title Download Video & Frames Results
#@markdown It will download you a .zip, unzip it however you want

#@markdown It will have 2 folders:

#@markdown 1. 'results' , whhich has inside 'restored_imgs' = the results of the frames of the video

#@markdown 2. 'results_mp4_videos' = the results of the video after the face has been restored (so basically all the frames together)

# download the result
!ls results
print('Download results')
os.system(f'zip -r download.zip results/restored_imgs results_mp4_videos')
files.download("download.zip")