In [1]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
from PIL import Image
import pathlib
import numpy as np
import time

In [2]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Thu Jan 13 05:33:07 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.46       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   39C    P0    27W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
###
# Generator
###
class ResBlock_Generator(nn.Module):

    def __init__(self, ch_in, ch_out, kernel_size=3, stride=1):
        super(ResBlock_Generator, self).__init__()

        self.residual = nn.Sequential(
            nn.Conv2d(ch_in, ch_in, kernel_size=kernel_size, stride=stride, padding=1),
            nn.BatchNorm2d(ch_in),
            nn.PReLU(),
            nn.Conv2d(ch_in, ch_out, kernel_size=kernel_size, stride=stride, padding=1),
            nn.BatchNorm2d(ch_out)
        )

    def forward(self, x):
        residual = self.residual(x)
        return x + residual

class Upsample_Generator(nn.Module):

    def __init__(self, ch_in, ch_out, upscale_factor=2):
        super(Upsample_Generator, self).__init__()

        self.UpsampleBlock = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1),
            nn.PixelShuffle(upscale_factor=upscale_factor),
            nn.PReLU()
        )

    def forward(self, x):
        return self.UpsampleBlock(x)

class Generator(nn.Module):

    def __init__(self, ch_in, ch_out, n_ResBlock=5):
        super(Generator, self).__init__()

        self.Conv_input = nn.Sequential(
            nn.Conv2d(ch_in, 64, kernel_size=9, stride=1, padding=4),
            nn.PReLU(),
        )

        Resnet = []
        for i in range(n_ResBlock):
            Resnet.append(ResBlock_Generator(64, 64, kernel_size=3, stride=1))
        self.ResNet = nn.Sequential(*Resnet)

        self.Conv_afterRes = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
        )

        self.UpsampleNet = nn.Sequential(
            Upsample_Generator(64, 256, upscale_factor=2),
            # Upsample_Generator(64, 256, upscale_factor=2)
        )

        self.Conv_output = nn.Conv2d(64, ch_out, kernel_size=9, stride=1, padding=4)

    def forward(self, x):
        out1 = self.Conv_input(x)
        x = self.ResNet(out1)
        x = self.Conv_afterRes(x)
        x = x + out1
        x = self.UpsampleNet(x)
        out = self.Conv_output(x)

        return out


###
# Discriminator
###

class Discriminator(nn.Module):

    def __init__(self, ch_in, ch_out, n_Convblock=7):
        super(Discriminator, self).__init__()

        self.Conv_input = nn.Sequential(
            nn.Conv2d(ch_in, 64, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2)
        )

        ch_in_ConvNet = [64, 64, 128, 128, 256, 256, 512]
        ch_out_ConvNet = [64, 128, 128, 256, 256, 512, 512]
        stride_CovNet = [2, 1, 2, 1, 2, 1, 2]
        self.ConvNet = nn.ModuleList(
            [self.Get_ConvBlock(ch_in_ConvNet[i], ch_out_ConvNet[i], stride=stride_CovNet[i])
             for i in range(n_Convblock)]
        )

        self.output = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(512, 1024, kernel_size=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(1024, 1, kernel_size=1),
        )

    def forward(self, x):
        x = self.Conv_input(x)
        for ConvPath in self.ConvNet:
            x = ConvPath(x)
        x = self.output(x)
        out = torch.sigmoid(x.view(x.size(0), -1))

        return out

    def Get_ConvBlock(self, ch_in, ch_out, stride):
        return nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1),
            nn.BatchNorm2d(ch_out),
            nn.LeakyReLU(0.2)
        )

In [5]:
class CelebAdata(Dataset):

    def __init__(self, image_dir):
        super(CelebAdata, self).__init__()

        self.image_dir = image_dir
        self.image_list = self.get_image_list()
        self.means, self.stds = self.get_mean_std()

    def get_image_list(self):
        return list(sorted(pathlib.Path(self.image_dir).glob("*.jpg")))

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, index):
        image_path = self.image_list[index]

        image = Image.open(str(image_path))

        image_lr = self.get_augmentation(image)
        image_hr = transforms.ToTensor()(image)
        image_hr = transforms.Normalize(mean=self.means,
                                std=self.stds)(image_hr)

        return image_hr, image_lr

    def get_mean_std(self):
        means = np.zeros((3, 1), dtype=np.float)
        stds = np.zeros((3, 1), dtype=np.float)
        for i in range(len(self.image_list)):
            image_path = self.image_list[i]
            image = transforms.ToTensor()(Image.open(str(image_path)))
            image = image.cpu().numpy()
            for ch in range(3):
                means[ch] += image[ch, :, :].mean()
                stds[ch] += image[ch, :, :].std()

        means = means / len(self.image_list)
        stds = stds / len(self.image_list)

        return means.squeeze(), stds.squeeze()

    def get_augmentation(self, image):
        zoom_transform = [
            transforms.Resize((109, 89)),
            transforms.ToTensor(),
            transforms.Normalize(mean=self.means,
                                 std=self.stds)
        ]
        aug = transforms.Compose(zoom_transform)
        image = aug(image)
        
        return image

