Skip to content

Scthe/SRCNN-PyTorch

Repository files navigation

SRCNN - PyTorch (DEMO)

Currently, there are 2 predominant upscalers on Civitai: Real-ESRGAN and UltraSharp. Both are based on ESRGAN. If you look at any recent paper regarding Super-Resolution, you will see sentences like:

"Since the pioneering work of SRCNN [9], deep convolution neural network (CNN) approaches have brought prosperous developments in the SR field"

-- "Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data" by Wang et. all.

SRCNN? This sounds familiar. In 2015 I wrote an implementation in raw OpenCL that runs on GPU. This repo is a PyTorch reimplementation that I wrote some time ago. I also had a TensorFlow one but seems to be lost in the depths of the hard drive.

Overview

Super-resolution problem tries to upscale the image so that perceived loss of quality is minimal. For example, after scaling with bicubic interpolation it is apparent that some pixels are just smudged together. The question is: can AI do a better job?

Results

gh_image_compare

left: upscaling with bicubic interpolation, right: result of the presented algorithm

gh_image_details

Details closeup - left: upscaling with bicubic interpolation, right: result of the presented algorithm

The current algorithm only upscales the luma, the chroma is preserved as-is. This is a common trick known as chroma subsampling.

Usage

Install dependencies

pip install -r requirements.txt will install the CPU version of PyTorch. If you want to run the code on GPU, use pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu118 (docs). Be wary that it's a much bigger download size and you don't need it - the model is small enough.

Alternatively, you can reuse packages from your kohya_ss or Stable Diffusion web UI. Add it to Python's path:

def inject_path():
    import sys
    sys.path.append('C:/programs/install/kohya_ss/venv/Lib/site-packages') # your path here
inject_path() # call this fn from both 'main.py' and 'gen_samples.py'

Training

  1. Put some images into ./images_raw.
  2. python gen_samples.py 400 -64. Generate 400 image pairs (one image was downscaled, the other is the original size). Each image is 64x64 px and it's stored in ./samples_gen_64.
  3. python main.py train --cpu -n -e 20 -i "samples_gen_64". Run training for 20 epochs using samples from ./samples_gen_64. By default:
    • The program will use GPU if appropriate PyTorch is installed. Use --cpu flag to force to use the CPU (even if you have GPU-capable PyTorch).
    • The program will continue from the last checkpoint (stored in ./models). Use -n to start from scratch.

First, we need to generate training data. ./gen_samples.py reads images from ./images_raw and randomly crops 32x32 px (or 64x64 px with -64) patches. They will be stored as e.g. ./samples_gen_64/0b0mkhrd.large.png. We also generate corresponding ./samples_gen_64/0b0mkhrd.small.png. It's done by downscaling and upscaling the cropped image. Our goal is to learn how to turn the blurred small image into the sharp one.

If you want to get something good enough, the training will take a few minutes at most, even on the CPU.

After training, the model is saved to e.g. ./models/srcnn_model.2024-02-27--23-43-05.pt

Inference

  • python main.py upscale -i "<some path here>/image_to_upscale.jpg". Run main.py with -i set to your image.

The program will automatically separate luma, run upscale, and reassemble the final image. The --cpu flag works here too. By default, it will use the latest model from the ./models directory.

The result is stored to e.g. './images_upscaled/<your_image_name>.2024-02-27--23-43-27.png'.

Web demo

The PyTorch model was exported to ONNX file. This allows inference in the web browser. Unfortunately, ONNX runtime on the web has errors that prevent using GPU backends (WebGPU, WebGL). CPU is much slower. Fortunately, this app is just my private playground. Use netron.app to preview the srcnn.onnx file.

Lessons from ONNX conversion

  1. During the training, your image-based PyTorch model has input of size [batch_size, img_channel_count, img_height, img_width]. During inference, Pytorch accepts e.g. [img_channel_count, img_height, img_width]. It does not mind that the dimension for batch_size does not exist. THIS IS NOT TRUE FOR ONNX!.
  2. Double check you have always correct tensors for images: [batch_size, img_channel_count, img_height, img_width]. I've lost "a bit" of time cause my input had width and height reversed. Evident when:
    • Model works only for square images.
    • Vertical images have a few "ghost duplicates" along horizontal axis.
    • Horizontal images have many "ghost duplicates" along horizontal axis.

The second one sounds silly. But after years of writing code for CG, your fingers do not think about it.

I recommend following utils (for single grayscale image processing):

const encodeDims = (w, h) => [1, 1, h, w]; // [batch_size, channels, height, width]
const decodeDims = (dims) => [dims[3], dims[2]]; // returns: [w, h]

The files

  • images_raw/. The original images we will use to generate training samples from. Add some images to this directory.
  • images_upscaled/. Contains the final upscaled image after inference.
  • models/. Contains learned models as .pt file.
  • samples_gen_32/. Training patches generated from images_raw with gen_samples.py with default patch size (32x32 px).
  • samples_gen_64/. Training patches generated from images_raw with gen_samples.py with -64 flag (64x64 px).
  • gen_samples.py/. Script to generate sample patches from images_raw.
  • main.py. CLI for training/inference.
  • srcnn.py. CNN model implementation.

References

If you are interested in math or implementation details, I've written 2 articles 9 years ago:

Ofc. the original "Image Super-Resolution Using Deep Convolutional Networks" is still relevant. Even the current state of the art references it as the progenitor.

About

[PyTorch] SRCNN - Super-resolution using convolutional neural networks

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published