In [8]:
import glob
import random
import os
import numpy as np

import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import math
import itertools
import sys
from torchvision.utils import save_image, make_grid
from torchvision.models import vgg19

from torch.utils.data import DataLoader
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch
import math


# Normalization parameters for pre-trained PyTorch models
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])


class Setup:
    def __init__(self):
        self.upscale_factor = 2
        self.epoch = 0
        self.n_epochs= 200
        self.dataset_name = "ksdakf"
        self.batch_size = 8
        self.lr = 0.0002
        self.b1 = 0.5
        self.b2 = 0.999
        self.decay_epoch = 100
        self.n_cpu = 8
        self.hr_height = 256
        self.hr_width = 256
        self.channels = 3
        self.sample_interval = 100
        self.checkpoint_interval = 100
        self.training_image_dir  = 'data/DIV2K_train_HR/DIV2K_train_HR'
        self.validation_image_dir = 'data/DIV2K_valid_HR/DIV2K_valid_HR'
        self.model_checkpoint_dir = 'model/checkpoints/ESPCN'
        self.results = 'data/results/ESPCN'


opt = Setup()

class ImageDataset(Dataset):
    def __init__(self, root, hr_shape, lr_factor):
        hr_height, hr_width = hr_shape
        # Transforms for low resolution images and high resolution images
        self.lr_transform = transforms.Compose(
            [
                transforms.Resize((hr_height // lr_factor, hr_height // lr_factor), Image.BICUBIC),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ]
        )
        self.hr_transform = transforms.Compose(
            [
                transforms.Resize((hr_height, hr_height), Image.BICUBIC),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ]
        )

        self.files = sorted(glob.glob(root + "/*.png"))

    def __getitem__(self, index):
        img = Image.open(self.files[index % len(self.files)])
        img_lr = self.lr_transform(img)
        img_hr = self.hr_transform(img)

        return {"lr": img_lr, "hr": img_hr}

    def __len__(self):
        return len(self.files)

train_set = ImageDataset(opt.training_image_dir,(opt.hr_height,opt.hr_width),2)

train_loader = DataLoader(
    train_set,
    batch_size=opt.batch_size,
    shuffle=True,
    num_workers=0)

validation_set = ImageDataset(opt.validation_image_dir,(opt.hr_height,opt.hr_width),2)

validation_loader = DataLoader(
    validation_set,
    batch_size=opt.batch_size,
    shuffle=True,
    num_workers=0)



In [9]:
class ESPCN(nn.Module):
    def __init__(self, scale_factor, num_channels=1):
        super(ESPCN, self).__init__()
        self.first_part = nn.Sequential(
            nn.Conv2d(num_channels, 64, kernel_size=5, padding=5//2),
            nn.Tanh(),
            nn.Conv2d(64, 32, kernel_size=3, padding=3//2),
            nn.Tanh(),
        )
        self.last_part = nn.Sequential(
            nn.Conv2d(32, num_channels * (scale_factor ** 2), kernel_size=3, padding=3 // 2),
            # nn.Conv2d(32, num_channels, kernel_size=3, padding=3 // 2),
            nn.PixelShuffle(scale_factor)
        )

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                if m.in_channels == 32:
                    nn.init.normal_(m.weight.data, mean=0.0, std=0.001)
                    nn.init.zeros_(m.bias.data)
                else:
                    nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
                    nn.init.zeros_(m.bias.data)

    def forward(self, x):
        x = self.first_part(x)
        x = self.last_part(x)
        return x

In [11]:

cuda = torch.cuda.is_available()

hr_shape = (opt.hr_height, opt.hr_width)
espcn = ESPCN(opt.upscale_factor,3)
bicubic_upscaler = transforms.Resize((opt.hr_width, opt.hr_height), Image.BICUBIC)

# Losses
criterion_loss = torch.nn.MSELoss()

if cuda:
    espcn = espcn.cuda()
    criterion_loss = criterion_loss.cuda()

if opt.epoch != 0:
    # Load pretrained models
    espcn.load_state_dict(torch.load(opt.model_checkpoint_dir + "/espcn_%d.pth"))

# Optimizers
optimizer = torch.optim.Adam(espcn.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor


# ----------
#  Training
# ----------

for epoch in range(opt.epoch, opt.n_epochs):
    for i, imgs in enumerate(train_loader):

        # Configure model input
        imgs_lr = Variable(imgs["lr"].type(Tensor))
        imgs_hr = Variable(imgs["hr"].type(Tensor))


        # ------------------
        #  Train 
        # ------------------

        optimizer.zero_grad()

        imgs_sr = espcn(imgs_lr)
        loss = criterion_loss(imgs_sr,imgs_hr)
        loss.backward()
        optimizer.step()
        # --------------
        #  Log Progress
        # --------------

        sys.stdout.write(
            "[Epoch %d/%d] [Batch %d/%d] [Loss: %f]\n"
            % (epoch, opt.n_epochs, i, len(train_loader), loss.item())
        )

        batches_done = epoch * len(train_loader) + i
        if batches_done % opt.sample_interval == 0:
            # Save image grid with upsampled inputs and SRGAN outputs
            imgs_hr = make_grid(imgs_hr, nrow=1, normalize=True)
            imgs_lr = make_grid(bicubic_upscaler(imgs_lr), nrow=1, normalize=True)
            imgs_sr = make_grid(imgs_sr, nrow=1, normalize=True)
            img_grid = torch.cat((imgs_lr, imgs_sr, imgs_hr), -1)
            save_image(img_grid, opt.results + "/training_%d.png" % batches_done, normalize=False)

    if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0:
        # Save model checkpoints
        torch.save(espcn.state_dict(), opt.model_checkpoint_dir + "espcn_%d.pth" % epoch)

[Epoch 0/200] [Batch 0/100] [Loss: 1.629568]
[Epoch 0/200] [Batch 1/100] [Loss: 1.668743]
[Epoch 0/200] [Batch 2/100] [Loss: 1.235530]
[Epoch 0/200] [Batch 3/100] [Loss: 1.604813]
[Epoch 0/200] [Batch 4/100] [Loss: 1.365054]
[Epoch 0/200] [Batch 5/100] [Loss: 1.127466]
[Epoch 0/200] [Batch 6/100] [Loss: 1.273879]
[Epoch 0/200] [Batch 7/100] [Loss: 0.890168]
[Epoch 0/200] [Batch 8/100] [Loss: 0.861514]
[Epoch 0/200] [Batch 9/100] [Loss: 1.101900]


KeyboardInterrupt: 

In [6]:
espcn = ESPCN(opt.upscale_factor,3)
espcn(torch.ones((1,3,128,128))).shape

torch.Size([1, 3, 256, 256])