def get_dataloader(image_dir, batch_size):
    CelebA = CelebAdata(image_dir)
    CelebALoader = DataLoader(CelebA, batch_size=batch_size, shuffle=True, drop_last=True)

    return CelebALoader, CelebA.means, CelebA.stds

In [6]:
from torchvision.models.vgg import vgg16

class CententLoss(nn.Module):

    def __init__(self):
        super(CententLoss, self).__init__()

        vgg = vgg16(pretrained=True)
        loss_net = nn.Sequential(*list(vgg.features)[:31]).eval()
        for par in loss_net.parameters():
            par.requires_grad = False

        self.loss_net = loss_net
        self.loss = nn.MSELoss()

    def forward(self, image_sr, image_hr):
        loss_perception = self.loss(self.loss_net(image_sr), self.loss_net(image_hr))
        loss_mse = self.loss(image_sr, image_hr)

        return loss_mse + 0.006 * loss_perception

class AdversarialLoss(nn.Module):

    def __init__(self):
        super(AdversarialLoss, self).__init__()

        self.loss = nn.BCELoss()

    def forward(self, label_f):
        loss_adversarial = self.loss(label_f, torch.ones_like(label_f))

        return 0.001 * loss_adversarial

In [7]:
from skimage.metrics import structural_similarity

def compare_ssim(im1, im2):
  ssim = structural_similarity(im1, im2, data_range=1.0, channel_axis=3, multichannel=True)
  return ssim

In [8]:
import shutil
train_dir = "/content/drive/Shareddrives/MachineLearning/Colab_Notebooks/DLearning/SRGAN/Train_data"
shutil.copytree(train_dir, 'Train_data')
eval_dir = "/content/drive/Shareddrives/MachineLearning/Colab_Notebooks/DLearning/SRGAN/Eval_data"
shutil.copytree(eval_dir, 'Eval_data')

'Eval_data'

In [9]:
batchsize = 32
epochs = 100
ch_in = 3
ch_out = 3
train_dir = "/content/Train_data/"
eval_dir = "/content/Eval_data/"
file_name = "_E%d_B%d.pth" % (epochs, batchsize)
model_dir = "/content/drive/Shareddrives/MachineLearning/Colab_Notebooks/DLearning/SRGAN/Model/"
device = torch.device('cuda')

trainData = CelebAdata(train_dir)
train_mean = trainData.means
train_std = trainData.stds
trainLoader = DataLoader(trainData, batch_size=batchsize, shuffle=True, drop_last=True)

evalData = CelebAdata(eval_dir)
eval_mean = evalData.means
eval_std = evalData.stds
evalLoader = DataLoader(evalData, batch_size=8, shuffle=True, drop_last=True)

Gmodel = Generator(ch_in, ch_out).to(device)
Dmodel = Discriminator(ch_in, ch_out).to(device)

Goptim = torch.optim.Adam(Gmodel.parameters(), lr=0.0001)
Doptim = torch.optim.Adam(Dmodel.parameters(), lr=0.0001)

bce = nn.BCELoss()
contentLoss = CententLoss().to(device)
adversarialLoss = AdversarialLoss().to(device)

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth


  0%|          | 0.00/528M [00:00<?, ?B/s]

In [10]:
print(train_mean)
print(train_std)
print(eval_mean)
print(eval_std)

[0.50656407 0.42607656 0.38357797]
[0.26642735 0.24562411 0.24160843]
[0.50605952 0.42561501 0.38276385]
[0.26673186 0.24566819 0.2417474 ]


In [11]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Thu Jan 13 05:38:51 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.46       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   39C    P0    34W / 250W |   1089MiB / 16280MiB |      3%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [12]:
# load model

Gmodel = torch.load(model_dir+"G"+file_name)
Dmodel = torch.load(model_dir+"D"+file_name)

In [None]:
g_loss_recorder = []
d_loss_recorder = []

since = time.time()

