In [16]:
%matplotlib inline
from IPython.core.interactiveshell import InteractiveShell
#InteractiveShell.ast_node_interactivity = "all"
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.utils as vutils
from matplotlib import pyplot as plt
from IPython import display
import torchvision.transforms as transforms
from torch.utils.data import DataLoader,Dataset
from PIL import Image
import PIL
import time
from model import *

In [17]:
# For both cpu and gpu integration all the variables and models should use
# "xx.to(device)"  
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# Basic parameters for reproducablity

seed = 1
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)

if device == "cuda:0": 
    torch.cuda.manual_seed_all(1) # gpu vars

# Dataset paths
celea_dataset = "Dataset/HIGH/celea_60000_SFD"
sr_dataset = "Dataset/HIGH/SRtrainset_2"
vgg_dataset = "Dataset/HIGH/vggface2/vggcrop_train"
wider_dataset = "Dataset/LOW/widerface_subset/wider_lnew"

# Model saving paths
hightolow_generator = "Checkpoint/hightolow_g_"
hightolow_discriminator = "Checkpoint/hightolow_d_"

In [27]:
# All the parameters and dynamic numbers will be setted here
hightolow_batch_size = 8 
epoch = 200
learning_rate = 1e-4
loss_a_coeff = 1
loss_b_coeff = 0.05
adam_beta1 = 0
adam_beta2 = 0.9
# After each this value of iterations generator will be updated
# In original paper it is 5:1
generator_update_ratio = 5 
# Save the generated images after some interval
# to visualize the progress
sample_interval = 500

high_image_size = 64
low_image_size = 16
noise_dimension = 64

In [19]:
def save_checkpoint(model_path,epoch,model,optimizer,loss):
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            },  model_path + epoch)

# TODO: Instead of using function to load paste these codes to main loop    
def load_checkpoint(model_path,epoch):
    checkpoint = torch.load(model_path)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']

    model.eval()
    # - or -
    model.train()
    
def batch_to_image(batch):
    np_grid = vutils.make_grid(batch).numpy()
    plt.imshow(np.transpose(np_grid, (1,2,0)), interpolation='nearest')
    
# Noise distrubition sampled from normal distribution
def create_noise():
    return torch.randn(hightolow_batch_size,64)

def calculate_remaining_training_time(batches_done,batches_left):
    print()
    #batches_done = epoch * len(dataloader) + i
    #batches_left = opt.n_epochs * len(dataloader) - batches_done
    #time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))

In [31]:
# Dataset for high images
def load_image(path):
    return Image.open(path)

# This dataset use the image datasets where images
# are located on the dataset folder and this
# Generic dataset should be customized according to the dataset
# csv based dataset require different loading function

class HighDataset(Dataset):
    """ Initialize the dataset by giving the dataset path and transform that will be applied """
    def __init__(self,transform = None):
        images = []
        celea_subjects = [subject for subject in os.listdir(celea_dataset)]
        sr_subjects = [subject for subject in os.listdir(sr_dataset)]
        vgg_subjects = [subject for subject in os.listdir(vgg_dataset)]
        
        for subject in celea_subjects:
            images.append(os.path.join(celea_dataset,subject))

        for subject in sr_subjects:
            images.append(os.path.join(sr_dataset,subject))
                              
        for subject in vgg_subjects:
            images.append(os.path.join(vgg_dataset,subject))
        
                              
        self.images = images
        self.transform = transform
        self.count = len(images)
        

    """ Image with given index will be loaded by using the image path """
    def __getitem__(self, index):
        image_path = self.images[index]
        image = load_image(image_path)
        if self.transform is not None:
            image = self.transform(image)

        return image

    def __len__(self):
        return self.count
    

