In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torchvision
import torchvision.transforms as tfms
import torchvision.datasets as dsets
from torchvision.utils import save_image, make_grid

In [2]:
from IPython.display import clear_output
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import glob
import sys

# from torchsummaryM import summary

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
# Weight Initializer

def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

In [5]:
class ResidualBlock(nn.Module):
    def __init__(self, in_c):
        super(ResidualBlock, self).__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(in_c, in_c, 3, 1, 1),
            nn.BatchNorm2d(in_c, 0.8),
            nn.PReLU(),
            nn.Conv2d(in_c, in_c, 3, 1, 1),
            nn.BatchNorm2d(in_c, 0.8),
        )
    def forward(self, x):
        return x + self.conv_block(x)

In [6]:
class GeneratorResNet(nn.Module):
    def __init__(self, in_c, out_c, n_blocks=16):
        super(GeneratorResNet, self).__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_c, 64, kernel_size=9, stride=1, padding=4), 
            nn.PReLU()
        )
        
        res = []
        for _ in range(n_blocks):
            res.append(ResidualBlock(64))
        self.res_blocks = nn.Sequential(
            *res
        )
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 64, 3, 1, 1),
            nn.BatchNorm2d(64, 0.8)
        )
        
        up = []
        for _ in range(2):
            up += [
                nn.Conv2d(64, 256, 3, 1, 1),
                nn.BatchNorm2d(256),
                nn.PixelShuffle(upscale_factor=2),  # (b, C×(r^2), H, W) => (b, C, rxH, rxW), r=upscale_factor
                nn.PReLU()
            ]
    
        self.up= nn.Sequential(
            *up
        )
        
        self.conv3 = nn.Sequential(
            nn.Conv2d(64, out_c, kernel_size=9, stride=1, padding=4),
            nn.Tanh()
        )
        
    def forward(self, x):
        o1 = self.conv1(x)
        o = self.res_blocks(o1)
        o2 = self.conv2(o)
        
        o = torch.add(o1, o2)
        o = self.up(o)
        o = self.conv3(o)
        return o

In [7]:
class Discriminator(nn.Module):
    def __init__(self, in_shape):
        super(Discriminator, self).__init__()
        
        self.in_shape = in_shape
        in_c, in_h, in_w = self.in_shape
        patch_h, patch_w = int(in_h/2**4), int(in_w/2**4)
        self.out_shape = (1, patch_h, patch_w)
        
        def block(in_f, out_f, first=False):
            layers = []
            layers.append(nn.Conv2d(in_f, out_f, 3, 1, 1))
            if not first:
                layers.append(nn.BatchNorm2d(out_f))
            layers.append(nn.LeakyReLU(0.2, True))
            layers.append(nn.Conv2d(out_f, out_f, 3, 2, 1))
            layers.append(nn.BatchNorm2d(out_f))
            layers.append(nn.LeakyReLU(0.2, True))
            
            return layers
        
        in_f = in_c
        dis_layers = []
        for i, out_f in enumerate([64, 128, 256, 512]):
            dis_layers.extend(block(in_f, out_f, first=(i==0)))
            in_f = out_f
        
        dis_layers.append(nn.Conv2d(out_f, 1, 3, 1, 1))
        self.discriminator = nn.Sequential(
            *dis_layers
        )
    def forward(self, img):
        return self.discriminator(img)

In [8]:
from torchvision.models import vgg19

class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        vgg19_model = vgg19(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(vgg19_model.features.children())[:18])

    def forward(self, img):
        return self.feature_extractor(img)

In [9]:
hr_shape = (256, 256)

generator = GeneratorResNet(3, 3).to(device)
discriminator = Discriminator(in_shape=(3, *hr_shape)).to(device)
feature_extractor = FeatureExtractor().to(device)

In [10]:
# s = summary(generator, inputs=(1, 3, 512//4, 512//4), device="cuda")
# # generator

In [11]:
optim_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optim_d = optim.Adam(discriminator.parameters(), lr=0.0002,  betas=(0.5, 0.999))

criterion_gan_loss = nn.MSELoss()
criterion_content  = nn.L1Loss()

In [12]:
from torch.utils.data import Dataset, DataLoader
import re