for i in range(epochs):
    Gmodel.train()
    Dmodel.train()

    print('-' * 10)
    print('Epoch {}/{}'.format(i + 1, epochs))
    print('-' * 10)

    g_loss_epoch = 0.0
    d_loss_epoch = 0.0

    for image_hr, image_lr in trainLoader:
        image_hr = image_hr.to(device)
        image_lr = image_lr.to(device)

        # train discriminator
        Doptim.zero_grad()

        image_sr = Gmodel(image_lr)
        logits_fake = Dmodel(image_sr)
        logits_true = Dmodel(image_hr)

        true_smooth = torch.tensor(np.random.rand(batchsize, 1) * 0.1 + 0.9, dtype=torch.float).to(device)
        fake_smooth = torch.tensor(np.random.rand(batchsize, 1) * 0.1, dtype=torch.float).to(device)

        d_loss = bce(logits_true, true_smooth) + bce(logits_fake, fake_smooth)
        d_loss.backward(retain_graph=True)

        # train generator
        Goptim.zero_grad()
        g_loss = contentLoss(image_sr, image_hr) + adversarialLoss(logits_fake)
        g_loss.backward()
        Doptim.step()
        Goptim.step()

        # running loss
        g_loss_epoch += g_loss.cpu().detach().numpy()
        d_loss_epoch += d_loss.cpu().detach().numpy()

    g_loss_recorder.append(g_loss_epoch)
    d_loss_recorder.append(d_loss_epoch)

    print("Train Loss: ")
    print("G Loss : {:.4f}".format(g_loss_epoch))
    print("D Loss : {:.4f}".format(d_loss_epoch))

    if i % 2 == 0:
        Gmodel.eval()
        Dmodel.eval()

        g_loss_eval = 0.0
        d_loss_eval = 0.0
        ssim = 0.0
        mssim = 0.0
        best_ssim = 0.0

        print("  "+'-' * 10)
        print('Evaluation stage')
        print("  "+'-' * 10)

        with torch.no_grad():
            i = 0
            for eval_hr, eval_lr in evalLoader:
                eval_hr = eval_hr.to(device)
                eval_lr = eval_lr.to(device)

                test_sr = Gmodel(eval_lr)
                test_true = Dmodel(eval_hr)
                test_fake = Dmodel(test_sr)

                true_smooth_eval = torch.tensor(np.random.rand(8, 1) * 0.1 + 0.9, dtype=torch.float).to(device)
                fake_smooth_eval = torch.tensor(np.random.rand(8, 1) * 0.1, dtype=torch.float).to(device)

                d_eval = bce(test_true, true_smooth_eval) + bce(test_fake, fake_smooth_eval)
                g_eval = contentLoss(test_sr, eval_hr) + adversarialLoss(test_fake)

                i = i + 1

                d_loss_eval += d_eval.cpu().detach().numpy()
                g_loss_eval += g_eval.cpu().detach().numpy()

                if (i == 1) or (i == 2):
                    plt.figure(figsize=(12, 12))

                    test_sr = transforms.Normalize(mean=-(eval_mean/eval_std),std= 1/eval_std)(test_sr)
                    test_sr = test_sr.cpu().numpy()[0].transpose((1, 2, 0)).astype(np.float)
                    test_sr[test_sr<=0] = 0
                    test_sr[test_sr>=1] = 1

                    eval_hr = transforms.Normalize(mean=-(eval_mean/eval_std),std= 1/eval_std)(eval_hr)
                    eval_hr = eval_hr.cpu().numpy()[0].transpose((1, 2, 0)).astype(np.float)
                    eval_hr[eval_hr<=0] = 0
                    eval_hr[eval_hr>=1] = 1

                    eval_lr = transforms.Normalize(mean=-(eval_mean/eval_std),std= 1/eval_std)(eval_lr)
                    eval_lr = eval_lr.cpu().numpy()[0].transpose((1, 2, 0)).astype(np.float)
                    eval_lr[eval_lr<=0] = 0
                    eval_lr[eval_lr>=1] = 1

                    ssim_i = compare_ssim(test_sr, eval_hr)
                    ssim += ssim_i

                    plt.subplot(1, 3, 1)
                    plt.title("LR image")
                    plt.imshow(eval_lr)
                    plt.subplot(1, 3, 2)
                    plt.title("SR image with D:{:.4f}".format(test_fake.cpu().numpy()[0,0]))
                    plt.imshow(test_sr)
                    plt.subplot(1, 3, 3)
                    plt.title("HR image with D:{:.4f}".format(test_true.cpu().numpy()[0,0]))
                    plt.imshow(eval_hr)
                    plt.show()

            mssim = ssim / 2

            print("  Eval Loss: ")
            print("  G Loss : {:.4f}".format(g_loss_eval))
            print("  D Loss : {:.4f}".format(d_loss_eval))
            print("  MSSIM : {:.4f}".format(mssim))

            if mssim >= best_ssim:
              best_ssim = mssim
              # save model and test
              torch.save(Gmodel, model_dir + "G" + file_name)
              torch.save(Dmodel, model_dir + "D" + file_name)
              print('Model saved!')

