<a href="https://colab.research.google.com/github/Indukurivigneshvarma/Deep_Learning/blob/main/Computer_Vision/ESRGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [16]:
import os
import sys
import argparse
import urllib.request
from pathlib import Path
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [17]:
class ResidualDenseBlock_5C(nn.Module):
    def __init__(self, nf=64, gc=32):
        super().__init__()
        self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1)
        self.conv2 = nn.Conv2d(nf+gc, gc, 3, 1, 1)
        self.conv3 = nn.Conv2d(nf+2*gc, gc, 3, 1, 1)
        self.conv4 = nn.Conv2d(nf+3*gc, gc, 3, 1, 1)
        self.conv5 = nn.Conv2d(nf+4*gc, nf, 3, 1, 1)
        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
        self.scale = 0.2

    def forward(self, x):
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
        x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
        x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
        return x + x5 * self.scale

In [18]:
class RRDB(nn.Module):
    def __init__(self, nf=64, gc=32):
        super().__init__()
        self.RDB1 = ResidualDenseBlock_5C(nf, gc)
        self.RDB2 = ResidualDenseBlock_5C(nf, gc)
        self.RDB3 = ResidualDenseBlock_5C(nf, gc)

    def forward(self, x):
        out = self.RDB1(x)
        out = self.RDB2(out)
        out = self.RDB3(out)
        return out * 0.2 + x

In [19]:

class RRDBNet(nn.Module):
    def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, scale=4):
        super().__init__()
        self.scale = scale
        self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1)
        self.RRDB_trunk = nn.Sequential(*[RRDB(nf, gc) for _ in range(nb)])
        self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1)
        self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1)
        self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1)
        self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1)
        self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1)
        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

    def forward(self, x):
        fea = self.conv_first(x)
        trunk = self.trunk_conv(self.RRDB_trunk(fea))
        fea = fea + trunk
        fea = self.lrelu(F.interpolate(self.upconv1(fea), scale_factor=2, mode='nearest'))
        fea = self.lrelu(F.interpolate(self.upconv2(fea), scale_factor=2, mode='nearest'))
        out = self.conv_last(self.lrelu(self.HRconv(fea)))
        return out

In [20]:
MODEL_URLS = [
    "https://drive.google.com/uc?export=download&id=1pJ_T-V1dpb1ewoEra1TGSWl5e6H7M4NN",
    "https://huggingface.co/databuzzword/esrgan/resolve/main/RRDB_ESRGAN_x4.pth"
]

In [21]:
def download_file(url, out_path):
    out_path.parent.mkdir(parents=True, exist_ok=True)
    print(f"Downloading model from:\n  {url}\n-> {out_path}")
    try:
        urllib.request.urlretrieve(url, out_path)
    except Exception as e:
        print("Automatic download failed:", e)
        print("Please download the file manually and place it at:", out_path)
        raise

def load_model(model_path, device):
    net = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, scale=4)
    net.eval()
    state = torch.load(model_path, map_location=device)
    if isinstance(state, dict) and 'params_ema' in state:
        sd = state['params_ema']
    elif isinstance(state, dict) and 'state_dict' in state:
        sd = state['state_dict']
    else:
        sd = state
    new_sd = {}
    for k, v in sd.items():
        nk = k.replace('module.', '') if k.startswith('module.') else k
        new_sd[nk] = v
    net.load_state_dict(new_sd, strict=False)
    net.to(device)
    return net

def read_image(path):
    img = Image.open(path).convert('RGB')
    return img

def save_image(tensor, path):
    arr = (tensor.clip(0,1) * 255.0).round().astype(np.uint8)
    Image.fromarray(arr).save(path)

def tensor_from_pil(img):
    arr = np.array(img).astype(np.float32) / 255.0
    arr = arr.transpose((2,0,1))
    return torch.from_numpy(arr).unsqueeze(0)

def pil_from_tensor(tensor):
    arr = tensor.squeeze(0).cpu().numpy()
    arr = np.transpose(arr, (1,2,0))
    return (arr)

In [22]:

import os

os.makedirs("/content/ESRGAN/images", exist_ok=True)
os.makedirs("/content/ESRGAN/results", exist_ok=True)
os.makedirs("/content/ESRGAN/models", exist_ok=True)

os.chdir("/content/ESRGAN")

print("Current working directory:", os.getcwd())

Current working directory: /content/ESRGAN


In [29]:
from google.colab import files
uploaded = files.upload()

Saving Blurred_Pets.webp to Blurred_Pets.webp


In [31]:
!mv Blurred_Pets.webp /content/ESRGAN/images/


In [32]:
img_path = '/content/ESRGAN/images/Blurred_Pets.webp'

In [33]:
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument('--input', '-i', required=True, help='input image path or folder')
    ap.add_argument('--model', '-m', default='./models/RRDB_ESRGAN_x4.pth', help='path to pretrained model')
    ap.add_argument('--out', '-o', default='./results', help='output folder')
    ap.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', help='cpu or cuda')
    class Args:
      input = '/content/ESRGAN/images/Blurred_Pets.webp'
      model = '/content/ESRGAN/weights/ESRGAN.pth'
      out = '/content/output.webp'
      device = 'cuda'

    args = Args()

    input_path = Path(args.input)
    model_path = Path(args.model)
    out_dir = Path(args.out)
    out_dir.mkdir(parents=True, exist_ok=True)

    # download model if not present - try mirrors
    if not model_path.exists():
        print(f"Model not found at {model_path}. Attempting to download...")
        for url in MODEL_URLS:
            try:
                download_file(url, model_path)
                print("Download completed.")
                break
            except Exception:
                print("Attempt failed, trying next URL...")
        if not model_path.exists():
            print("Model download failed. Please download RRDB_ESRGAN_x4.pth manually and place it at:", model_path)
            print("Example mirror: https://huggingface.co/databuzzword/esrgan/resolve/main/RRDB_ESRGAN_x4.pth")
            sys.exit(1)

    device = torch.device(args.device)
    print("Loading model on", device)
    net = load_model(model_path, device)

    # Collect images
    imgs = []
    if input_path.is_dir():
        for ext in ('*.png','*.jpg','*.jpeg','*.webp'):
            imgs += sorted(input_path.glob(ext))
    else:
        imgs = [input_path]

    if not imgs:
        print("No images found at", input_path)
        sys.exit(1)

    print(f"Found {len(imgs)} image(s). Starting inference...")
    for p in imgs:
        print("Processing:", p)
        img = read_image(p)
        # Convert to tensor normalized 0..1
        t = tensor_from_pil(img).to(device)
        with torch.no_grad():
            out = net(t)
            # clamp and convert to numpy
            # Some ESRGANs expect input in range [0,1] and produce same range
            out_np = pil_from_tensor(out)
        out_name = out_dir / (p.stem + '_esr.png')
        save_image(out_np, out_name)
        print("Saved:", out_name)

    print("Done.")

if __name__ == '__main__':
    main()

Loading model on cuda
Found 1 image(s). Starting inference...
Processing: /content/ESRGAN/images/Blurred_Pets.webp
Saved: /content/output.webp/Blurred_Pets_esr.png
Done.
