In [1]:
%matplotlib inline
from matplotlib import pyplot
import os
import time
import cv2
import numpy as np
from PIL import Image
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.utils.data as data
from torch.autograd import Variable
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision import utils as vutils
from Net import Net

In [4]:
class LoadImage(data.Dataset):
    def __init__(self, image_list):
        self.image_list = image_list
        self.data_list = {"img_in": []}
        for num_img in range(len(image_list)):
            self.data_list["img_in"].append(image_list[num_img])

    def __getitem__(self, index):
        img_in_path = self.data_list["img_in"][index]
        img_in = Image.open(img_in_path).convert("RGB")
        img_in = img_in.resize((512, 512))
        t_list = [transforms.ToTensor()]
        composed_transform = transforms.Compose(t_list)
        img_in = composed_transform(img_in)
        inputs = {"img_in": img_in}
        return inputs

    def __len__(self):
        return len(self.data_list["img_in"])

In [5]:
def tensor_to_pil(image_tensor):
    image_tensor = image_tensor.detach().cpu()
    image_tensor = image_tensor[0, :, :, :]
    grid = vutils.make_grid(image_tensor)
    # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
    ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
    return Image.fromarray(ndarr)

In [21]:
def save_images(out_path, img_name, results, key):
    func_time = time.time()
    img_path = os.path.join(out_path,f"{key}_{img_name}")
    res_list = results[key]
    pil_preds = []

    if res_list:
        # Stack Images
        start_time = time.time()
        for img_tensor in res_list:
            pil_img = tensor_to_pil(img_tensor)
            pil_preds.append(pil_img)
        imgs_comb = np.vstack(pil_preds)
        print("Stack: ", time.time() - start_time)

        # Write image
        start_time = time.time()
        cv2.imwrite(img_path,imgs_comb)
        print("Write: ", time.time() - start_time)
        print("Func: ", time.time() - func_time)

In [7]:
def train(model, optimizer, tqdm_bar):
    predictions = []
    sums = []
    for batch_idx, input in enumerate(tqdm_bar):
        model.train()

        img_in = Variable(torch.FloatTensor(input["img_in"])).cuda()
        optimizer.zero_grad()

        net_tensor = model(img_in)
        sum_tensor = img_in + net_tensor

        if (batch_idx+1) % 10 == 0:
            predictions.append(net_tensor)
            sums.append(sum_tensor)
        tqdm_bar.update()

    results = {
        "pred": predictions,
        "sum": sums
    }
    return results

In [8]:
torch.manual_seed(0)
img_name = "DSC00580.JPG"
out_dir = "/home/catchall/Documents/thesis/enhance/data/output/night_enhance"
data_dir = "/home/catchall/Documents/thesis/enhance/data/light-effects/"
out_path = "/home/catchall/Documents/thesis/enhance/data/output/night_enhance/investigate/"
channels = 3
iteration = 60

In [9]:
# Make out_dir if not existing
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

In [10]:
# Place image path in list
images = sorted(
    [
        os.path.join(data_dir, img_name)
        for file in os.listdir(data_dir)
        if file == img_name
    ]
)
image_list = images * iteration

In [11]:
ImageLoader = torch.utils.data.DataLoader(
    LoadImage(image_list),
    batch_size=1,
    shuffle=True,
    num_workers=8,
    drop_last=False,
)

In [12]:
model = Net(input_nc=channels, output_nc=channels)
model = nn.DataParallel(model).cuda()

optimizer = torch.optim.Adam(
    model.parameters(), lr=0.0001, betas=(0.9, 0.999)
)
tqdm_bar = tqdm(ImageLoader)
results = train(model, optimizer, tqdm_bar)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 60/60 [00:03<00:00, 17.49it/s]


In [13]:
len(results)

2

In [22]:
save_images(out_path, img_name, results, "pred")

Stack:  0.06850504875183105
Write:  0.03966856002807617
Func:  0.10846304893493652


In [23]:
save_images(out_path, img_name, results, "sum")

Stack:  0.06294822692871094
Write:  0.03293251991271973
Func:  0.0959775447845459
