In [None]:
import os
import cv2
import time

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [None]:
class ResidualBlock(nn.Module):

    def __init__(self):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
          nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=1),
          nn.BatchNorm2d(num_features=64),
          nn.ReLU(),
          nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=1),
          nn.BatchNorm2d(num_features=64),
          nn.ReLU(),
        )

    def forward(self, x):
        z = self.block(x)
        x = x + z
        return x


class SRGAN_g(nn.Module):
    """ Generator in Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
    feature maps (n) and stride (s) feature maps (n) and stride (s)
    """

    def __init__(self):
        super(SRGAN_g, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=1),
            nn.BatchNorm2d(num_features=64),
            nn.ReLU()
        )
        self.residual_block = self.make_layer()
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=1),
            nn.BatchNorm2d(num_features=64),
            nn.ReLU()
        )
        self.upsample = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=1),
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.ReLU(),
            nn.BatchNorm2d(num_features=128),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=1),
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=3, kernel_size=(1, 1), stride=(1, 1)),
        )

    def make_layer(self):
        layer_list = OrderedDict()
        for i in range(10):
            layer_list['res_block'+str(i+1)] = ResidualBlock()

        return nn.Sequential(layer_list)

    def forward(self, x):
        x = self.conv1(x)
        temp = x
        x = self.residual_block(x)
        x = self.conv2(x)
        x = x + temp
        x = torch.sigmoid(self.upsample(x))
        return x



class SRGAN_d(nn.Module):

    def __init__(self, dim=64):
        super(SRGAN_d, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=dim, kernel_size=(4, 4), stride=(2, 2), padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(in_channels=dim, out_channels=dim * 2, kernel_size=(4, 4), stride=(2, 2), padding=1),
            nn.BatchNorm2d(num_features=dim * 2),
            nn.LeakyReLU(),
            nn.Conv2d(in_channels=dim * 2, out_channels=dim * 4, kernel_size=(4, 4), stride=(2, 2), padding=1),
            nn.BatchNorm2d(num_features=dim * 4),
            nn.LeakyReLU(),
            nn.Conv2d(in_channels=dim * 4, out_channels=dim * 4, kernel_size=(4, 4), stride=(2, 2), padding=1),
            nn.BatchNorm2d(num_features=dim * 4),
            nn.LeakyReLU(),
            nn.Conv2d(in_channels=dim * 4, out_channels=dim * 8, kernel_size=(4, 4), stride=(2, 2), padding=1),
            nn.BatchNorm2d(num_features=dim * 8),
            nn.LeakyReLU(),
            nn.Conv2d(in_channels=dim * 8, out_channels=dim, kernel_size=(3, 3), stride=(1, 1), padding=1),
            nn.BatchNorm2d(num_features=dim),
            nn.LeakyReLU()
        )

        self.flat = nn.Flatten()
        self.dense = nn.Sequential(
            nn.Linear(in_features=9216, out_features=4608),
            nn.Linear(in_features=4608, out_features=1),
          )

    def forward(self, x):
        x = self.block(x)
        x = self.flat(x)
        x = self.dense(x)
        x = torch.sigmoid(x)
        return x

In [None]:
__all__ = [
    'VGG',
    'vgg16',
    'vgg19',
    'VGG16',
    'VGG19',
]

layer_names = [
    ['conv1_1', 'conv1_2'], 'pool1', ['conv2_1', 'conv2_2'], 'pool2',
    ['conv3_1', 'conv3_2', 'conv3_3', 'conv3_4'], 'pool3', ['conv4_1', 'conv4_2', 'conv4_3', 'conv4_4'], 'pool4',
    ['conv5_1', 'conv5_2', 'conv5_3', 'conv5_4'], 'pool5', 'flatten', 'fc1_relu', 'fc2_relu', 'outputs'
]

cfg = {
    'A': [[64], 'M', [128], 'M', [256, 256], 'M', [512, 512], 'M', [512, 512], 'M', 'F', 'fc1', 'fc2', 'O'],
    'B': [[64, 64], 'M', [128, 128], 'M', [256, 256], 'M', [512, 512], 'M', [512, 512], 'M', 'F', 'fc1', 'fc2', 'O'],
    'D':
        [
            [64, 64], 'M', [128, 128], 'M', [256, 256, 256], 'M', [512, 512, 512], 'M', [512, 512, 512], 'M', 'F',
            'fc1', 'fc2', 'O'
        ],
    'E':
        [
            [64, 64], 'M', [128, 128], 'M', [256, 256, 256, 256], 'M', [512, 512, 512, 512], 'M', [512, 512, 512, 512],
            'M', 'F', 'fc1', 'fc2', 'O'
        ],
}