class LowDataset(Dataset):
    """ Initialize the dataset by giving the dataset path and transform that will be applied """
    def __init__(self,transform = None):
        images = []
        wider_subjects = [subject for subject in os.listdir(wider_dataset)]
        
        
        for subject in wider_subjects:
            images.append(os.path.join(wider_dataset,subject))

        self.images = images
        self.transform = transform
        self.count = len(images)
        

    """ Image with given index will be loaded by using the image path """
    def __getitem__(self, index):
        image_path = self.images[index]
        image = load_image(image_path)
        if self.transform is not None:
            image = self.transform(image)

        return image

    def __len__(self):
        return self.count

In [21]:
def conv3x3(in_planes, out_planes, stride=1):
    "3x3 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)
class BasicBlock(nn.Module):
    def __init__(self, inplanes, planes, stride=1, downsample=False, upsample=False, nobn = False):
        super(BasicBlock, self).__init__()
        self.upsample = upsample
        self.downsample = downsample
        self.nobn = nobn
        if self.upsample:
            self.conv1 = nn.ConvTranspose2d(inplanes, planes, 4, 2, 1)
        else:
            self.conv1 = conv3x3(inplanes, planes, stride)
        if not self.nobn:
            self.bn1 = nn.BatchNorm2d(inplanes)
        self.relu = nn.ReLU(inplace=False)
        if self.downsample:
            self.conv2 =nn.Sequential(nn.AvgPool2d(2,2), conv3x3(planes, planes))
        else:
            self.conv2 = conv3x3(planes, planes)
        if not self.nobn:
            self.bn2 = nn.BatchNorm2d(planes)
        if inplanes != planes or self.upsample or self.downsample:
            if self.upsample:
                self.skip = nn.ConvTranspose2d(inplanes, planes, 4, 2, 1)
            elif self.downsample:
                self.skip = nn.Sequential(nn.AvgPool2d(2,2), nn.Conv2d(inplanes, planes, 1, 1))
            else:
                self.skip = nn.Conv2d(inplanes, planes, 1, 1, 0)
        else:
            self.skip = None
        self.stride = stride

    def forward(self, x):
        residual = x
        if not self.nobn:
            out = self.bn1(x)
            out = self.relu(out)
        else:
            out = self.relu(x)
        out = self.conv1(out)
        if not self.nobn:
            out = self.bn2(out)
        out = self.relu(out)
        out = self.conv2(out)
        if self.skip is not None:
            residual = self.skip(x)
        out += residual
        return out

    # Upsamplings can be changed with pixelShuffle
class HighToLowGenerator(nn.Module):
    def __init__(self):
        super(HighToLowGenerator, self).__init__()
        self.layers_in = conv3x3(3, 64)
        # 64x64
        self.residual1 = nn.Sequential(BasicBlock(64,64),BasicBlock(64,64,downsample=True))
        # 32x32
        self.residual2 = nn.Sequential(BasicBlock(64,64),BasicBlock(64,64,downsample=True))
        # 16x16
        self.residual3 = nn.Sequential(BasicBlock(64,64),BasicBlock(64,64,downsample=True))
        # 8x8
        self.residual4 = nn.Sequential(BasicBlock(64,64),BasicBlock(64,64,downsample=True))
        # 4x4
        self.residual5 = nn.Sequential(BasicBlock(64,64),BasicBlock(64,64,upsample=True))
        # 8x8
        self.residual6 = nn.Sequential(BasicBlock(64,64),BasicBlock(64,3,upsample=True),nn.Tanh())
        # 16x16
        
        
    def forward(self, input, noise= None):
        if noise is None:
            noise = torch.randn([input.size()[0],3,64,64]).to(device)
        
        x = input + noise 
        x = self.layers_in(x)
        x = self.residual1(x)
        x = self.residual2(x)
        x = self.residual3(x)
        x = self.residual4(x)
        x = self.residual5(x)
        out = self.residual6(x)
        return out
        
