In [1]:
from PIL import Image
from torchvision.transforms import ToTensor, ToPILImage
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
import warnings
warnings.filterwarnings("ignore")

In [2]:
def make_layer(block, n_layers):
    layers = []
    for _ in range(n_layers):
        layers.append(block())
    return nn.Sequential(*layers)

In [3]:
class ResidualDenseBlock_5C(nn.Module):
    def __init__(self, nf=64, gc=32, bias=True):
        super(ResidualDenseBlock_5C, self).__init__()
        self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
        self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
        self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
        self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
        self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

    def forward(self, x):
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
        x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
        x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
        return x5 * 0.2 + x

In [4]:
class RRDB(nn.Module):
    '''Residual in Residual Dense Block'''

    def __init__(self, nf, gc=32):
        super(RRDB, self).__init__()
        self.RDB1 = ResidualDenseBlock_5C(nf, gc)
        self.RDB2 = ResidualDenseBlock_5C(nf, gc)
        self.RDB3 = ResidualDenseBlock_5C(nf, gc)

    def forward(self, x):
        out = self.RDB1(x)
        out = self.RDB2(out)
        out = self.RDB3(out)
        return out * 0.2 + x

In [5]:
class RRDBNet(nn.Module):
    def __init__(self, in_nc, out_nc, nf, nb, gc=32):
        super(RRDBNet, self).__init__()
        RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)

        self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
        self.RRDB_trunk = make_layer(RRDB_block_f, nb)
        self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)

        self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)

        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

    def forward(self, x):
        fea = self.conv_first(x)
        trunk = self.trunk_conv(self.RRDB_trunk(fea))
        fea = fea + trunk

        fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
        fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
        out = self.conv_last(self.lrelu(self.HRconv(fea)))

        return out

In [6]:
import os
import glob
import os.path as osp
import cv2
import numpy as np
import torch
import matplotlib.pyplot as plt

In [7]:
model_path = 'GAN_weights.pth'
device = torch.device('cpu')

In [8]:
model = RRDBNet(3, 3, 64, 23, gc=32)
model.load_state_dict(torch.load(model_path, map_location=device), strict=True)

<All keys matched successfully>

In [9]:
for key, value in model.state_dict().items():
    print(f"{key}: {value.shape}")

conv_first.weight: torch.Size([64, 3, 3, 3])
conv_first.bias: torch.Size([64])
RRDB_trunk.0.RDB1.conv1.weight: torch.Size([32, 64, 3, 3])
RRDB_trunk.0.RDB1.conv1.bias: torch.Size([32])
RRDB_trunk.0.RDB1.conv2.weight: torch.Size([32, 96, 3, 3])
RRDB_trunk.0.RDB1.conv2.bias: torch.Size([32])
RRDB_trunk.0.RDB1.conv3.weight: torch.Size([32, 128, 3, 3])
RRDB_trunk.0.RDB1.conv3.bias: torch.Size([32])
RRDB_trunk.0.RDB1.conv4.weight: torch.Size([32, 160, 3, 3])
RRDB_trunk.0.RDB1.conv4.bias: torch.Size([32])
RRDB_trunk.0.RDB1.conv5.weight: torch.Size([64, 192, 3, 3])
RRDB_trunk.0.RDB1.conv5.bias: torch.Size([64])
RRDB_trunk.0.RDB2.conv1.weight: torch.Size([32, 64, 3, 3])
RRDB_trunk.0.RDB2.conv1.bias: torch.Size([32])
RRDB_trunk.0.RDB2.conv2.weight: torch.Size([32, 96, 3, 3])
RRDB_trunk.0.RDB2.conv2.bias: torch.Size([32])
RRDB_trunk.0.RDB2.conv3.weight: torch.Size([32, 128, 3, 3])
RRDB_trunk.0.RDB2.conv3.bias: torch.Size([32])
RRDB_trunk.0.RDB2.conv4.weight: torch.Size([32, 160, 3, 3])
RRDB_trun

In [10]:
model.eval()

RRDBNet(
  (conv_first): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (RRDB_trunk): Sequential(
    (0): RRDB(
      (RDB1): ResidualDenseBlock_5C(
        (conv1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv2): Conv2d(96, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv3): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv4): Conv2d(160, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv5): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (lrelu): LeakyReLU(negative_slope=0.2, inplace=True)
      )
      (RDB2): ResidualDenseBlock_5C(
        (conv1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv2): Conv2d(96, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv3): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv4): Conv2d(160, 32, kernel_size=(3, 3), str

In [11]:
def enhance_image(input_image):
    img = np.array(input_image) 
    img = img / 255.0
    img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float().unsqueeze(0).to(device)

    with torch.no_grad():
        sr_output = model(img).squeeze(0).clamp_(0, 1).cpu().numpy()

    sr_output = (np.transpose(sr_output[[2, 1, 0], :, :], (1, 2, 0)) * 255.0).astype(np.uint8)  # Convert back to BGR
    sr_image = Image.fromarray(sr_output)
    return sr_image

In [12]:
import gradio as gr

interface = gr.Interface(
    fn=enhance_image,
    inputs=gr.Image(type="pil", label="Upload Low-Resolution Image"),
    outputs=gr.Image(type="pil", label="Enhanced Image"),
    title="ESRGAN Super-Resolution",
    description="Upload a low-resolution image, and this tool will generate a high-resolution version using ESRGAN.",
    allow_flagging="never",
)

if __name__ == "__main__":
    interface.launch(share=True)

* Running on local URL:  http://127.0.0.1:7860
* Running on public URL: https://23c80731cad1618608.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
