In [1]:
import torch
import torchvision
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, random_split, TensorDataset

import os
import time
import cv2
import random
import skimage
from skimage.util import random_noise
import numpy as np
from PIL import Image
from PIL import ImageFile
from earlystopping import EarlyStopping
from loss import *
import matplotlib.pyplot as plt
import hiddenlayer as hl

In [2]:
n_blocks = 5
n_epochs = 100
batch_size = 64
train_path = './data/COCO2014/train2014/'
val_path = './data/COCO2014/val2014/'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [3]:
randomcrop = transforms.RandomCrop(96)

In [4]:
def addGaussNoise(data, sigma):
    sigma2 = sigma**2 / (255 ** 2)
    noise = random_noise(data, mode='gaussian', var=sigma2, clip=True)
    return noise

In [5]:
class MyDataset(Dataset):
    def __init__(self, path, transform, sigma=30, ex=1):
        self.transform = transform
        self.sigma = sigma

        for _, _, files in os.walk(path):
            self.imgs = [path + file for file in files if Image.open(path + file).size >= (96,96)] * ex

        np.random.shuffle(self.imgs)

    def __getitem__(self, index):
        tempImg = self.imgs[index]
        tempImg = Image.open(tempImg).convert('RGB')
        Img = np.array(self.transform(tempImg))/255
        nImg = addGaussNoise(Img, self.sigma)
        Img = torch.tensor(Img.transpose(2,0,1))
        nImg = torch.tensor(nImg.transpose(2,0,1))
        return Img, nImg

    def __len__(self):
        return len(self.imgs)

In [6]:
def get_data(batch_size, train_path, val_path, transform, sigma, ex=1):
    train_dataset = MyDataset(train_path, transform, sigma, ex)
    val_dataset = MyDataset(val_path, transform, sigma, ex)
    train_iter = DataLoader(train_dataset, batch_size, drop_last=True, num_workers=6)
    val_iter = DataLoader(val_dataset, batch_size, drop_last=True, num_workers=6)
    return train_iter, val_iter

In [7]:
train_iter, val_iter = get_data(batch_size, train_path, val_path, randomcrop, 30, ex=1)

In [8]:
def calculate_psnr(img1, img2):
    return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))

In [9]:
class ResBlock(nn.Module):
    def __init__(self, inC, outC):
        super(ResBlock, self).__init__()
        self.layer1 = nn.Sequential(nn.Conv2d(inC, outC, kernel_size=3, stride=1, padding=1, bias=False), 
                                    nn.BatchNorm2d(outC), 
                                    nn.PReLU())

        self.layer2 = nn.Sequential(nn.Conv2d(outC, outC, kernel_size=3, stride=1, padding=1, bias=False), 
                                    nn.BatchNorm2d(outC))

    def forward(self, x):
        resudial = x

        out = self.layer1(x)
        out = self.layer2(out)
        out = out + resudial

        return out

In [10]:
class Generator(nn.Module):
    def __init__(self, n_blocks):
        super(Generator, self).__init__()
        self.convlayer1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=9, stride=1, padding=4, bias=False),
                                        nn.PReLU())

        self.ResBlocks = nn.ModuleList([ResBlock(64, 64) for _ in range(n_blocks)])

        self.convlayer2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False), 
                                        nn.BatchNorm2d(64))

        self.convout = nn.Conv2d(64, 3, kernel_size=9, stride=1, padding=4, bias=False)

    def forward(self, x):
        out = self.convlayer1(x)
        residual = out

        for block in self.ResBlocks:
            out = block(out)

        out = self.convlayer2(out)
        out = out + residual

        out = self.convout(out)

        return out

In [11]:
class DownSample(nn.Module):
    def __init__(self, input_channel, output_channel,  stride, kernel_size=3, padding=1):
        super(DownSample, self).__init__()
        self.layer = nn.Sequential(nn.Conv2d(input_channel, output_channel, kernel_size, stride, padding),
                                   nn.BatchNorm2d(output_channel),
                                   nn.LeakyReLU(inplace=True))

    def forward(self, x):
        x = self.layer(x)
        return x

