In [1]:
%matplotlib inline
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.utils as vutils
from matplotlib import pyplot as plt
from IPython import display
import torchvision.transforms as transforms
from torch.utils.data import DataLoader,Dataset
from PIL import Image
import time

In [2]:
# For both cpu and gpu integration all the variables and models should use
# "xx.to(device)"  
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# Basic parameters for reproducablity

seed = 1
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)

if device == "cuda:0": 
    torch.cuda.manual_seed_all(1) # gpu vars

    # Dataset paths
celea_dataset = "Dataset/HIGH/celea_60000_SFD"
sr_dataset = "Dataset/HIGH/SRtrainset_2"
vgg_dataset = "Dataset/HIGH/vggface2/vggcrop_train"

# Model saving paths
hightolow_generator = "Checkpoint/hightolow_g_"
hightolow_discriminator = "Checkpoint/hightolow_d_"

<torch._C.Generator at 0x7f7e541490b0>

In [9]:
# All the parameters and dynamic numbers will be setted here
hightolow_batch_size = 8 
epoch = 200
learning_rate = 1e-4
loss_a_coeff = 1
loss_b_coeff = 0.05
adam_beta1 = 0
adam_beta2 = 0.9
# After each this value of iterations generator will be updated
# In original paper it is 5:1
generator_update_ratio = 5 
# Save the generated images after some interval
# to visualize the progress
sample_images = 400

high_image_size = 64
low_image_size = 16
noise_dimension = 64

In [8]:
def save_checkpoint(model_path,epoch,model,optimizer,loss):
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            },  model_path + epoch)

# TODO: Instead of using function to load paste these codes to main loop    
def load_checkpoint(model_path,epoch):
    checkpoint = torch.load(model_path)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']

    model.eval()
    # - or -
    model.train()
    
def batch_to_image(batch):
    np_grid = vutils.make_grid(batch).numpy()
    plt.imshow(np.transpose(np_grid, (1,2,0)), interpolation='nearest')
    
# Noise distrubition sampled from normal distribution
def create_noise():
    return torch.randn(hightolow_batch_size,64)

def calculate_remaining_training_time(batches done,batches_left):
    #batches_done = epoch * len(dataloader) + i
    #batches_left = opt.n_epochs * len(dataloader) - batches_done
    #time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))

In [5]:
# Dataset for high images
def load_image(path):
    return Image.open(path)

# This dataset use the image datasets where images
# are located on the dataset folder and this
# Generic dataset should be customized according to the dataset
# csv based dataset require different loading function

class HighDataset(Dataset):
    """ Initialize the dataset by giving the dataset path and transform that will be applied """
    def __init__(self,transform = None):
        images = []
        celea_subjects = [subject for subject in os.listdir(celea_dataset)]
        sr_subjects = [subject for subject in os.listdir(sr_dataset)]
        vgg_subjects = [subject for subject in os.listdir(vgg_dataset)]
        
        for subject in celea_subjects:
            images.append(os.path.join(celea_dataset,subject))

        for subject in sr_subjects:
            images.append(os.path.join(sr_subjects,subject))
                              
        for subject in celea_subjects:
            images.append(os.path.join(vgg_subjects,subject))
        
                              
        self.images = images
        self.transform = transform
        self.count = len(images)
        

    """ Image with given index will be loaded by using the image path """
    def __getitem__(self, index):
        image_path = self.images[index]
        image = load_image(image_path)
        if self.transform is not None:
            image = self.transform(image)

        return image

    def __len__(self):
        return self.count

In [6]:
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dataset = HighDataset(transform)
data_loader = DataLoader(dataset,batch_size = hightolow_batch_size)

In [7]:
class HighToLowGenerator(nn.Module):
    def __init__(self):
        super(HighToLowGenerator, self).__init__()
        GANLoss = 
        

In [None]:
# Training loop
start_time = time.time()

for cur_epoch in range(epoch):
    for i, batch in enumerate(dataloader):
        
        # Train discriminator with real
        
        # Train discriminator with fake
        
        # Train generator if ratio is reached
        if i % generator_update_ratio == 0:
            # Train it
            
            # Log the progress
            # Calculate remaining time
            sys.stdout.write(
                "\r[Epoch %d/%d] [Batch %d/%d] [D adv: %f, aux: %f] [G loss: %f, adv: %f, aux: %f, cycle: %f] ETA: %s"
                % (
                    epoch,
                    opt.n_epochs,
                    i,
                    len(dataloader),
                    time_left,
                )
            )

            # If at sample interval sample and save image
            if batches_done % sample_interval == 0:
                sample_images(batches_done)