In [1]:
%matplotlib inline
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.utils as vutils
from matplotlib import pyplot as plt
from IPython import display
import torchvision.transforms as transforms
from torch.utils.data import DataLoader,Dataset
from PIL import Image

In [2]:
# For both cpu and gpu integration all the variables and models should use
# "xx.to(device)"  
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# Basic parameters for reproducablity

torch.manual_seed(1)
np.random.seed(1)
if device == "cuda:0": 
    torch.cuda.manual_seed_all(1) # gpu vars
# Dataset paths
celea_dataset = "Dataset/HIGH/celea_60000_SFD"
sr_dataset = "Dataset/HIGH/SRtrainset_2"
vgg_dataset = "Dataset/HIGH/vggface2/vggcrop_train"



<torch._C.Generator at 0x7f7e541490b0>

In [9]:
# All the parameters and dynamic numbers will be setted here
hightolow_batch_size = 8 
epoch = 200
learning_rate = 1e-4
loss_a_coeff = 1
loss_b_coeff = 0.05
adam_beta1 = 0
adam_beta2 = 0.9

high_image_size = 64
low_image_size = 16
noise_dimension = 64

In [8]:
def batch_to_image(batch):
    np_grid = vutils.make_grid(batch).numpy()
    plt.imshow(np.transpose(np_grid, (1,2,0)), interpolation='nearest')
    
# Noise distrubition sampled from normal distribution
def create_noise():
    return torch.randn(hightolow_batch_size,64)

In [5]:
# Dataset for high images
def load_image(path):
    return Image.open(path)

# This dataset use the image datasets where images
# are located on the dataset folder and this
# Generic dataset should be customized according to the dataset
# csv based dataset require different loading function

class HighDataset(Dataset):
    """ Initialize the dataset by giving the dataset path and transform that will be applied """
    def __init__(self,transform = None):
        images = []
        celea_subjects = [subject for subject in os.listdir(celea_dataset)]
        sr_subjects = [subject for subject in os.listdir(sr_dataset)]
        vgg_subjects = [subject for subject in os.listdir(vgg_dataset)]
        
        for subject in celea_subjects:
            images.append(os.path.join(celea_dataset,subject))

        for subject in sr_subjects:
            images.append(os.path.join(celea_dataset,subject))
                              
        for subject in celea_subjects:
            images.append(os.path.join(celea_dataset,subject))
        
                              
        self.images = images
        self.transform = transform
        self.count = len(images)
        

    """ Image with given index will be loaded by using the image path """
    def __getitem__(self, index):
        image_path = self.images[index]
        image = load_image(image_path)
        if self.transform is not None:
            image = self.transform(image)

        return image

    def __len__(self):
        return self.count

In [6]:
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dataset = HighDataset(transform)
data_loader = DataLoader(dataset,batch_size = hightolow_batch_size)

In [7]:
class HighToLowGenerator(nn.Module):
    def __init__(self):
        super(HighToLowGenerator, self).__init__()
        