class ImageDataset(Dataset):
    def __init__(self, root, hr_shape):
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        
        hr_height, hr_width = hr_shape
        self.lr_transform = tfms.Compose(
            [
                tfms.Resize((hr_height // 4, hr_height // 4), Image.BICUBIC),
                tfms.ToTensor(),
                tfms.Normalize(mean, std),
            ]
        )
        self.hr_transform = tfms.Compose(
            [
                tfms.Resize((hr_height, hr_height), Image.BICUBIC),
                tfms.ToTensor(),
                tfms.Normalize(mean, std),
            ]
        )

        self.lrs = sorted(glob.glob(root + "LR/*.*"))
        self.hrs = sorted(glob.glob(root + "HR/*.*"))

    def __getitem__(self, index):
        img_lr = Image.open(self.lrs[index % len(self.lrs)])
        img_hr = Image.open(self.hrs[index % len(self.hrs)])
        
        img_lr = self.lr_transform(img_lr)
        img_hr = self.hr_transform(img_hr)

        return {"lr": img_lr, "hr": img_hr}

    def __len__(self):
        return len(self.lrs)

In [13]:
data_loader = DataLoader(
    ImageDataset("../data/SRGAN/", hr_shape=hr_shape),
    batch_size=8,
    shuffle=True
)

In [19]:
generator.load_state_dict(torch.load('./models/generator.ptr'))
discriminator.load_state_dict(torch.load('./models/discriminator.ptr'))
feature_extractor.load_state_dict(torch.load('./models/feature_extractor.ptr'))

<All keys matched successfully>

In [25]:
total_batch = len(data_loader)

for epoch in range(30, 51):
    for batch_idx, imgs in enumerate(data_loader):
        lrs = imgs["lr"].to(device).float()
        hrs = imgs["hr"].to(device).float()
        
        valid = torch.ones((lrs.size(0), *discriminator.out_shape), requires_grad=False).to(device)
        fake  = torch.zeros((lrs.size(0), *discriminator.out_shape), requires_grad=False).to(device)
        
        optim_g.zero_grad()
        gen_hr = generator(lrs)
        
        gan_loss = criterion_gan_loss(discriminator(gen_hr), valid)
        gen_feature  = feature_extractor(gen_hr)
        real_feature = feature_extractor(hrs)
        cotent_loss = criterion_content(gen_feature, real_feature.detach())
        
        loss_G = 1e-3 * gan_loss + cotent_loss
        loss_G.backward()
        optim_g.step()
        
        optim_d.zero_grad()
        
        real_loss = criterion_gan_loss(discriminator(hrs), valid)
        fake_loss = criterion_content(discriminator(gen_hr.detach()), fake)
        loss_D = .5 * (real_loss + fake_loss)
        loss_D.backward()
        optim_d.step()
        
        if batch_idx % 100 == 0:
            sys.stdout.write(
                "[Epoch {:3d}/{:3d}]    [Batch {:8d}/{:8d}]    [D loss: {:10.7f}]    [G loss: {:10.7f}]\n".format(
                    epoch, 200, batch_idx, total_batch, loss_D.item(), loss_G.item()
            ))
        
        batches_done = (epoch-1) * total_batch + batch_idx
        if batches_done % 100 == 0:
            # Save image grid with upsampled inputs and SRGAN outputs
            imgs_lr = nn.functional.interpolate(lrs, scale_factor=4)

            gen_hr = make_grid(gen_hr, nrow=1, normalize=True)
            imgs_lr = make_grid(imgs_lr, nrow=1, normalize=True)
            imgs_hr = make_grid(hrs, nrow=1, normalize=True)
            img_grid = torch.cat((imgs_lr, imgs_hr, gen_hr), -1)
            save_image(img_grid, "imgs2/%d.png" % batches_done, normalize=False)

[Epoch  30/200]    [Batch        0/      35]    [D loss:  0.0056528]    [G loss:  0.8430027]
[Epoch  31/200]    [Batch        0/      35]    [D loss:  0.0081026]    [G loss:  0.9632586]
[Epoch  32/200]    [Batch        0/      35]    [D loss:  0.0035245]    [G loss:  0.8340473]
[Epoch  33/200]    [Batch        0/      35]    [D loss:  0.0084954]    [G loss:  0.9637439]
[Epoch  34/200]    [Batch        0/      35]    [D loss:  0.0040131]    [G loss:  0.9517164]
[Epoch  35/200]    [Batch        0/      35]    [D loss:  0.0200303]    [G loss:  0.7589475]
[Epoch  36/200]    [Batch        0/      35]    [D loss:  0.0024076]    [G loss:  0.7768597]
[Epoch  37/200]    [Batch        0/      35]    [D loss:  0.0057516]    [G loss:  0.7512628]
[Epoch  38/200]    [Batch        0/      35]    [D loss:  0.0020700]    [G loss:  0.6927186]
[Epoch  39/200]    [Batch        0/      35]    [D loss:  0.0008019]    [G loss:  0.6756210]
[Epoch  40/200]    [Batch        0/      35]    [D loss:  0.0019863]  

In [26]:
torch.save(generator.state_dict(), "./models/generator.ptr")
torch.save(discriminator.state_dict(), "./models/discriminator.ptr")
torch.save(feature_extractor.state_dict(), "./models/feature_extractor.ptr")