# Setup

In [1]:
import sys
from torch.utils.data import DataLoader

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt

import numpy as np
import torchvision.transforms as T
to_img = T.ToPILImage()
from PIL import Image
import cv2
import os

In [2]:
CROP_SIZE = 128
UPSCALE_FACTOR = 4

In [3]:
class AddGaussianNoise(object):
    def __init__(self, mean=0., std=0.05):
        self.std = std
        self.mean = mean
        
    def __call__(self, tensor):
        mask = torch.rand(tensor.size()) > 0.9
        noise = torch.randn(tensor.size()) * self.std + self.mean
        return tensor + mask * noise
    
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

In [4]:
from os import listdir
from os.path import join

from PIL import Image
from torch.utils.data.dataset import Dataset
from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize, GaussianBlur, Grayscale

def is_image_file(filename):
     return any(filename.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'])


def calculate_valid_crop_size(crop_size, upscale_factor):
 return crop_size - (crop_size % upscale_factor)


def train_hr_transform(crop_size):
 return Compose([
     RandomCrop(crop_size),
     ToTensor(),
 ])


def train_lr_transform(crop_size, upscale_factor):
 return Compose([
     AddGaussianNoise(),
     ToPILImage(),
#      GaussianBlur(3, sigma=(0.1, 2.0)),
     Resize(crop_size // upscale_factor, interpolation=Image.BICUBIC),
     ToTensor()
 ])


def display_transform():
 return Compose([
     ToPILImage(),
     Resize(400),
     CenterCrop(400),
     ToTensor()
 ])

In [5]:
class TrainDatasetFromFolder(Dataset):
     def __init__(self, dataset_dir, crop_size, upscale_factor):
         super(TrainDatasetFromFolder, self).__init__()
         self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]
         crop_size = calculate_valid_crop_size(crop_size, upscale_factor)
         self.hr_transform = train_hr_transform(crop_size)
         self.lr_transform = train_lr_transform(crop_size, upscale_factor)

     def __getitem__(self, index):
         hr_image = self.hr_transform(Image.open(self.image_filenames[index]))
         lr_image = self.lr_transform(hr_image)
         return lr_image, hr_image

     def __len__(self):
         return len(self.image_filenames)
        
        
class ValDatasetFromFolder(Dataset):
     def __init__(self, dataset_dir, upscale_factor):
         super(ValDatasetFromFolder, self).__init__()
         self.upscale_factor = upscale_factor
         self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]

     def __getitem__(self, index):
         hr_image = Image.open(self.image_filenames[index])
         w, h = hr_image.size
         crop_size = calculate_valid_crop_size(min(w, h), self.upscale_factor)
         lr_scale = Resize(crop_size // self.upscale_factor, interpolation=Image.BICUBIC)
         hr_scale = Resize(crop_size, interpolation=Image.BICUBIC)
         hr_image = CenterCrop(crop_size)(hr_image)
         lr_image = lr_scale(hr_image)
         hr_restore_img = hr_scale(lr_image)
         return ToTensor()(lr_image), ToTensor()(hr_restore_img), ToTensor()(hr_image)

     def __len__(self):
         return len(self.image_filenames)

In [6]:
# train_set = TrainDatasetFromFolder('data/DIV2K/train', crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR)
val_set = ValDatasetFromFolder('/home/ubuntu/noisy-superres/data/DIV2K/val', upscale_factor=UPSCALE_FACTOR)

# train_loader = DataLoader(dataset=train_set, num_workers=0, batch_size=1, shuffle=False)
val_loader = DataLoader(dataset=val_set, num_workers=0, batch_size=1, shuffle=False)

In [7]:
# iter_train = iter(train_loader)
# iter_val = iter(val_loader)

In [8]:
# lr_img, lr_img_rescaled, gt_img = val_set[0]
# print(lr_img.shape)
# print(lr_img_rescaled.shape)
# print(gt_img.shape)

In [9]:
# to_img(gt_img)

In [10]:
# to_img(lr_img)

In [11]:
# to_img(lr_img_rescaled)

In [12]:
sys.path.append("./SRGAN")
sys.path.append("./SRGAN-PyTorch")
import srgan_pytorch
import srgan_pytorch.models as srgan_models
from srgan_pytorch.utils.estimate import iqa

sys.path.append("./ESRGAN-PyTorch")
import esrgan_pytorch
import esrgan_pytorch.models as esrgan_models


sys.path.append("./DIDN")
import color_model


In [13]:
device = torch.device("cuda:0")
# device = torch.device("cpu")

# Overview

```
srgan one network (srgan_one)
esrgan one network (esrgan_one)

denoise first
cv2 (cv2_lr_denoised) -> srgan (cv2_srgan)
cv2 (cv2_lr_denoised) -> esrgan (cv2_esrgan)
didn (didn_lr_denoised) -> srgan (didn_srgan)
didn (didn_lr_denoised) -> esrgan (didn_esrgan)

superres first
srgan (srgan_hr_noisy)-> cv2 (srgan_cv2)
esrgan (esrgan_hr_noisy)-> cv2 (esrgan_cv2)
srgan (srgan_hr_noisy) -> didn (srgan_didn)
esrgan (esrgan_hr_noisy) -> didn (esrgan_didn)
```

In [14]:
basepath = "/home/ubuntu/noisy-superres/data/DIV2K/vis_images/"

# Get lr, lr_noisy, and gt images

In [25]:
count = 0
for i in range(0, 30):
    if count >= 10:
        break
    lr_img, lr_img_rescaled, gt_img = val_set[i + 30]
    if lr_img.shape[-1] < 328:
        continue
    noisy_lr_img = AddGaussianNoise()(lr_img)
    
    to_img(lr_img).save(basepath + "lr/" + "img"+str(count)+".png")
    torch.save(lr_img, basepath + "lr/" + "img"+str(count)+".pt")
    
    to_img(noisy_lr_img).save(basepath + "lr_noisy/" + "img"+str(count)+".png")
    torch.save(noisy_lr_img, basepath + "lr_noisy/" + "img"+str(count)+".pt")
    
    to_img(gt_img).save(basepath + "gt/" + "img"+str(count)+".png")
    torch.save(gt_img, basepath + "gt/" + "img"+str(count)+".pt")
    
    count += 1

# Get cv2_lr_denoised images

In [32]:
for filename in os.listdir(basepath+"lr_noisy/"):
    if filename.endswith(".pt"):
        filename = filename[:-3]
        cur_filepath = os.path.join(basepath+"lr_noisy/", filename+".pt")
#         print(filename)
#         img = ToTensor()(Image.open(cur_filepath))
        img = torch.load(cur_filepath)
        img = img.numpy()
        img = np.moveaxis(img, 0, -1)
        img = cv2.fastNlMeansDenoisingColored((img * 255).astype('uint8'), None,10,10,7,21)
        img = torch.from_numpy(np.moveaxis(img, -1, 0).astype('float32'))
        img = img / 255
        torch.save(img, basepath + "cv2_lr_denoised/" + filename+".pt")
        to_img(img).save(basepath+"cv2_lr_denoised/"+filename+".png")

# Get didn_lr_denoised images

In [15]:
denoise_model = color_model._NetG()
checkpoint = torch.load('DIDN/checkpoint/color_model.pth', map_location=lambda storage, loc: storage)
denoise_model.load_state_dict(checkpoint['model'].state_dict())
denoise_model = denoise_model.to(device)



In [44]:
torch.cuda.empty_cache()
for filename in os.listdir(basepath+"lr_noisy/"):
    if filename.endswith(".pt"):
        with torch.no_grad():
            filename = filename[:-3]
            cur_filepath = os.path.join(basepath+"lr_noisy/", filename+".pt")
            img = torch.load(cur_filepath)
            dim = img.shape[-1] // 8 * 8
            img = denoise_model(img[:,:dim, :dim].unsqueeze(0).to(device))
            img = img[0].clip(0, 1)
            torch.save(img, basepath + "didn_lr_denoised/" + filename+".pt")
            to_img(img).save(basepath+"didn_lr_denoised/"+filename[:-3]+".png")

torch.Size([3, 336, 336]) <class 'torch.Tensor'>


RuntimeError: No active exception to reraise

# Get srgan_hr_noisy images

In [31]:
model_path = "SRGAN-PyTorch/best_weights/vanilla/GAN.pth"
srgan = srgan_models.__dict__["srgan"]()
state_dict = torch.load(model_path, map_location=device)
srgan.load_state_dict(state_dict)
srgan.eval()
srgan = srgan.to(device)

In [45]:
torch.cuda.empty_cache()
for filename in os.listdir(basepath+"lr_noisy/"):
    if filename.endswith(".pt"):
        with torch.no_grad():
            filename = filename[:-3]
            cur_filepath = os.path.join(basepath+"lr_noisy/", filename+".pt")
            img = torch.load(cur_filepath)
            img = srgan(img.unsqueeze(0).to(device))[0]
#             img = img.cpu().detach().numpy().clip(0,1)
            img = img.cpu().detach().clip(0,1)
#             img = torch.from_numpy(np.moveaxis(img, 0, -1))
#             print(img.shape)
#             raise
            torch.save(img, basepath + "srgan_hr_noisy/" + filename+".pt")
#             to_img(np.uint8(img * 255)).save(basepath+"srgan_hr_noisy/"+filename+".png")
            to_img(img).save(basepath+"srgan_hr_noisy/"+filename+".png")

torch.Size([3, 1356, 1356])
torch.Size([3, 1620, 1620])
torch.Size([3, 1356, 1356])
torch.Size([3, 1356, 1356])
torch.Size([3, 1356, 1356])
torch.Size([3, 1872, 1872])
torch.Size([3, 1356, 1356])
torch.Size([3, 1596, 1596])
torch.Size([3, 1536, 1536])
torch.Size([3, 1500, 1500])


# Get esrgan_hr_noisy images

In [19]:
model_path = "ESRGAN-PyTorch/best_weights/vanilla/GAN.pth"
esrgan = esrgan_models.__dict__["esrgan16"]()
state_dict = torch.load(model_path, map_location=torch.device("cpu"))
esrgan.load_state_dict(state_dict)
esrgan.eval()
esrgan = esrgan.to(device)

In [22]:
torch.cuda.empty_cache()
for filename in os.listdir(basepath+"lr_noisy/"):
    if filename.endswith(".pt"):
        with torch.no_grad():
            filename = filename[:-3]
            cur_filepath = os.path.join(basepath+"lr_noisy/", filename+".pt")
            img = torch.load(cur_filepath)
            img = esrgan(img.unsqueeze(0).to(device))[0]
            img = img.cpu().detach().clip(0,1)
            torch.save(img, basepath + "esrgan_hr_noisy/" + filename+".pt")
            to_img(img).save(basepath+"esrgan_hr_noisy/"+filename+".png")

# SRGAN one network

In [23]:
# model_path = "SRGAN-PyTorch/weights/Generator_best.pth"
model_path = "SRGAN-PyTorch/best_weights/noisy/GAN.pth"
noisy_srgan = srgan_models.__dict__["srgan"]()
state_dict = torch.load(model_path, map_location=torch.device("cpu"))
noisy_srgan.load_state_dict(state_dict.get('state_dict', state_dict))
noisy_srgan.eval()
noisy_srgan = noisy_srgan.to(device)

In [24]:
torch.cuda.empty_cache()
for filename in os.listdir(basepath+"lr_noisy/"):
    if filename.endswith(".pt"):
        with torch.no_grad():
            filename = filename[:-3]
            cur_filepath = os.path.join(basepath+"lr_noisy/", filename+".pt")
            img = torch.load(cur_filepath)
            img = noisy_srgan(img.unsqueeze(0).to(device))[0]
#             img = img.cpu().detach().numpy().clip(0,1)
            img = img.cpu().detach().clip(0,1)
#             img = torch.from_numpy(np.moveaxis(img, 0, -1))
#             print(img.shape)
#             raise
            torch.save(img, basepath + "srgan_one/" + filename+".pt")
#             to_img(np.uint8(img * 255)).save(basepath+"srgan_hr_noisy/"+filename+".png")
            to_img(img).save(basepath+"srgan_one/"+filename+".png")

# ESRGAN one network

In [25]:
model_path = "ESRGAN-PyTorch/best_weights/noisy/GAN.pth"
noisy_esrgan = esrgan_models.__dict__["esrgan16"]()
state_dict = torch.load(model_path, map_location=torch.device("cpu"))
noisy_esrgan.load_state_dict(state_dict.get('state_dict', state_dict))
noisy_esrgan.eval()
noisy_esrgan = noisy_esrgan.to(device)

In [26]:
torch.cuda.empty_cache()
for filename in os.listdir(basepath+"lr_noisy/"):
    if filename.endswith(".pt"):
        with torch.no_grad():
            filename = filename[:-3]
            cur_filepath = os.path.join(basepath+"lr_noisy/", filename+".pt")
            img = torch.load(cur_filepath)
            img = noisy_esrgan(img.unsqueeze(0).to(device))[0]
#             img = img.cpu().detach().numpy().clip(0,1)
            img = img.cpu().detach().clip(0,1)
#             img = torch.from_numpy(np.moveaxis(img, 0, -1))
#             print(img.shape)
#             raise
            torch.save(img, basepath + "srgan_one/" + filename+".pt")
#             to_img(np.uint8(img * 255)).save(basepath+"srgan_hr_noisy/"+filename+".png")
            to_img(img).save(basepath+"srgan_one/"+filename+".png")

# cv2_srgan

In [28]:
model_path = "SRGAN-PyTorch/best_weights/vanilla/GAN.pth"
srgan = srgan_models.__dict__["srgan"]()
state_dict = torch.load(model_path, map_location=device)
srgan.load_state_dict(state_dict)
srgan.eval()
srgan = srgan.to(device)

In [34]:
torch.cuda.empty_cache()
for filename in os.listdir(basepath+"cv2_lr_denoised/"):
    if filename.endswith(".pt"):
        with torch.no_grad():
            filename = filename[:-3]
            cur_filepath = os.path.join(basepath+"cv2_lr_denoised/", filename+".pt")
            img = torch.load(cur_filepath)
#             print(img)
#             raise
            img = srgan(img.unsqueeze(0).to(device))[0]
#             img = img.cpu().detach().numpy().clip(0,1)
            img = img.cpu().detach().clip(0,1)
#             img = torch.from_numpy(np.moveaxis(img, 0, -1))
#             print(img.shape)
#             raise
            torch.save(img, basepath + "cv2_srgan/" + filename+".pt")
#             to_img(np.uint8(img * 255)).save(basepath+"srgan_hr_noisy/"+filename+".png")
            to_img(img).save(basepath+"cv2_srgan/"+filename+".png")

# cv2_esrgan

In [35]:
model_path = "ESRGAN-PyTorch/best_weights/vanilla/GAN.pth"
esrgan = esrgan_models.__dict__["esrgan16"]()
state_dict = torch.load(model_path, map_location=torch.device("cpu"))
esrgan.load_state_dict(state_dict)
esrgan.eval()
esrgan = esrgan.to(device)

In [37]:
torch.cuda.empty_cache()
for filename in os.listdir(basepath+"cv2_lr_denoised/"):
    if filename.endswith(".pt"):
        with torch.no_grad():
            filename = filename[:-3]
            cur_filepath = os.path.join(basepath+"cv2_lr_denoised/", filename+".pt")
            img = torch.load(cur_filepath)
#             print(img)
#             raise
            img = esrgan(img.unsqueeze(0).to(device))[0]
#             img = img.cpu().detach().numpy().clip(0,1)
            img = img.cpu().detach().clip(0,1)
#             img = torch.from_numpy(np.moveaxis(img, 0, -1))
#             print(img.shape)
#             raise
            torch.save(img, basepath + "cv2_esrgan/" + filename+".pt")
#             to_img(np.uint8(img * 255)).save(basepath+"srgan_hr_noisy/"+filename+".png")
            to_img(img).save(basepath+"cv2_esrgan/"+filename+".png")

# didn_srgan

In [38]:
model_path = "SRGAN-PyTorch/best_weights/vanilla/GAN.pth"
srgan = srgan_models.__dict__["srgan"]()
state_dict = torch.load(model_path, map_location=device)
srgan.load_state_dict(state_dict)
srgan.eval()
srgan = srgan.to(device)

In [41]:
torch.cuda.empty_cache()
for filename in os.listdir(basepath+"didn_lr_denoised/"):
    if filename.endswith(".pt"):
        with torch.no_grad():
            filename = filename[:-3]
            cur_filepath = os.path.join(basepath+"didn_lr_denoised/", filename+".pt")
            img = torch.load(cur_filepath)
#             print(img)
#             raise
            img = srgan(img.unsqueeze(0).to(device))[0]
#             img = img.cpu().detach().numpy().clip(0,1)
            img = img.cpu().detach().clip(0,1)
#             img = torch.from_numpy(np.moveaxis(img, 0, -1))
#             print(img.shape)
#             raise
            torch.save(img, basepath + "didn_srgan/" + filename+".pt")
#             to_img(np.uint8(img * 255)).save(basepath+"srgan_hr_noisy/"+filename+".png")
            to_img(img).save(basepath+"didn_srgan/"+filename+".png")

# didn_esrgan

In [42]:
model_path = "ESRGAN-PyTorch/best_weights/vanilla/GAN.pth"
esrgan = esrgan_models.__dict__["esrgan16"]()
state_dict = torch.load(model_path, map_location=torch.device("cpu"))
esrgan.load_state_dict(state_dict)
esrgan.eval()
esrgan = esrgan.to(device)

In [43]:
torch.cuda.empty_cache()
for filename in os.listdir(basepath+"didn_lr_denoised/"):
    if filename.endswith(".pt"):
        with torch.no_grad():
            filename = filename[:-3]
            cur_filepath = os.path.join(basepath+"didn_lr_denoised/", filename+".pt")
            img = torch.load(cur_filepath)
#             print(img)
#             raise
            img = esrgan(img.unsqueeze(0).to(device))[0]
#             img = img.cpu().detach().numpy().clip(0,1)
            img = img.cpu().detach().clip(0,1)
#             img = torch.from_numpy(np.moveaxis(img, 0, -1))
#             print(img.shape)
#             raise
            torch.save(img, basepath + "didn_esrgan/" + filename+".pt")
#             to_img(np.uint8(img * 255)).save(basepath+"srgan_hr_noisy/"+filename+".png")
            to_img(img).save(basepath+"didn_esrgan/"+filename+".png")

# srgan_cv2

In [44]:
for filename in os.listdir(basepath+"srgan_hr_noisy/"):
    if filename.endswith(".pt"):
        filename = filename[:-3]
        cur_filepath = os.path.join(basepath+"srgan_hr_noisy/", filename+".pt")
#         print(filename)
#         img = ToTensor()(Image.open(cur_filepath))
        img = torch.load(cur_filepath)
        img = img.numpy()
        img = np.moveaxis(img, 0, -1)
        img = cv2.fastNlMeansDenoisingColored((img * 255).astype('uint8'), None,10,10,7,21)
        img = torch.from_numpy(np.moveaxis(img, -1, 0).astype('float32'))
        img = img / 255
        torch.save(img, basepath + "srgan_cv2/" + filename+".pt")
        to_img(img).save(basepath+"srgan_cv2/"+filename+".png")

# esrgan_cv2

In [45]:
for filename in os.listdir(basepath+"esrgan_hr_noisy/"):
    if filename.endswith(".pt"):
        filename = filename[:-3]
        cur_filepath = os.path.join(basepath+"esrgan_hr_noisy/", filename+".pt")
#         print(filename)
#         img = ToTensor()(Image.open(cur_filepath))
        img = torch.load(cur_filepath)
        img = img.numpy()
        img = np.moveaxis(img, 0, -1)
        img = cv2.fastNlMeansDenoisingColored((img * 255).astype('uint8'), None,10,10,7,21)
        img = torch.from_numpy(np.moveaxis(img, -1, 0).astype('float32'))
        img = img / 255
        torch.save(img, basepath + "esrgan_cv2/" + filename+".pt")
        to_img(img).save(basepath+"esrgan_cv2/"+filename+".png")

# srgan_didn TODO

In [46]:
torch.cuda.empty_cache()
for filename in os.listdir(basepath+"srgan_hr_noisy/"):
    if filename.endswith(".pt"):
        with torch.no_grad():
            filename = filename[:-3]
            cur_filepath = os.path.join(basepath+"srgan_hr_noisy/", filename+".pt")
            print(cur_filepath)
            img = torch.load(cur_filepath)
            dim = img.shape[-1] // 8 * 8
            img = denoise_model(img[:,:dim, :dim].unsqueeze(0).to(device))
            img = img[0].clip(0, 1)
            torch.save(img, basepath + "srgan_didn/" + filename+".pt")
            to_img(img).save(basepath+"srgan_didn/"+filename[:-3]+".png")

/home/ubuntu/noisy-superres/data/DIV2K/vis_images/srgan_hr_noisy/img3.pt


RuntimeError: CUDA out of memory. Tried to allocate 2.62 GiB (GPU 0; 11.17 GiB total capacity; 6.40 GiB already allocated; 2.51 GiB free; 8.17 GiB reserved in total by PyTorch)

# esrgan_didn