In [12]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(3, 64, 3, stride=1, padding=1),
                                   nn.LeakyReLU(inplace=True))

        self.down = nn.Sequential(DownSample(64, 64, stride=2, padding=1),
                                  DownSample(64, 128, stride=1, padding=1),
                                  DownSample(128, 128, stride=2, padding=1),
                                  DownSample(128, 256, stride=1, padding=1),
                                  DownSample(256, 256, stride=2, padding=1),
                                  DownSample(256, 512, stride=1, padding=1),
                                  DownSample(512, 512, stride=2, padding=1))

        self.dense = nn.Sequential(nn.AdaptiveAvgPool2d(1),
                                   nn.Conv2d(512, 1024, 1),
                                   nn.LeakyReLU(inplace=True),
                                   nn.Conv2d(1024, 1, 1),
                                   nn.Sigmoid())

    def forward(self, x):
        x = self.conv1(x)
        x = self.down(x)
        x = self.dense(x)
        return x

In [13]:
lr = 0.001

G = Generator(n_blocks)
D = Discriminator()

G_loss = PerceptualLoss(device)
Regulaztion = RegularizationLoss().to(device)
D_loss = nn.BCELoss().to(device)

optimizer_g = torch.optim.Adam(G.parameters(), lr=lr*0.1)
optimizer_d = torch.optim.Adam(D.parameters(), lr=lr)

real_label = torch.ones([batch_size, 1, 1, 1]).to(device)
fake_label = torch.zeros([batch_size, 1, 1, 1]).to(device)

early_stopping = EarlyStopping(10, verbose=True)

In [14]:
train_loss_g = []
train_loss_d = []
train_psnr = []
val_loss = []
val_psnr = []

In [15]:
def train(generator, discriminator, train_iter, val_iter, n_epochs, optimizer_g, optimizer_d, loss_g, loss_d, Regulaztion, device):
    print('train on',device)
    generator.to(device)
    discriminator.to(device)
    cuda = next(generator.parameters()).device
    for epoch in range(n_epochs):
        train_epoch_loss_g = []
        train_epoch_loss_d = []
        train_epoch_psnr = []
        val_epoch_loss = []
        val_epoch_psnr = []
        start = time.time()
        generator.train()
        discriminator.train()
        for i, (img, nimg) in enumerate(train_iter):
            img, nimg = img.to(cuda).float(), nimg.to(cuda).float()
            fakeimg = generator(nimg)
            
            optimizer_d.zero_grad()
            realOut = discriminator(img)
            fakeOut = discriminator(fakeimg.detach())
            loss_d = D_loss(realOut, real_label) + D_loss(fakeOut, fake_label)
            loss_d.backward()
            optimizer_d.step()
            
            optimizer_g.zero_grad()
            loss_g = G_loss(fakeimg, img, D(fakeimg)) + 2e-8*Regulaztion(fakeimg)
            loss_g.backward()
            optimizer_g.step()
            
            train_epoch_loss_d.append(loss_d.item())
            train_epoch_loss_g.append(loss_g.item())
            train_epoch_psnr.append(calculate_psnr(fakeimg, img).item())
        train_epoch_avg_loss_g = np.mean(train_epoch_loss_g)
        train_epoch_avg_loss_d = np.mean(train_epoch_loss_d)
        train_epoch_avg_psnr = np.mean(train_epoch_psnr)
        train_loss_g.append(train_epoch_avg_loss_g)
        train_loss_d.append(train_epoch_avg_loss_d)
        train_psnr.append(train_epoch_avg_psnr)
        print(f'Epoch {epoch + 1}, Generator Train Loss: {train_epoch_avg_loss_g:.4f}, '
              f'Discriminator Train Loss: {train_epoch_avg_loss_d:.4f}, PSNR: {train_epoch_avg_psnr:.4f}')
        generator.eval()
        discriminator.eval()
        with torch.no_grad():
            for i, (img, nimg) in enumerate(val_iter):
                img, nimg = img.to(cuda).float(), nimg.to(cuda).float()
                fakeimg = generator(nimg)
                loss_g = G_loss(fakeimg, img, D(fakeimg)) + 2e-8*Regulaztion(fakeimg)
                val_epoch_loss.append(loss_g.item())
                val_epoch_psnr.append(calculate_psnr(fakeimg, img).item())
            val_epoch_avg_loss = np.mean(val_epoch_loss)
            val_epoch_avg_psnr = np.mean(val_epoch_psnr)
            val_loss.append(val_epoch_avg_loss)
            val_psnr.append(val_epoch_avg_psnr)
            print(f'Generator Val Loss: {val_epoch_avg_loss:.4f}, PSNR: {val_epoch_avg_psnr:.4f}, Cost: {(time.time()-start):.4f}s')
            checkpoint_perf = early_stopping(generator, discriminator, train_epoch_avg_psnr, val_epoch_avg_psnr)
            if early_stopping.early_stop:
                print("Early stopping")
                print('Final model performance:')
                print(f'Train PSNR: {checkpoint_perf[0]}, Val PSNR: {checkpoint_perf[1]}')
                break
        torch.cuda.empty_cache()

