In [2]:
%matplotlib inline
from matplotlib import pyplot
import os
import time
import cv2
import numpy as np
from PIL import Image
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.utils.data as data
from torch.autograd import Variable
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision import utils as vutils
from Net import Net
from demo import ExclusionLoss

In [3]:
CHANNELS = 3
ITERATION = 60
NUM_IMG = 3

In [4]:
class LoadImage(data.Dataset):
    def __init__(self, image_list):
        self.image_list = image_list
        self.data_list = {"img_in": []}
        for num_img in range(len(image_list)):
            self.data_list["img_in"].append(image_list[num_img])

    def __getitem__(self, index):
        img_in_path = self.data_list["img_in"][index]
        img_in = Image.open(img_in_path).convert("RGB")
        img_in = img_in.resize((512, 512))
        t_list = [transforms.ToTensor()]
        composed_transform = transforms.Compose(t_list)
        img_in = composed_transform(img_in)
        inputs = {"img_in": img_in}
        return inputs

    def __len__(self):
        return len(self.data_list["img_in"])

In [5]:
def tensor_to_pil(image_tensor):
    image_tensor = image_tensor.detach().cpu()
    image_tensor = image_tensor[0, :, :, :]
    grid = vutils.make_grid(image_tensor)
    # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
    ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
    return Image.fromarray(ndarr)

In [6]:
def save_images(out_path, image_name, results, key):
    func_time = time.time()
    img_path = os.path.join(out_path, image_name)
    res_list = results[key]
    pil_preds = []

    if res_list:
        # Stack Images
        start_time = time.time()
        for img_tensor in res_list:
            pil_img = tensor_to_pil(img_tensor)
            pil_preds.append(pil_img)
        imgs_comb = np.vstack(pil_preds)

        # Write image
        start_time = time.time()
        cv2.imwrite(img_path,imgs_comb)
    print(f"Func {key}: ", time.time() - func_time)
    return imgs_comb

In [7]:
def get_cc_loss(sum_tensor):
    lambda_cc = 1.0
    pred_cc = torch.mean(sum_tensor, dim=1, keepdims=True)
    cc_loss = (
        F.l1_loss(sum_tensor[:, 0:1, :, :], pred_cc)
        + F.l1_loss(sum_tensor[:, 1:2, :, :], pred_cc)
        + F.l1_loss(sum_tensor[:, 2:3, :, :], pred_cc)
    ) * (
        1 / 3
    )  # Color Constancy Loss
    return lambda_cc * cc_loss

def get_recon_loss(sum_tensor, img_in):
    lambda_recon = 1.0
    recon_loss = F.l1_loss(sum_tensor, img_in)
    return lambda_recon * recon_loss    


def smooth_loss(pred_tensor):
    def gradient(pred):
        D_dy = pred[:, :, 1:] - pred[:, :, :-1]
        D_dx = pred[:, :, :, 1:] - pred[:, :, :, :-1]
        return D_dx, D_dy
    dx, dy = gradient(pred_tensor)
    dx2, dxdy = gradient(dx)
    dydx, dy2 = gradient(dy)
    loss = dx2.abs().mean() + dxdy.abs().mean() + dydx.abs().mean() + dy2.abs().mean()
    return loss

# def criterion(pred_tensor, img_in):
#     lambda_smooth = 1.0
#     lambda_excl = 0.01        
#     # loss = get_cc_loss()
#     loss = get_recon_loss()
#     return loss

In [8]:
def train(model, optimizer, tqdm_bar, loss_name):
    func_time = time.time()
    predictions = []
    sums = []
    model.train()
    excl_loss = ExclusionLoss().type(torch.cuda.FloatTensor)
    for batch_idx, input in enumerate(tqdm_bar):

        img_in = Variable(torch.FloatTensor(input["img_in"])).cuda()
        pred_tensor = model(img_in)
        sum_tensor = img_in + pred_tensor
        
        if loss_name == "cc":
            loss = get_cc_loss(sum_tensor)
        elif loss_name == "recon":
            loss = get_recon_loss(sum_tensor, img_in)
        elif loss_name == "excl":
            loss = excl_loss(pred_tensor, sum_tensor)
        elif loss_name == "smooth":
            loss = smooth_loss(pred_tensor)
            
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
        if (batch_idx+1) % int(ITERATION/NUM_IMG) == 0:
            predictions.append(pred_tensor)
            sums.append(sum_tensor)

        tqdm_bar.update()

    results = {
        "pred": predictions,
        "final": sums,
    }
    print("Training: ", time.time() - func_time)
    return results

In [9]:
torch.manual_seed(0)
img_name = "DSC00580.JPG"
out_dir = "/home/catchall/Documents/thesis/enhance/data/output/night_enhance/investigate/"
data_dir = "/home/catchall/Documents/thesis/enhance/data/light-effects/"
out_path = ""

In [10]:
# Make out_dir if not existing
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

In [11]:
# Place image path in list
images = sorted(
    [
        os.path.join(data_dir, img_name)
        for file in os.listdir(data_dir)
        if file == img_name
    ]
)
image_list = images * ITERATION

In [12]:
ImageLoader = torch.utils.data.DataLoader(
    LoadImage(image_list),
    batch_size=1,
    shuffle=True,
    num_workers=8,
    drop_last=False,
)

In [13]:
model = Net(input_nc=CHANNELS, output_nc=CHANNELS)
model = nn.DataParallel(model).cuda()

optimizer = torch.optim.Adam(
    model.parameters(), lr=0.0001, betas=(0.9, 0.999)
)
tqdm_bar = tqdm(ImageLoader)
loss_name = "cc" # cc, recon, excl, smooth
results = train(model, optimizer, tqdm_bar, loss_name)

RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx

In [18]:
for key, value in results.items():
    image_name = f"{loss_name}-{key}_{img_name}"
    img = save_images(out_dir, image_name, results, key)

Func pred:  0.043663740158081055
Func final:  0.041159629821777344


In [19]:
images = []
img_path = os.path.join(out_dir,img_name)
input_img = cv2.imread(os.path.join(data_dir, img_name))
inputs = [input_img * NUM_IMG]
input_stack = np.vstack(inputs)
import glob
glob_list = glob.glob(os.path.join(out_dir, "*.JPG"))
if img_path in glob_list:
    os.remove(img_path)
    glob_list.pop(glob_list.index(img_path))
for img in sorted(glob_list):
    images.append(cv2.imread(img))
imgs_comb = np.hstack(images)
cv2.imwrite(img_path, imgs_comb)

True