mapped_cfg = {
    'vgg11': 'A',
    'vgg11_bn': 'A',
    'vgg13': 'B',
    'vgg13_bn': 'B',
    'vgg16': 'D',
    'vgg16_bn': 'D',
    'vgg19': 'E',
    'vgg19_bn': 'E'
}

model_urls = {
    'vgg16': 'https://git.openi.org.cn/attachments/760835b9-db71-4a00-8edd-d5ece4b6b522?type=0',
    'vgg19': 'https://git.openi.org.cn/attachments/503c8a6c-705f-4fb6-ba18-03d72b6a949a?type=0'
}

model_saved_name = {'vgg16': 'vgg16_weights.npz', 'vgg19': 'vgg19.npy'}


class VGG(nn.Module):

    def __init__(self, layer_type, batch_norm=False, end_with='outputs', name=None):
        super(VGG, self).__init__()
        self.end_with = end_with

        config = cfg[mapped_cfg[layer_type]]
        self.make_layer = make_layers(config, batch_norm, end_with)

    def forward(self, inputs):
        """
        inputs : tensor
            Shape [None, 224, 224, 3], value range [0, 1].
        """

        inputs = inputs * 255. - torch.as_tensor(np.array([123.68, 116.779, 103.939], dtype=np.float32).reshape(-1,1,1)).to(device)
        out = self.make_layer(inputs)
        return out


def make_layers(config, batch_norm=False, end_with='outputs'):
    layer_list = OrderedDict()
    is_end = False
    for layer_group_idx, layer_group in enumerate(config):
        if isinstance(layer_group, list):
            for idx, layer in enumerate(layer_group):
                layer_name = layer_names[layer_group_idx][idx]
                n_filter = layer
                if idx == 0:
                    if layer_group_idx > 0:
                        in_channels = config[layer_group_idx - 2][-1]
                    else:
                        in_channels = 3
                else:
                    in_channels = layer_group[idx - 1]
                layer_list[layer_name+str(idx)] = nn.Sequential(
                        nn.Conv2d(in_channels=in_channels, out_channels=n_filter, kernel_size=(3, 3), stride=(1, 1), padding=1),
                        nn.ReLU()
                        )
                if batch_norm:
                    layer_list[layer_name+"_batch_norm"] = nn.BatchNorm(num_features=n_filter)
                if layer_name == end_with:
                    is_end = True
                    break
        else:
            layer_name = layer_names[layer_group_idx]
            if layer_group == 'M':
                layer_list[layer_name] = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=1)
            elif layer_group == 'O':
                layer_list[layer_name] = nn.Linear(out_features=1000, in_features=4096)
            elif layer_group == 'F':
                layer_list[layer_name] = nn.Flatten(name='flatten')
            elif layer_group == 'fc1':
                layer_list[layer_name] = nn.Sequential(
                      nn.Linear(out_features=4096, in_features=512 * 7 * 7),
                      nn.ReLU()
                    )
            elif layer_group == 'fc2':
                layer_list[layer_name] = nn.Sequential(
                      nn.Linear(out_features=4096, in_features=4096),
                      nn.ReLU()
                    )
            if layer_name == end_with:
                is_end = True
        if is_end:
            break

    return nn.Sequential(layer_list)

def restore_model(model, layer_type):
    # download weights
    weights = []
    if layer_type == 'vgg16':
        npz = np.load(os.path.join('model', model_saved_name[layer_type]), allow_pickle=True)
        # get weight list
        for val in sorted(npz.items()):
            weights.append(val[1])
            if len(list(list(model.children())[0].children())) == len(weights):
                break
    elif layer_type == 'vgg19':
        npz = np.load(os.path.join('model', model_saved_name[layer_type]), allow_pickle=True, encoding='latin1').item()
        # get weight list
        for val in sorted(npz.items()):
            weights.extend(val[1])
            if len(list(list(model.children())[0].children())) == len(weights):
                break
    # assign weight values
    for i in range(len(weights)):
        if len(weights[i].shape) == 4:
            weights[i] = np.transpose(weights[i], axes=[3, 2, 0, 1])
    model.weights = weights
    del weights