In [16]:
train(G, D, train_iter, val_iter, n_epochs, optimizer_g, optimizer_d, G_loss, D_loss, Regulaztion, device)

train on cuda


KeyboardInterrupt: 

In [None]:
plt.figure()
plt.subplot(1,2,1)
plt.plot(train_loss_d, label='Generator Train Loss')
plt.plot(train_loss_g, label='Discriminator Train Loss')
plt.plot(val_loss, label='Validation Loss')
plt.subplot(1,2,2)
plt.plot(train_psnr, label='Train PSNR')
plt.plot(val_psnr, label='Validation PSNR')
plt.title('Training process')
plt.legend()
plt.show()

In [None]:
model = Generator(n_blocks)
model.load_state_dict(torch.load('Generator.pth', map_location=torch.device('cpu')))
model.eval()

In [None]:
test_transform = transforms.ToTensor()
testimg = Image.open('img_011_SRF_4_HR.png')
timg = np.array(testimg)/255
timg = addGaussNoise(timg, 30)
timg = torch.tensor(timg.transpose(2,0,1)).float().unsqueeze(0)

In [None]:
dnimg = model(timg)[0, :, :, :]
dnimg = dnimg.detach().numpy().transpose((1, 2, 0))

In [None]:
timg = Image.fromarray(np.uint8(cv2.normalize(timg.squeeze().detach().numpy().transpose(1,2,0), None, 0, 255, cv2.NORM_MINMAX)))
timg.save('noiseimg_011_SRF_4_HR.png')

In [None]:
img = Image.fromarray(np.uint8(cv2.normalize(dnimg, None, 0, 255, cv2.NORM_MINMAX)))
img.save('set5_gan_test.png')

In [None]:
outPath = './output_file/'

In [None]:
def restruct(model, val_iter, outPath):
    for i, (img, nimg) in enumerate(val_iter):
        nimg = nimg.float()
        dnimg = model(nimg)
        dnimg = dnimg.detach().numpy().transpose(0,2,3,1)
        img = img.detach().numpy().transpose(0,2,3,1)
        
        for t in range(img.shape[0]):
            dnimgs = Image.fromarray(np.uint8(cv2.normalize(dnimg[t,:,:,:], None, 0, 255, cv2.NORM_MINMAX)))
            rawimgs = Image.fromarray(np.uint8(cv2.normalize(img[t,:,:,:], None, 0, 255, cv2.NORM_MINMAX)))
            dnimgs.save(outPath+f'{i*batch_size+t}_DN.png')
            rawimgs.save(outPath+f'{i*batch_size+t}.png')

In [None]:
restruct(model, val_iter, outPath)