In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm

torch.cuda.empty_cache()

In [None]:
path = 'D:/Images/'
images = []
for root, dirs, files in os.walk(path):
    for file in files:
        if file.endswith(".jpg"):
            images.append(os.path.join(root, file))

In [None]:
from PIL import Image
import imgaug.augmenters as iaa
import imgaug.augmentables.lines as ial
from random import randrange


class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, images, transform, aug=None, line_aug: bool = None):
        self.images = images
        self.cache = {}
        self.transform = transform
        self.aug = aug
        self.line_aug = line_aug

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_key = self.images[idx]
        if image_key not in self.cache:
            image = Image.open(image_key)
            self.cache[image_key] = image
        else:
            image = self.cache[image_key]
        image = self.transform(image)
        train_image = image
        W = image.shape[1]
        if self.aug:
            train_image = np.transpose(np.uint8(255 * image.numpy()), (1, 2, 0))
            if self.line_aug:
                line_aug = ial.LineStringsOnImage(
                    [
                        ial.LineString([(randrange(0, W), randrange(0, W)) for i in range(randrange(2, 5))]) for j
                        in range(randrange(1, 3))
                    ],
                    shape=train_image.shape)
                color_white = (randrange(235, 255), randrange(235, 255), randrange(235, 255))
                size_white = randrange(2, 4)
                train_image = line_aug.draw_on_image(train_image,
                                                     color_lines=color_white, color_points=color_white,
                                                     size_lines=size_white, size_points=size_white)
            train_image = self.aug(image=train_image)
            if self.line_aug:
                line_aug = ial.LineStringsOnImage(
                    [
                        ial.LineString([(randrange(0, W), randrange(0, W)) for i in range(randrange(2, 5))]) for j
                        in range(randrange(1, 5))
                    ],
                    shape=train_image.shape)
                color_blue = (randrange(20, 80), randrange(20, 80), randrange(140, 220))
                size_blue = randrange(1, 3)
                train_image = line_aug.draw_on_image(train_image,
                                                     color_lines=color_blue, color_points=color_blue,
                                                     size_lines=size_blue, size_points=size_blue)
            train_image = torch.from_numpy(np.transpose(train_image, (2, 0, 1)) / 255.0).float()
        noise = torch.randn_like(train_image) * 0.05
        noisy_image = train_image + noise
        return noisy_image, image

In [None]:
from sklearn.model_selection import train_test_split

valid_images, _ = train_test_split(images, test_size=0.98, random_state=42)
train_images, test_images = train_test_split(valid_images, test_size=0.05, random_state=42)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomCrop(320),
])
aug = iaa.Sequential([
    iaa.WithBrightnessChannels(iaa.Add((-60, 10))),
])
train_dataset = ImageDataset(train_images, transform=transform, aug=aug, line_aug=True)
test_dataset = ImageDataset(test_images, transform=transform, aug=aug, line_aug=True)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
class DarkMAELoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DarkMAELoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        # return torch.abs(inputs - targets).mean()
        return ((1 / (1e-3 + torch.min(targets, 1 - targets))) * torch.abs(inputs - targets)).mean()
        # return ((1 / (1e-4 + targets)) * torch.abs(inputs - targets)).mean()

In [None]:
from model.custom.net import ConvAutoencoder

model = ConvAutoencoder()
model = model.cuda()
criterion = DarkMAELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=5, threshold=0.001)
print('Encoder Params:', sum(p.numel() for p in model.encoders.parameters()))
print('Decoder Params:', sum(p.numel() for p in model.decoders.parameters()))
# print('Total Params:', sum(p.numel() for p in model.parameters()))
model.load_state_dict(torch.load("save/custom/model_1.pkl"))

In [None]:
num_epochs = 100
losses = []
pbar = tqdm(range(num_epochs))
for epoch in pbar:
    total_loss = 0
    for iteration, data in enumerate(train_loader):
        noisy_imgs, imgs = data
        noisy_imgs = noisy_imgs.cuda()
        imgs = imgs.cuda()

        output = model(noisy_imgs)
        loss = criterion(output, imgs)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        pbar.set_description('ITER: [{}/{}] | LOSS: {:.4f} | LR: {:.5f}'
                             .format(iteration + 1, len(train_loader), total_loss / (iteration + 1),
                                     optimizer.param_groups[0]["lr"]))
    scheduler.step(total_loss / len(train_loader))
    losses.append(total_loss / len(train_loader))
    torch.save(model.state_dict(), "save/custom/model_2.pkl")

In [None]:
plt.plot(losses)

In [None]:
model.eval()
with torch.no_grad():
    pbar = tqdm(test_loader, total=len(test_loader))
    for data in pbar:
        noisy_imgs, imgs = data
        noisy_imgs = noisy_imgs.cuda()
        imgs = imgs.cuda()
        output = model(noisy_imgs)
        loss = criterion(output, imgs)
    pbar.set_description('test loss:{:.4f}'.format(loss.item()))

noisy_imgs = noisy_imgs.cpu()
imgs = imgs.cpu()
output = output.cpu()
fig, axes = plt.subplots(nrows=3, ncols=5, sharex=True, sharey=True, figsize=(25, 15))
in_imgs = noisy_imgs[:5]
reconstructed_imgs = output[:5]
for images, row in zip([in_imgs, reconstructed_imgs, imgs], axes):
    for img, ax in zip(images, row):
        ax.imshow(np.transpose(img, (1, 2, 0)))
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
fig.tight_layout(pad=0.1)
fig.savefig("out/custom/test_out.jpg")

In [None]:
model.eval()
with torch.no_grad():
    imgs = Image.open('D:/test_noisy.jpg')
    imgs = imgs.crop((0, 0, 2544, 3504))
    imgs = torch.unsqueeze(transforms.ToTensor()(imgs), dim=0)
    imgs = imgs.cuda()
    output = model(imgs)

imgs = np.transpose(imgs.cpu().squeeze().numpy(), (1, 2, 0))
output = np.transpose(output.cpu().squeeze().numpy(), (1, 2, 0))
im = Image.fromarray(np.uint8(output * 255), mode='RGB')
im.save("out/vanilla/test_noisy_out.jpg")
im = Image.fromarray(np.uint8(imgs * 255), mode='RGB')
im.save("test_noisy.jpg")