def vgg19(pretrained=False, end_with='outputs', mode='dynamic', name=None):
    """Pre-trained VGG19 model.

    Parameters
    ------------
    pretrained : boolean
        Whether to load pretrained weights. Default False.
    end_with : str
        The end point of the model. Default ``fc3_relu`` i.e. the whole model.
    mode : str.
        Model building mode, 'dynamic' or 'static'. Default 'dynamic'.
    name : None or str
        A unique layer name.

    Examples
    ---------
    Classify ImageNet classes with VGG19, see `tutorial_models_vgg.py <https://github.com/tensorlayer/TensorLayerX/blob/main/examples/model_zoo/vgg.py>`__
    With TensorLayerx

    >>> # get the whole model, without pre-trained VGG parameters
    >>> vgg = vgg19()
    >>> # get the whole model, restore pre-trained VGG parameters
    >>> vgg = vgg19(pretrained=True)
    >>> # use for inferencing
    >>> output = vgg(img)
    >>> probs = tlx.ops.softmax(output)[0].numpy()

    """
    if mode == 'dynamic':
        model = VGG(layer_type='vgg19', batch_norm=False, end_with=end_with, name=name)
    elif mode == 'static':
        raise NotImplementedError
    else:
        raise Exception("No such mode %s" % mode)
    if pretrained:
        restore_model(model, layer_type='vgg19')
    return model


VGG19 = vgg19

In [None]:
datasets.Flowers102('train_data',download=True)

Downloading https://thor.robots.ox.ac.uk/datasets/flowers-102/102flowers.tgz to train_data/flowers-102/102flowers.tgz


100%|██████████| 344862509/344862509 [00:20<00:00, 16613703.44it/s]


Extracting train_data/flowers-102/102flowers.tgz to train_data/flowers-102
Downloading https://thor.robots.ox.ac.uk/datasets/flowers-102/imagelabels.mat to train_data/flowers-102/imagelabels.mat


100%|██████████| 502/502 [00:00<00:00, 160275.60it/s]


Downloading https://thor.robots.ox.ac.uk/datasets/flowers-102/setid.mat to train_data/flowers-102/setid.mat


100%|██████████| 14989/14989 [00:00<00:00, 30848097.48it/s]


Dataset Flowers102
    Number of datapoints: 1020
    Root location: train_data
    split=train

In [None]:
device = 'cuda:0'

###====================== HYPER-PARAMETERS ===========================###
batch_size = 4
n_epoch_init = 10
n_epoch = 2000
# create folders to save result images and trained models
save_dir = "samples"
if not os.path.isdir('/content/'+save_dir):
  os.makedirs('/content/'+save_dir)
checkpoint_dir = "models"
if not os.path.isdir('/content/'+checkpoint_dir):
  os.makedirs('/content/'+checkpoint_dir)

hr_transform = transforms.Compose([
    transforms.Resize(size = (384, 384)),
    transforms.RandomHorizontalFlip(0.5),
])
nor = transforms.Compose([
    transforms.ToTensor(),
])
lr_transform = transforms.Resize(size=(96, 96))

train_hr_imgs = [img for img in os.walk('/content/train_data/flowers-102/jpg/')]


class TrainData(Dataset):

    def __init__(self, hr_trans=hr_transform, lr_trans=lr_transform):
        self.train_hr_imgs = train_hr_imgs[0][2:][0]
        self.hr_trans = hr_trans
        self.lr_trans = lr_trans

    def __getitem__(self, index):
        img = Image.open('/content/train_data/flowers-102/jpg/'+self.train_hr_imgs[index])
        hr_patch = self.hr_trans(img)
        lr_patch = self.lr_trans(hr_patch)

        return nor(lr_patch), nor(hr_patch)

    def __len__(self):
        return len(self.train_hr_imgs)


