In [None]:
import math
from math import sqrt
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image, ImageEnhance
import os
import matplotlib.pyplot as plt
import random
import numpy as np
import torch.optim as optim
from sklearn.model_selection import train_test_split

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Define Dataset

In [None]:
class CustomDataset(Dataset):
    def __init__(self, root, data_list, training=False, transform=None, hr_size=483, scale=3, crop_size=255):
        super(CustomDataset, self).__init__()
        self.path = root
        self.training = training
        self.image_filenames = data_list # [x for x in sorted(os.listdir(self.path))]
        self.hr_size = hr_size
        self.crop_size = crop_size
        self.scale = scale
        self.transform = transform

    def __getitem__(self, index):
        # Load Image
        img_path = os.path.join(self.path, self.image_filenames[index])
        hr_img = Image.open(img_path)
        hr_img = self.pad_img(hr_img)
        lr_img = hr_img.resize((self.hr_size//self.scale, self.hr_size//self.scale), Image.BICUBIC)
        # lr_img = lr_img.resize((self.hr_size, self.hr_size), Image.BICUBIC)

        if self.training:
          # Data augumentation for training set
          hr_img = self.flip_rotate_and_crop(hr_img)
          lr_img, hr_img = self.aug_pool(hr_img)

          # lr_img = hr_img.resize((self.crop_size//self.scale, self.crop_size//self.scale), Image.BICUBIC)
          # lr_img = lr_img.resize((self.crop_size, self.crop_size), Image.BICUBIC)
        
        if self.transform is not None:
            lr_tensor = self.transform(lr_img)
            hr_tensor = self.transform(hr_img)
     
        return lr_tensor, hr_tensor
    
    def pad_img(self, img):
        width, height = img.size
        pad_b = (self.hr_size - height) // 2
        pad_r = (self.hr_size - width) // 2
        pad_t = (self.hr_size - height) - pad_b
        pad_l = (self.hr_size - width) - pad_r
        pad = (pad_l, pad_t, pad_r, pad_b)
        padding = transforms.Compose([transforms.Pad(pad, padding_mode="symmetric")])
        return padding(img)

    def aug_pool(self, img):
        pool = ['color', 'cutout', 'cutmix']
        aug = random.sample(pool, 1)

        if aug[0] == 'color':
          lr_img, hr_img = self.color_jitter(img)

          return lr_img, hr_img
        elif aug[0] == 'cutout':
          lr_img, hr_img = self.cutout(img)

          return lr_img, hr_img
        elif aug[0] == 'cutmix':
          lr_img, hr_img = self.cutmix(img)

          return lr_img, hr_img
        else:
          lr_img = img.resize((self.crop_size//self.scale, self.crop_size//self.scale), Image.BICUBIC)
          # lr_img = lr_img.resize((self.crop_size, self.crop_size), Image.BICUBIC)

          return lr_img, img

    def flip_rotate_and_crop(self, img):
        aug_transform = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.RandomCrop(self.crop_size),
            ])
        out_img = aug_transform(img)

        # Random Rotate
        if random.random() < 0.5:
          out_img = out_img.transpose(Image.ROTATE_90)

        return out_img

    def color_jitter(self, img):
        aug_transform = transforms.Compose([                      
          transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
        ])

        out_img = aug_transform(img)
        out_lr_img = out_img.resize((self.crop_size//self.scale, self.crop_size//self.scale), Image.BICUBIC)
        # out_lr_img = out_lr_img.resize((self.crop_size, self.crop_size), Image.BICUBIC)

        return out_lr_img, out_img

    def cutmix(self, img):
        beta = 1.0
        lam = np.random.beta(beta, beta)
        width, height = img.size
        cut_rat = np.sqrt(1. - lam)
        cut_w = np.int(width * cut_rat)
        cut_h = np.int(height * cut_rat)

        # center point
        cx = np.random.randint(width)
        cy = np.random.randint(height)
        bbx1 = np.clip(cx - cut_w // 2, 0, width)
        bby1 = np.clip(cy - cut_h // 2, 0, height)
        bbx2 = np.clip(cx + cut_w // 2, 0, width)
        bby2 = np.clip(cy + cut_h // 2, 0, height)

        # read another image
        sample_file = random.sample(self.image_filenames, 1)
        sample_filename = sample_file[0]
        sample_img = Image.open(os.path.join(self.path, sample_filename))
        sample_img = self.pad_img(sample_img)
        sample_img = self.flip_rotate_and_crop(sample_img)

        cut_patch = sample_img.crop((bbx1, bby1, bbx2, bby2))
        out_im = img.copy()
        out_im.paste(cut_patch, (bbx1, bby1))
        lr_out_im = out_im.resize((self.crop_size//self.scale, self.crop_size//self.scale), Image.BICUBIC)
        
        return lr_out_im, out_im

    def cutout(self, img):
        width, height = img.size
        cutout_w = np.random.randint(int(0.4 * width), int(0.6 * width))
        cutout_h = np.random.randint(int(0.4 * height), int(0.6 * height))

        # center point
        cx = np.random.randint(width)
        cy = np.random.randint(height)
        bbx1 = np.clip(cx - cutout_w // 2, 0, width)
        bby1 = np.clip(cy - cutout_w // 2, 0, height)
        bbx2 = np.clip(cx + cutout_w // 2, 0, width)
        bby2 = np.clip(cy + cutout_w // 2, 0, height)

        mask = Image.new(mode = 'RGB', size = (bbx2 - bbx1, bby2 - bby1))
        out_im = img.copy()
        out_im.paste(mask, (bbx1, bby1))
        lr_out_im = out_im.resize((self.crop_size//self.scale, self.crop_size//self.scale), Image.BICUBIC)

        return lr_out_im, out_im

    def __len__(self):
        return len(self.image_filenames)

# Create Dataset

In [None]:
!unzip 'datasets.zip'

In [None]:
dataroot = '/content/datasets/training_hr_images'

transform = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

train_list = [x for x in sorted(os.listdir(dataroot))]

train_dataset = CustomDataset(dataroot, train_list, training=True,
                              transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=24, shuffle=True)


In [None]:
print(len(train_dataset))

291


# Define Model 

## EDSR

In [None]:
def default_conv(in_channels, out_channels, kernel_size, bias=True):
    return nn.Conv2d(
        in_channels, out_channels, kernel_size,
        padding=(kernel_size//2), bias=bias)


class MeanShift(nn.Conv2d):
    def __init__(
        self, rgb_range,
        rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):

        super(MeanShift, self).__init__(3, 3, kernel_size=1)
        std = torch.Tensor(rgb_std)
        self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
        self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
        for p in self.parameters():
            p.requires_grad = False


class ResBlock(nn.Module):
    def __init__(
        self, conv, n_feats, kernel_size,
        bias=True, bn=False, act=nn.ReLU(True), res_scale=1):

        super(ResBlock, self).__init__()
        m = []
        for i in range(2):
            m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
            if bn:
                m.append(nn.BatchNorm2d(n_feats))
            if i == 0:
                m.append(act)

        self.body = nn.Sequential(*m)
        self.res_scale = res_scale

    def forward(self, x):
        res = self.body(x).mul(self.res_scale)
        res += x

        return res


class Upsampler(nn.Sequential):
    def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):

        m = []
        if (scale & (scale - 1)) == 0:    # Is scale = 2^n?
            for _ in range(int(math.log(scale, 2))):
                m.append(conv(n_feats, 4 * n_feats, 3, bias))
                m.append(nn.PixelShuffle(2))
                if bn:
                    m.append(nn.BatchNorm2d(n_feats))
                if act == 'relu':
                    m.append(nn.ReLU(True))
                elif act == 'prelu':
                    m.append(nn.PReLU(n_feats))

        elif scale == 3:
            m.append(conv(n_feats, 9 * n_feats, 3, bias))
            m.append(nn.PixelShuffle(3))
            if bn:
                m.append(nn.BatchNorm2d(n_feats))
            if act == 'relu':
                m.append(nn.ReLU(True))
            elif act == 'prelu':
                m.append(nn.PReLU(n_feats))
        else:
            raise NotImplementedError

        super(Upsampler, self).__init__(*m)


class EDSR(nn.Module):
    def __init__(self, n_resblocks, n_feats, scale, res_scale, pretrained=False):
        super(EDSR, self).__init__()
        self.scale = scale

        kernel_size = 3 
        n_colors = 3
        rgb_range = 255
        conv=default_conv
        act = nn.ReLU(True)
        self.sub_mean = MeanShift(rgb_range)
        self.add_mean = MeanShift(rgb_range, sign=1)

        # define head module
        m_head = [conv(n_colors, n_feats, kernel_size)]

        # define body module
        m_body = [
            ResBlock(
                conv, n_feats, kernel_size, act=act, res_scale=res_scale
            ) for _ in range(n_resblocks)
        ]
        m_body.append(conv(n_feats, n_feats, kernel_size))

        # define tail module
        m_tail = [
            Upsampler(conv, scale, n_feats, act=False),
            conv(n_feats, n_colors, kernel_size)
        ]

        self.head = nn.Sequential(*m_head)
        self.body = nn.Sequential(*m_body)
        self.tail = nn.Sequential(*m_tail)


    def forward(self, x, scale=None):
        if scale is not None and scale != self.scale:
            raise ValueError(f"Network scale is {self.scale}, not {scale}")
        x = self.sub_mean(255 * x)
        x = self.head(x)

        res = self.body(x)
        res += x

        x = self.tail(res)
        x = self.add_mean(x) / 255

        return x


def edsr_r16f64(scale, pretrained=False):
    return EDSR(16, 64, scale, 1.0, pretrained)


def edsr_r32f256(scale, pretrained=False):
    return EDSR(32, 256, scale, 0.1, pretrained)


def edsr_baseline(scale, pretrained=False):
    return edsr_r16f64(scale, pretrained)


def edsr(scale, pretrained=False):
    return edsr_r32f256(scale, pretrained)

# Prepare for training

In [None]:
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")

In [None]:
model = edsr_baseline(scale=3, pretrained=False).to(device)

In [None]:
lr = 1e-4

criterionL2 = nn.MSELoss().to(device)
criterionL1 = nn.L1Loss().to(device)
optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-8)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.5, verbose=True)

Adjusting learning rate of group 0 to 1.0000e-04.


# Training

In [None]:
def plot_loss(train_loss_history, epoch):
  epoch_history = [*range(0, epoch+1, 1)]
  line1, = plt.plot(epoch_history, train_loss_history ,label = 'Training')
  plt.legend(handles = [line1])
  plt.xlabel('epochs')
  plt.ylabel('loss')
  #plt.ylim(0, 0.1)
  plt.savefig('loss.png')
  plt.show()

In [None]:
num_epochs = 300
train_loss_history = []
val_loss_history = []
log_path = 'log.txt'

for epoch in range(num_epochs):
  model.train()

  totalLoss = 0
  count = 0

  for lr, hr in train_dataloader:
    model.zero_grad()
    lr = lr.to(device)
    hr = hr.to(device)
    sr = model(lr)
    loss = criterionL1(sr.view(-1), hr.view(-1))
    loss.backward()
    optimizer.step()

    count += len(hr)
    totalLoss += loss.item() * len(hr)

  train_loss = totalLoss / count
  train_loss_history.append(train_loss)

  torch.save(model.state_dict(), "model_ep{}_loss{:.8f}.pkl".format(epoch+1, train_loss))

  # Log information
  with open(log_path, 'a') as f:
      f.write("Epoch {}: Training Loss: {:.8f}.\n".format(epoch+1, train_loss))
  print("Epoch {}: Training Loss: {:.8f}.".format(epoch+1, train_loss))

  plot_loss(train_loss_history, epoch)
  print("-------")

  scheduler.step()
