In [1]:
import os
from easydict import EasyDict as edict

config = edict()

# training parameters
config.batch_size = 64#32
config.patch_size = 100
config.mode = "RGB"
config.channels = 3
config.content_layer = 'relu2_2' # originally relu5_4 in DPED
config.learning_rate = 1e-4
config.augmentation = True #data augmentation (flip, rotation)
config.test_every = 200
config.train_iter = 50000
config.data_loader_workers = 16
config.pin_memory = 2
# config.sample_size = 100000

# weights for loss
config.w_content = 2 # reconstruction (originally 1)
config.w_profile = 0.2
config.w_color = 5 # gan color (originally 5e-3)
config.w_texture = 2 # gan texture (originally 5e-3)
config.w_tv = 3 # total variation (originally 400)
config.gamma = 0.6
config.model_name = "WESPE_DIV2K_arnav_gpu1"

# directories
config.dataset_name = "iphone"
config.train_path_phone = os.path.join("/home/grads/v/vineet/Downloads/DPED/dped/iphone/training_data/iphone","*.jpg")
config.train_path_canon = os.path.join("/home/grads/v/vineet/Downloads/DPED/dped/iphone/training_data/canon","*.jpg")
config.train_path_DIV2K = os.path.join("/home/grads/v/vineet/Downloads/DPED/DIV2K_train_HR","*.png")

config.test_path_phone_patch = os.path.join("/home/grads/v/vineet/Downloads/DPED/sample_images/original_images/iphone","*.jpg")
config.test_path_phone_image = os.path.join("/home/grads/v/vineet/Downloads/DPED/sample_images/original_images/iphone","*.jpg")

config.vgg_dir = "./vgg_pretrained/imagenet-vgg-verydeep-19.mat"

config.result_dir = os.path.join("./result_1", config.model_name)
config.result_img_dir = os.path.join(config.result_dir, "samples")
config.checkpoint_dir = os.path.join(config.result_dir, "model")

if not os.path.exists(config.result_dir):
    print("creating dir...", config.result_dir)
    os.makedirs(config.result_dir)
    
if not os.path.exists(config.checkpoint_dir):
    print("creating dir...", config.checkpoint_dir)
    os.makedirs(config.checkpoint_dir)

if not os.path.exists(config.result_img_dir):
    print("creating dir...", config.result_img_dir)
    os.makedirs(config.result_img_dir)
    
config.sample_dir = "samples_DIV2K"
if not os.path.exists(config.sample_dir):
    print("creating dir...", config.sample_dir)
    os.makedirs(config.sample_dir)

In [61]:
from dataloader.dataloader_torch import Dataset, get_patch
from ops_torch import preprocess, postprocess
import imageio
from utils import calc_PSNR

dataset = Dataset(config)

Dataset: iphone, 160471 images
DIV2K: 800 images
160471 images loaded! setting took: 0.0952s


In [3]:
from generator import Generator
from discriminator import Discriminator
import torch
import torch.nn as nn

from vgg19_torch import net

In [4]:
data_loader = torch.utils.data.DataLoader(dataset,
                                       batch_size=4,
                                       shuffle=True,
                                       num_workers=config.data_loader_workers,
                                       pin_memory=config.pin_memory,
                                       drop_last=True)

In [5]:
vgg_dir = config.vgg_dir
content_layer = config.content_layer

w_content = config.w_content
w_profile = config.w_profile
w_texture = config.w_texture 
w_color = config.w_color
w_tv = config.w_tv
gamma = config.gamma

# Networks
generator = Generator()
discriminator1 = Discriminator(in_channels=3)
discriminator2 = Discriminator(in_channels=3)
discriminator3 = Discriminator(in_channels=1)

if torch.cuda.is_available():
    generator, discriminator1, discriminator2, discriminator3 = \
    generator.cuda(), discriminator1.cuda(), discriminator2.cuda(), discriminator3.cuda()

In [6]:
loss_fn_D = nn.BCEWithLogitsLoss()
optimizer_G = torch.optim.Adam(generator.parameters())
optimizer_D1 = torch.optim.Adam(discriminator1.parameters())
optimizer_D2 = torch.optim.Adam(discriminator2.parameters())
optimizer_D3 = torch.optim.Adam(discriminator3.parameters())

In [7]:
def build_discriminator_unit(generated_patch, actual_batch, index, preprocess):

    if index == 1:
        act, _ = discriminator1(actual_batch, preprocess = preprocess)
        fake, _ = discriminator1(generated_patch.detach(), preprocess = preprocess)

    elif index == 2:
        act, _ = discriminator2(actual_batch, preprocess = preprocess)
        fake, _ = discriminator2(generated_patch.detach(), preprocess = preprocess)

    elif index == 3:
        act, _ = discriminator3(actual_batch, preprocess = preprocess)
        fake, _ = discriminator3(generated_patch.detach(), preprocess = preprocess)

    else:
        raise NotImplementedError

    loss_real = loss_fn_D(act, torch.ones_like(act))
    loss_fake = loss_fn_D(fake, torch.zeros_like(fake))
    total_loss = loss_real+loss_fake

    return total_loss, act, fake