class WithLoss_init(nn.Module):
    def __init__(self, vgg, G_net, loss_fn):
        super(WithLoss_init, self).__init__()
        self.net = G_net
        self.vgg = vgg
        self.loss_fn = loss_fn
        self.counter = 0
        self.trans = transforms.Compose([
            transforms.Resize(size = (224, 224)),
        ])

    def forward(self, lr, hr):
        out = self.net(lr)
        if self.counter == 50:
          plt.imshow(out[0].detach().cpu().squeeze().permute(1,2,0))
          plt.show()
          self.counter = 0
        else:
          self.counter += 1
        feature_fake = self.vgg(self.trans(out))
        feature_real = self.vgg(self.trans(hr))
        loss = self.loss_fn(out, hr)
        vgg_loss = 100000 * self.loss_fn(feature_fake, feature_real)
        return loss + vgg_loss


class WithLoss_D(nn.Module):
    def __init__(self, D_net, G_net, loss_fn):
        super(WithLoss_D, self).__init__()
        self.D_net = D_net
        self.G_net = G_net
        self.loss_fn = loss_fn

    def forward(self, lr, hr):
        fake_patchs = self.G_net(lr)
        logits_fake = self.D_net(fake_patchs)
        logits_real = self.D_net(hr)
        d_loss1 = self.loss_fn(logits_real, torch.ones_like(logits_real))
        d_loss1 = torch.mean(d_loss1)
        d_loss2 = self.loss_fn(logits_fake, torch.zeros_like(logits_fake))
        d_loss2 = torch.mean(d_loss2)
        d_loss = d_loss1 + d_loss2
        return d_loss


class WithLoss_G(nn.Module):
    def __init__(self, vgg, D_net, G_net, loss_fn1, loss_fn2):
        super(WithLoss_G, self).__init__()
        self.D_net = D_net
        self.G_net = G_net
        self.vgg = vgg
        self.loss_fn1 = loss_fn1
        self.loss_fn2 = loss_fn2
        self.counter = 0
        self.trans = transforms.Compose([
            transforms.Resize(size = (224, 224)),
        ])

    def forward(self, lr, hr):
        fake_patchs = self.G_net(lr)
        if self.counter == 200:
          plt.imshow(fake_patchs[0].detach().cpu().squeeze().permute(1,2,0))
          plt.show()
          self.counter = 0
        else:
          self.counter += 1
        logits_fake = self.D_net(fake_patchs)
        # feature_fake = self.vgg(self.trans(fake_patchs))
        # feature_real = self.vgg(self.trans(hr))

        g_gan_loss = self.loss_fn1(logits_fake, torch.ones_like(logits_fake))
        g_gan_loss = torch.mean(g_gan_loss)
        mse_loss = self.loss_fn2(fake_patchs, hr)
        # vgg_loss = 1000 * self.loss_fn2(feature_fake, feature_real)
        g_loss = mse_loss + g_gan_loss #+ vgg_loss
        return g_loss


G = SRGAN_g()
D = SRGAN_d()
G = G.to(device)
D = D.to(device)
VGG = VGG19(pretrained=True, end_with='pool4', mode='dynamic')
VGG = VGG.to(device)