class LowDiscriminator(nn.Module):
    def __init__(self):
        super(LowDiscriminator, self).__init__()


        self.disc = nn.Sequential(BasicBlock(3,6,nobn=True),
                                  BasicBlock(6,6,nobn=True),
                                  BasicBlock(6,12,nobn=True),
                                  BasicBlock(12,12,nobn=True),
                                  BasicBlock(12,6,nobn=True),
                                  BasicBlock(6,6,nobn=True))  
        self.linear = nn.Sequential(nn.Linear(16*16*6, 1),
                                    nn.Sigmoid())
        
    def forward(self,input):
        x = self.disc(input)
        x = x.view(x.size(0), -1)
        out = self.linear(x)
        return out

In [32]:
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
high_dataset = HighDataset(transform)
low_dataset = LowDataset(transform) 
high_loader = DataLoader(high_dataset,batch_size = hightolow_batch_size)
low_loader = DataLoader(low_dataset,batch_size = hightolow_batch_size)
print(len(high_loader))

5006


In [29]:
# Training loop
start_time = time.time()
GANLoss = nn.HingeEmbeddingLoss().to(device)
L2Loss = nn.MSELoss().to(device)
hightolow_generator = HighToLowGenerator().to(device)
hightolow_discriminator = LowDiscriminator().to(device)
optimizer_G = torch.optim.Adam(hightolow_generator.parameters(), lr=learning_rate, betas=(adam_beta1, adam_beta2))
optimizer_D = torch.optim.Adam(hightolow_discriminator.parameters(), lr=learning_rate, betas=(adam_beta1, adam_beta2))
fixed_noise = torch.randn([hightolow_batch_size,3,64,64]).float().to(device)

def sample_images(batches_done):
    it = iter(high_dataset)
    nxt_batch = next(it)
    nxt_batch = nxt_batch.to(device)
    if hightolow_batch_size != nxt_batch.size()[0]:
        fixed_noise = fixed_noise[nxt_batch.size()[0],:,:,:]
    samples = hightolow_generator(nxt_batch, fixed_noise)
    vutils.save_image(samples, "Checkpoint/%s.png" % batches_done, normalize=True )

batches_done = 0  

for cur_epoch in range(epoch):
    for i, batch in enumerate(high_loader):
        valid = torch.from_numpy(np.ones((batch.size()[0], 1))).float().to(device)
        fake =  torch.from_numpy(np.zeros((batch.size()[0], 1))).float().to(device)
        
        batch = batch.to(device)
        low_batch = next(low_it).to(device)
        # Train discriminator with real
        optimizer_D.zero_grad()
        pred_real = hightolow_discriminator(low_batch)
        loss_D = L2Loss(pred_real,valid)
        loss_D.backward()
 
        
        # Train discriminator with fake

        fake_batch = hightolow_generator(batch)
        pred_fake = hightolow_discriminator(fake_batch)
        loss_D = L2Loss(pred_fake,fake)
        loss_D.backward()
        optimizer_D.step()
        
        batches_done += 1
        # Train generator if ratio is reached
        if i % generator_update_ratio == 0:
            # Train it

            optimizer_G.zero_grad()
            fake_batch = hightolow_generator(batch)
            loss_L2 = L2Loss(fake_batch,low_batch)
            loss_G = GANLoss(hightolow_discriminator(fake_batch),valid)
            total_loss = loss_a_coeff*loss_L2 + loss_b_coeff*loss_G
            total_loss.backward()
            optimizer_G.step()
            
            # Log the progress
            # Calculate remaining time
        if i% sample_interval
            info = "===> Epoch[{}]({}/{}): time: {:4.4f}:".format(cur_epoch, i, len(high_loader), time.time()-start_time)
            info += "Generator: {:.4f}, Discriminator: {:.4f}".format(total_loss, loss_D)     
            print(info)
            
            # If at sample interval sample and save images
        if batches_done % sample_interval == 0:
            sample_images(batches_done)
                

KeyboardInterrupt: 

In [None]:
it = iter(data_loader)
batch = next(it)
model = HighToLowGenerator(random_noise=True)
out = model(batch)
model2 = LowDiscriminator()
predict = model2(out)
print(predict)
vutils.save_image(out,"output.png")