In [8]:
for step, (phone_patch, canon_patch, DIV2K_patch) in enumerate(data_loader):
    break

phone_patch, canon_patch, DIV2K_patch = phone_patch.float(), canon_patch.float(), DIV2K_patch.float()
if torch.cuda.is_available():
    phone_patch, canon_patch, DIV2K_patch = phone_patch.cuda(), canon_patch.cuda(), DIV2K_patch.cuda()

In [9]:
enhanced_patch = generator(phone_patch)

# Discrimiator 1
d_loss_profile, logits_DIV2K_profile, logits_enhanced_profile = build_discriminator_unit(enhanced_patch, DIV2K_patch, index=1, preprocess='blur')

# Discrimiator 2
d_loss_color, logits_original_color, logits_enhanced_color = build_discriminator_unit(enhanced_patch, canon_patch, index=2, preprocess='none')

# Discrimiator 3
d_loss_texture, logits_original_texture, logits_enhanced_texture = build_discriminator_unit(enhanced_patch, canon_patch, index=3, preprocess='gray')

Discriminator-color (blur)
Discriminator-color (blur)
Discriminator-color (none)
Discriminator-color (none)
Discriminator-texture
Discriminator-texture


In [10]:
def total_variation_loss(images):

    ndims = len(images.shape)

    if ndims == 3:
        pixel_dif1 = images[:, 1:, :] - images[:, :-1, :]
        pixel_dif2 = images[:, :, 1:] - images[:, :, :-1]
        sum_axis = None

    if ndims == 4:
        pixel_dif1 = images[:, :, 1:, :] - images[:, :, :-1, :]
        pixel_dif2 = images[:, :, :, 1:] - images[:, :, :, :-1]
        sum_axis = (1, 2, 3)

    else:
        raise ValueError('\'images\' must be either 3 or 4-dimensional.')

    tot_var = (
        torch.sum(torch.abs(pixel_dif1)) +
        torch.sum(torch.abs(pixel_dif2), dim=sum_axis))

    return tot_var

In [11]:
original_vgg = net(vgg_dir, canon_patch * 255)
enhanced_vgg = net(vgg_dir, enhanced_patch * 255)

In [12]:
content_loss = torch.mean(torch.pow(original_vgg[content_layer] - enhanced_vgg[content_layer], 2))

#profile loss(gan, enhanced-div2k)
profile_loss = loss_fn_D(logits_DIV2K_profile, logits_enhanced_profile)

# color loss (gan, enhanced-original)
color_loss = loss_fn_D(logits_original_color, logits_enhanced_color)

# texture loss (gan, enhanced-original)
texture_loss = loss_fn_D(logits_original_texture, logits_enhanced_texture)

In [13]:
tv_loss = torch.mean(torch.abs(total_variation_loss(enhanced_patch) - total_variation_loss(canon_patch)))

In [14]:
g_loss = w_content*content_loss + w_profile*profile_loss + w_color*color_loss + w_texture*texture_loss + w_tv*tv_loss

In [15]:
optimizer_G.zero_grad()

g_loss.backward(retain_graph=True)
optimizer_G.step()

optimizer_D1.zero_grad()
optimizer_D2.zero_grad()
optimizer_D3.zero_grad()

d_loss_profile.backward()
optimizer_D1.step()

d_loss_color.backward()
optimizer_D2.step()

d_loss_texture.backward()
optimizer_D3.step()

In [100]:
generator.eval()
discriminator1.eval()
discriminator2.eval()
discriminator3.eval()

In [101]:
from glob import glob
import numpy as np
import scipy.misc
import scipy.io

config.test_path_phone_patch = '/home/grads/v/vineet/Downloads/DPED/DIV2K_train_HR/*.png'
test_list_phone = sorted(glob(config.test_path_phone_patch))

In [102]:
test_num_patch = 200
test_num_image = 5
PSNR_phone_enhanced_list = np.zeros([test_num_patch])

indexes = []

In [54]:
for i in range(test_num_patch):
    break

In [55]:
index = np.random.randint(len(test_list_phone))
indexes.append(index)
test_img = scipy.misc.imread(test_list_phone[index], mode = "RGB").astype("float32")

`imread` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imread`` instead.
  This is separate from the ipykernel package so we can avoid doing imports until


In [93]:
test_patch_phone = get_patch(test_img, config.patch_size)
test_patch_phone = preprocess(test_patch_phone)

test_patch_phone = torch.from_numpy(np.transpose(test_patch_phone, (2,1,0))).float().unsqueeze(0).cuda()

if torch.cuda.is_available():
    test_patch_phone = test_patch_phone.cuda()

test_patch_enhanced = generator(test_patch_phone)
test_patch_enhanced = np.transpose(test_patch_enhanced.cpu().data.numpy(), (0,2,3,1))
test_patch_phone = np.transpose(test_patch_phone.cpu().data.numpy(), (0,2,3,1))

In [107]:
PSNR = calc_PSNR(postprocess(test_patch_enhanced[0]), postprocess(test_patch_phone[0]))
PSNR_phone_enhanced_list[i] = PSNR