In [1]:
import math
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import (
    Compose,
    ToTensor,
    ToPILImage,
    CenterCrop,
    Resize,
)
from PIL import Image
import os
from torchvision.models.vgg import vgg16
from torch import optim
from tqdm import tqdm
from datetime import datetime
import torchvision

In [2]:
class Generator(nn.Module):
    def __init__(self, scale_factor):
        upsample_block_num = int(math.log(scale_factor, 2))

        super(Generator, self).__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=9, padding=4),
            nn.PReLU()
        )
        self.block2 = ResidualBlock(64)
        self.block3 = ResidualBlock(64)
        self.block4 = ResidualBlock(64)
        self.block5 = ResidualBlock(64)
        self.block6 = ResidualBlock(64)
        self.block7 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64)
        )
        block8 = [UpsampleBLock(64, 2) for _ in range(upsample_block_num)]
        block8.append(nn.Conv2d(64, 3, kernel_size=9, padding=4))
        self.block8 = nn.Sequential(*block8)

    def forward(self, x):
        block1 = self.block1(x)
        block2 = self.block2(block1)
        block3 = self.block3(block2)
        block4 = self.block4(block3)
        block5 = self.block5(block4)
        block6 = self.block6(block5)
        block7 = self.block7(block6)
        block8 = self.block8(block1 + block7)

        return (torch.tanh(block8) + 1) / 2


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),

            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),

            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(512, 1024, kernel_size=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(1024, 1, kernel_size=1)
        )

    def forward(self, x):
        batch_size = x.size(0)
        return torch.sigmoid(self.net(x).view(batch_size))


class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.prelu = nn.PReLU()
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = self.conv1(x)
        residual = self.bn1(residual)
        residual = self.prelu(residual)
        residual = self.conv2(residual)
        residual = self.bn2(residual)

        return x + residual


class UpsampleBLock(nn.Module):
    def __init__(self, in_channels, up_scale):
        super(UpsampleBLock, self).__init__()
        self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(up_scale)
        self.prelu = nn.PReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.pixel_shuffle(x)
        x = self.prelu(x)
        return x

In [None]:
def train_hr_transform():
    return Compose(
        [
            ToTensor(),
        ]
    )


def train_lr_transform(size):
    return Compose(
        [ToPILImage(), Resize(size, interpolation=Image.BICUBIC), ToTensor()]
    )


def display_transform():
    return Compose([ToPILImage(), Resize(400), CenterCrop(400), ToTensor()])


def is_image_file(filename):
    return any(
        filename.endswith(extension)
        for extension in [".png", ".jpg", ".jpeg", ".PNG", ".JPG", ".JPEG"]
    )