def train():
    G.train()
    D.train()
    VGG.eval()
    train_ds = TrainData()
    train_ds = [train_ds.__getitem__(i) for i in range(40)]
    train_ds_img_nums = len(train_ds)
    train_ds = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True)

    g_optimizer = optim.Adam(G.parameters(), lr=0.0001)
    scheduler_g = optim.lr_scheduler.ExponentialLR(g_optimizer, gamma=0.5)
    d_optimizer = optim.Adam(D.parameters(), lr=0.0001)
    scheduler_d = optim.lr_scheduler.ExponentialLR(d_optimizer, gamma=0.5)
    g_weights = G.parameters()
    d_weights = D.parameters()
    net_with_loss_init = WithLoss_init(VGG, G, loss_fn=F.mse_loss)
    net_with_loss_D = WithLoss_D(D_net=D, G_net=G, loss_fn=F.binary_cross_entropy)
    criterion_D = F.cross_entropy
    net_with_loss_G = WithLoss_G(vgg=VGG, D_net=D, G_net=G, loss_fn1=F.binary_cross_entropy,
                                 loss_fn2=F.mse_loss)

    # initialize learning (G)
    n_step_epoch = round(train_ds_img_nums // batch_size)
    for epoch in range(n_epoch_init):
        for step, (lr_patch, hr_patch) in enumerate(train_ds):
            step_time = time.time()
            g_optimizer.zero_grad()

            loss = net_with_loss_init(lr_patch.to(device).float(), hr_patch.to(device).float())
            loss.backward()
            g_optimizer.step()
            print("Epoch: [{}/{}] step: [{}/{}] time: {:.3f}s, mse: {:.7f} ".format(
                epoch, n_epoch_init, step, n_step_epoch, time.time() - step_time, float(loss)))

    # adversarial learning (G, D)
    n_step_epoch = round(train_ds_img_nums // batch_size)
    for epoch in range(n_epoch):
        for step, (lr_patch, hr_patch) in enumerate(train_ds):
            step_time = time.time()
            g_optimizer.zero_grad()
            d_optimizer.zero_grad()
            loss_g = net_with_loss_G(lr_patch.to(device).float(), hr_patch.to(device).float())
            loss_g.backward()
            g_optimizer.step()
            loss_d = net_with_loss_D(lr_patch.to(device).float(), hr_patch.to(device).float())
            loss_d.backward()
            d_optimizer.step()

            print(
                "Epoch: [{}/{}] step: [{}/{}] time: {:.3f}s, g_loss:{:.7f}, d_loss: {:.7f}".format(
                    epoch, n_epoch, step, n_step_epoch, time.time() - step_time, float(loss_g), float(loss_d)))

        if epoch % 250 == 0:
          scheduler_g.step()
          scheduler_d.step()

        if epoch != 0 and epoch % 250 == 0 or epoch == n_epoch-1:
            torch.save(G.state_dict(), os.path.join(checkpoint_dir, 'g.npz'))
            torch.save(D.state_dict(), os.path.join(checkpoint_dir, 'd.npz'))


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()

    parser.add_argument('--mode', type=str, default='train', help='train, eval')

    args = parser.parse_known_args()

    flag = args[0].mode

    if flag == 'train':
        train()
    elif flag == 'eval':
        evaluate()
    else:
        raise Exception("Unknow --mode")

In [None]:
torch.save(G.state_dict(), os.path.join(checkpoint_dir, 'g.npz'))
torch.save(D.state_dict(), os.path.join(checkpoint_dir, 'd.npz'))

In [None]:
def evaluate():
    ###====================== PRE-LOAD DATA ===========================###
    # valid_hr_imgs = Image.open(path='/content/seg_test/seg_test/street/')
    ###========================LOAD WEIGHTS ============================###
    G = SRGAN_g()
    G = G.to(device)
    G.load_state_dict(torch.load(os.path.join('/content/models', 'g.npz')))
    G.eval()
    valid_hr_img = Image.open('/content/train_data/flowers-102/jpg/image_00002.jpg')
    valid_lr_img = np.asarray(valid_hr_img)
    hr_size1 = [valid_lr_img.shape[0], valid_lr_img.shape[1]]

    my_trans = transforms.Resize(size=(96,96))

    nor = transforms.Compose([
      transforms.ToTensor(),
    ])

    valid_lr_img = np.asarray(my_trans(valid_hr_img))
    valid_lr_img_tensor = nor(valid_lr_img)

    valid_lr_img_tensor = np.asarray(valid_lr_img_tensor, dtype=np.float32)
    valid_lr_img_tensor = valid_lr_img_tensor[np.newaxis, :, :, :]
    valid_lr_img_tensor= torch.as_tensor(valid_lr_img_tensor)
    size = [valid_lr_img.shape[0], valid_lr_img.shape[1]]
    out = np.array(G(valid_lr_img_tensor.to(device).float()).detach().cpu())
    out = np.asarray(out * 255, dtype=np.uint8)
    out = np.transpose(out[0], axes=[1, 2, 0])
    print("LR size: %s /  generated HR size: %s" % (size, out.shape))  # LR size: (339, 510, 3) /  gen HR size: (1, 1356, 2040, 3)
    print("[*] save images")
    Image.fromarray(out).save(save_dir + '/valid_gen.png')
    Image.fromarray(valid_lr_img).save(save_dir + '/valid_lr.png')
    valid_hr_img.save(save_dir + '/valid_hr.png')
    out_bicu = cv2.resize(valid_lr_img, dsize = [size[1] * 4, size[0] * 4], interpolation = cv2.INTER_CUBIC)
    Image.fromarray(out_bicu).save(save_dir + '/valid_hr_cubic.png')


evaluate()