In [0]:
!pip3 install http://download.pytorch.org/whl/cu80/torch-0.4.0-cp36-cp36m-linux_x86_64.whl 
!pip3 install torchvision
!pip install Pillow==4.0.0
!pip install PIL
!pip install image

In [0]:
!wget -x --load-cookies cookies.txt --cut-dirs=5 https://www.kaggle.com/evgeniumakov/images4k/downloads/Dataset4K.zip
!unzip ./www.kaggle.com/Dataset4K -d Dataset4K >> temp_log

import os
os.remove("./Dataset4K/Dataset4K/Thumbs.db")
os.remove("./Dataset4K/Dataset4K/4k-3840-x-2160-wallpapers-themefoxx (275).jpg")

In [0]:
import scipy
from scipy import misc
from scipy import ndimage
import numpy as np

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

from DataLoader import Loader
from Generator import GenGAN
from Discriminator import DiscGAN

import matplotlib.pyplot as plt
%matplotlib inline

In [0]:
torch.set_default_tensor_type('torch.FloatTensor')

batch_size = 20
lr = 10e-4
loader = Loader(path="./Dataset4K", crop_size=2160)
gpu = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
cpu = torch.device("cpu")


scale_input_data = 3
scale = 4
n_resgroups = 1
n_resblocks = 10
n_feats = 6
reduction = 4
n_colors = 3
res_scale = 1

In [0]:
discriminator = DiscGAN(ff=44, latent_size=1000, device=gpu)
D_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr)

generator = GenGAN(n_resgroups, n_resblocks, n_feats, reduction, n_colors, res_scale, scale, device=gpu)
G_optimizer = torch.optim.Adam(generator.parameters(), lr=lr)

In [0]:
def rescale(image, degree_compress=0.125):
    return np.array([ndimage.zoom(channel_image, degree_compress) for channel_image in image])

In [0]:
def bath_rescale(batch, degree_compress=0.125):
    return torch.Tensor([rescale(image/255, degree_compress) for image in batch.data.numpy()])

In [0]:
def pretrain_Descriminator():
    train_size = len(loader)
    for batch_idx, (data, _) in enumerate(DataLoader(loader, batch_size=batch_size)):
        # train D
        discriminator.zero_grad()
        data_low = bath_resize(bath_resize(data, 1/(scale_input_data*scale)), scale)
        fake_pred = discriminator(data_low.to(gpu))
        true_pred = discriminator(bath_resize(data, 1/scale_input_data).to(gpu))

        fake_loss = F.mse_loss(fake_pred, true_pred)
        true_loss = F.mse_loss(true_pred, true_pred)

        D_loss = 0.5 * (fake_loss + true_loss)

        D_loss.backward()
        D_optimizer.step()
        
        torch.save(discriminator.state_dict(), "./pretrain_dicriminator.mdl")
        torch.cuda.empty_cache()
        
        line = 'Train Epoch: [{}/{} ({:.0f}%)]\tLosses '.format(
            batch_idx * len(data), train_size, 100. * batch_idx / train_size)
        losses = 'D: {:.4f}'.format(D_loss.item())
        print(line + losses)

In [0]:
def train_GAN():
    global D_optimizer
    global G_optimizer
    train_size = len(loader)
    for batch_idx, (data, _) in enumerate(DataLoader(loader, batch_size=batch_size, shuffle=True)):    
        # train D
        generator.zero_grad()
        discriminator.zero_grad()
        data_false = bath_rescale(bath_rescale(data, 1/(scale_input_data*scale)), scale).to(gpu)
        data_true = bath_rescale(data, 1/scale_input_data).to(gpu)
        
        fake_pred = discriminator(data_false)
        true_pred = discriminator(data_true)

        fake_loss = F.mse_loss(fake_pred, true_pred)
        true_loss = F.mse_loss(true_pred, true_pred)

        D_loss = 0.5 * (fake_loss + true_loss)

        D_loss.backward()
        D_optimizer.step()
        torch.save(discriminator.state_dict(), "./train_dicriminator.mdl")
        
        
        # train G
        discriminator.zero_grad()
        generator.zero_grad()
        fake_pred = discriminator(generator(bath_rescale(data, 1/(scale_input_data*scale)).to(gpu)))
        true_pred = discriminator(data_true)
        G_loss = F.mse_loss(fake_pred, true_pred)
        
        G_loss.backward()
        G_optimizer.step()
        torch.save(generator.state_dict(), "./train_generator.mdl")
        
        del data_false
        del data_true
        del fake_pred
        del true_pred
        torch.cuda.empty_cache()    
        
        if D_loss.item() < 0.001:
            D_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr)
            G_optimizer = torch.optim.Adam(generator.parameters(), lr=lr)
              
        line = 'Train Epoch: [{}/{} ({:.0f}%)]\tLosses '.format(
            batch_idx * len(data), train_size, 100. * batch_idx / train_size)
        losses = 'G: {:.6f}, D: {:.6f}'.format(G_loss.item(), D_loss.item())
        print(line + losses)


In [0]:
train_GAN()