time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
    time_elapsed // 60, time_elapsed % 60))

plt.plot(np.array(d_loss_recorder), 'r')
plt.plot(np.array(g_loss_recorder), 'b')
plt.show()

Output hidden; open in https://colab.research.google.com to view.

In [32]:
Gmodel.eval()
Dmodel.eval()

g_loss_eval = 0.0
d_loss_eval = 0.0
ssim_sh = 0.0
ssim_lh = 0.0
best_ssim = 0.0
min_ssim = 1.0

print("  "+'-' * 10)
print('Evaluation stage')
print("  "+'-' * 10)

with torch.no_grad():
  i=0
  for eval_hr, eval_lr in evalLoader:
    eval_hr = eval_hr.to(device)
    eval_lr = eval_lr.to(device)

    test_sr = Gmodel(eval_lr)
    test_true = Dmodel(eval_hr)
    test_fake = Dmodel(test_sr)

    true_smooth_eval = torch.tensor(np.random.rand(8, 1) * 0.1 + 0.9, dtype=torch.float).to(device)
    fake_smooth_eval = torch.tensor(np.random.rand(8, 1) * 0.1, dtype=torch.float).to(device)

    d_eval = bce(test_true, true_smooth_eval) + bce(test_fake, fake_smooth_eval)
    g_eval = contentLoss(test_sr, eval_hr) + adversarialLoss(test_fake)

    d_loss_eval += d_eval.cpu().detach().numpy()
    g_loss_eval += g_eval.cpu().detach().numpy()

    test_sr = transforms.Normalize(mean=-(eval_mean/eval_std),std= 1/eval_std)(test_sr)
    eval_hr = transforms.Normalize(mean=-(eval_mean/eval_std),std= 1/eval_std)(eval_hr)

    for j in range(8):

      show_sr = test_sr.cpu().numpy()[j].transpose((1, 2, 0)).astype(np.float)
      show_sr[show_sr<=0] = 0
      show_sr[show_sr>=1] = 1

      show_hr = eval_hr.cpu().numpy()[j].transpose((1, 2, 0)).astype(np.float)
      show_hr[show_hr<=0] = 0
      show_hr[show_hr>=1] = 1

      show_lr = transforms.Normalize(mean=-(eval_mean/eval_std),std= 1/eval_std)(eval_lr)
      show_lr = transforms.Resize((218, 178))(show_lr)
      show_lr = show_lr.cpu().numpy()[j].transpose((1, 2, 0)).astype(np.float)
      show_lr[show_lr<=0] = 0
      show_lr[show_lr>=1] = 1

      ssim_s = compare_ssim(show_sr, show_hr)
      ssim_l = compare_ssim(show_lr, show_hr)

      if i % 50 == 0:
        plt.figure(figsize=(12, 12))

        plt.subplot(1, 3, 1)
        plt.title("LR image")
        plt.imshow(show_lr)
        plt.subplot(1, 3, 2)
        plt.title("SR image with D:{:.4f}".format(test_fake.cpu().numpy()[j,0]))
        plt.imshow(show_sr)
        plt.subplot(1, 3, 3)
        plt.title("HR image with D:{:.4f}".format(test_true.cpu().numpy()[j,0]))
        plt.imshow(show_hr)
        plt.show()

        print("  SSIM : {:.4f}".format(ssim_s))
        print("  SSIM : {:.4f}".format(ssim_l))


      if ssim_s >= best_ssim:
        best_ssim = ssim_s
      if ssim_s <= min_ssim:
        min_ssim = ssim_s

      ssim_sh += ssim_s
      ssim_lh += ssim_l

    i = i + 1

ssim_sh = ssim_sh / (len(evalLoader)*8)
ssim_lh = ssim_lh / (len(evalLoader)*8)
d_loss_eval = d_loss_eval / (len(evalLoader)*8)
g_loss_eval = g_loss_eval / (len(evalLoader)*8)

print("  Eval Loss: ")
print("  G Loss : {:.4f}".format(g_loss_eval))
print("  D Loss : {:.4f}".format(d_loss_eval))
print("  SSIM between sr and hr : {:.4f}".format(ssim_sh))
print("  SSIM max : {:.4f}".format(best_ssim))
print("  SSIM min : {:.4f}".format(min_ssim))
print("  SSIM between lr and hr : {:.4f}".format(ssim_lh))


Output hidden; open in https://colab.research.google.com to view.