class TrainDatasetFromFolder(Dataset):
    def __init__(self, dataset_dir, upscale_factor):
        super(TrainDatasetFromFolder, self).__init__()
        self.image_filenames = sorted([
            os.path.join(dataset_dir, x)
            for x in os.listdir(dataset_dir)
            if is_image_file(x)
        ][:100])
        self.hr_transform = train_hr_transform()
        self.lr_transform = train_lr_transform(256 // upscale_factor)

    def __getitem__(self, index):
        hr_image = self.hr_transform(Image.open(self.image_filenames[index]))
        lr_image = self.lr_transform(hr_image)
        return lr_image, hr_image

    def __len__(self):
        return len(self.image_filenames)

In [4]:
class GeneratorLoss(nn.Module):
    def __init__(self):
        super(GeneratorLoss, self).__init__()
        vgg = vgg16(pretrained=True)
        loss_network = nn.Sequential(*list(vgg.features)[:30]).eval()
        for param in loss_network.parameters():
            param.requires_grad = False
        self.loss_network = loss_network
        self.mse_loss = nn.MSELoss()
        self.tv_loss = TVLoss()

    def forward(self, out_labels, out_images, target_images):
        # Adversarial Loss
        adversarial_loss = torch.mean(1 - out_labels)
        # Perception Loss
        perception_loss = self.mse_loss(self.loss_network(out_images), self.loss_network(target_images))
        # Image Loss
        image_loss = self.mse_loss(out_images, target_images)
        # TV Loss
        tv_loss = self.tv_loss(out_images)
        return image_loss + 0.001 * adversarial_loss + 0.006 * perception_loss + 2e-8 * tv_loss


class TVLoss(nn.Module):
    def __init__(self, tv_loss_weight=1):
        super(TVLoss, self).__init__()
        self.tv_loss_weight = tv_loss_weight

    def forward(self, x):
        batch_size = x.size()[0]
        h_x = x.size()[2]
        w_x = x.size()[3]
        count_h = self.tensor_size(x[:, :, 1:, :])
        count_w = self.tensor_size(x[:, :, :, 1:])
        h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
        w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
        return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size

    @staticmethod
    def tensor_size(t):
        return t.size()[1] * t.size()[2] * t.size()[3]


In [7]:
params = {
    "upscale_factor": 4,
    "num_epochs": 5,
    "batch_size": 32,
    "num_workers": 4,
    "dataset_path": "images_256",
}

train_set = TrainDatasetFromFolder(
    params["dataset_path"], upscale_factor=params["upscale_factor"]
)
train_loader = DataLoader(
    dataset=train_set,
    num_workers=params["num_workers"],
    batch_size=params["batch_size"],
    shuffle=True,
)

netG = Generator(params["upscale_factor"])
print("# generator parameters:", sum(param.numel() for param in netG.parameters()))
netD = Discriminator()
print("# discriminator parameters:", sum(param.numel() for param in netD.parameters()))

generator_criterion = GeneratorLoss()

if torch.cuda.is_available():
    netG.cuda()
    netD.cuda()
    generator_criterion.cuda()

optimizerG = optim.Adam(netG.parameters())
optimizerD = optim.Adam(netD.parameters())

results = {
    "d_loss": [],
    "g_loss": [],
    "d_score": [],
    "g_score": [],
}

dirname = f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
os.makedirs(dirname)

for epoch in range(1, params["num_epochs"] + 1):
    train_bar = tqdm(train_loader)
    running_results = {
        "batch_sizes": 0,
        "d_loss": 0,
        "g_loss": 0,
        "d_score": 0,
        "g_score": 0,
    }
    netG.train()
    netD.train()

    for data, target in train_bar:
        g_update_first = True
        batch_size = data.size(0)
        running_results["batch_sizes"] += batch_size

        ############################
        # (1) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
        ###########################
        real_img = target
        if torch.cuda.is_available():
            real_img = real_img.float().cuda()
        z = data
        if torch.cuda.is_available():
            z = z.float().cuda()
        fake_img = netG(z)
        fake_out = netD(fake_img).mean()

        optimizerG.zero_grad()
        g_loss = generator_criterion(fake_out, fake_img, real_img)
        g_loss.backward()
        optimizerG.step()

        ############################
        # (2) Update D network: maximize D(x)-1-D(G(z))
        ###########################
        real_out = netD(real_img).mean()
        fake_out = netD(fake_img.detach()).mean()
        d_loss = 1 - real_out + fake_out

        optimizerD.zero_grad()
        d_loss.backward()

        fake_img = netG(z)
        fake_out = netD(fake_img).mean()

        optimizerD.step()

        # loss for current batch before optimization
        running_results["g_loss"] += g_loss.item() * batch_size
        running_results["d_loss"] += d_loss.item() * batch_size
        running_results["d_score"] += real_out.item() * batch_size
        running_results["g_score"] += fake_out.item() * batch_size

        train_bar.set_description(
            desc="[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f"
            % (
                epoch,
                params["num_epochs"],
                running_results["d_loss"] / running_results["batch_sizes"],
                running_results["g_loss"] / running_results["batch_sizes"],
                running_results["d_score"] / running_results["batch_sizes"],
                running_results["g_score"] / running_results["batch_sizes"],
            )
        )
            # save loss\scores\psnr\ssim
        results["d_loss"].append(
            running_results["d_loss"] / running_results["batch_sizes"]
        )
        results["g_loss"].append(
            running_results["g_loss"] / running_results["batch_sizes"]
        )
        results["d_score"].append(
            running_results["d_score"] / running_results["batch_sizes"]
        )
        results["g_score"].append(
            running_results["g_score"] / running_results["batch_sizes"]
        )
            
    if epoch % 1 == 0:
        torch.save(
            netG.state_dict(),
            "%s/netG_epoch_%d_%d.pth" % (dirname, epoch, params["upscale_factor"]),
        )
        torch.save(
            netD.state_dict(),
            "%s/netD_epoch_%d_%d.pth" % (dirname, epoch, params["upscale_factor"]),
        )
        with torch.no_grad():
            fake_images = netG(data[:4])
            torchvision.utils.save_image(
                fake_images,
                f"{dirname}/output_epoch_{epoch + 1}.png",
                normalize=True,
            )


# generator parameters: 734219
# discriminator parameters: 5215425


[1/5] Loss_D: 0.7280 Loss_G: 0.0359 D(x): 0.4856 D(G(z)): 0.2715: 100%|██████████| 25/25 [23:06<00:00, 55.47s/it]
[2/5] Loss_D: 0.9946 Loss_G: 0.0241 D(x): 0.4877 D(G(z)): 0.4822: 100%|██████████| 25/25 [23:43<00:00, 56.94s/it]
[3/5] Loss_D: 0.9940 Loss_G: 0.0215 D(x): 0.5617 D(G(z)): 0.5556: 100%|██████████| 25/25 [23:52<00:00, 57.31s/it]
[4/5] Loss_D: 0.9913 Loss_G: 0.0191 D(x): 0.4954 D(G(z)): 0.4931: 100%|██████████| 25/25 [23:45<00:00, 57.01s/it]
[5/5] Loss_D: 0.9968 Loss_G: 0.0185 D(x): 0.4657 D(G(z)): 0.4605: 100%|██████████| 25/25 [23:43<00:00, 56.95s/it]
