### Libs

In [None]:
from torchvision.datasets import CelebA
import os
import torch.nn as nn
import glob
import os
import torch
from torch.utils.data import Dataset
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import os
import torch
import math
# main libraries
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from torch.utils.data import SubsetRandomSampler
from torch.nn import functional as F
from torch import nn
from torch import optim
import torchvision
from torchvision import transforms, datasets
from torch.utils.data import Dataset
import math
import torchvision.transforms.functional as TF
import cv2
from keras.utils import image_dataset_from_directory
from torchvision import datasets, transforms
from torch.utils.data import random_split
from torch.nn.parameter import Parameter

from torch.nn import init
from torchvision.utils import make_grid
from collections import OrderedDict

import torch
from torch.utils.data import Dataset

from skimage import transform
import skimage.io as io
import numpy as np
from torchvision.utils import save_image

### Helper Functions

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
def ploter(image, image_hat):
    """
    (H, W)
    """
    plt.figure()
    plt.subplot(1,2,1)
    #plt.imshow(image_hat, cmap='gray', vmin=-1, vmax=1)
    plt.imshow(image_hat)
    plt.tight_layout()
    plt.title("Reconstruct")

    plt.subplot(1,2,2)
    #plt.imshow(image, cmap='gray', vmin=-1, vmax=1)
    plt.imshow(image)
    plt.tight_layout()
    plt.title("Original")

    plt.show()

### Dataloader

In [None]:
class Anomaly_Dataset(Dataset):
    def __init__(self,
                 root
                 ):
        super(Anomaly_Dataset, self).__init__()

        self.data = Anomaly_Dataset.load_dataset(root)
        self.image, self.label = Anomaly_Dataset.get_numpy(self.data)

    def __getitem__(self, item):
        x, y =  self.image[item], self.label[item]

        # RGB -> GRAY : (H, W)
        x = x[:,:,0]

        # (1, H, W)
        x = Anomaly_Dataset.normalization(x)
        #x = np.expand_dims(x, axis=0)

        return x, y


    def __len__(self):
        return len(self.data)


    @staticmethod
    def load_dataset(path):
        img_rows = 96
        img_cols = 96
        return image_dataset_from_directory(directory = path,
                                               label_mode = 'int',
                                               color_mode = 'rgb',
                                               shuffle = False,
                                               batch_size = None,
                                               image_size = (img_rows, img_cols),
                                               crop_to_aspect_ratio = True)

    @staticmethod
    def get_numpy(PrefetchDataset):
        """
        return:
            (N, H, W, C) , (N,)
        """
        images = []
        labels = []
        for (image, label) in PrefetchDataset:
            images.append(image)
            labels.append(label)
        return np.array(images), np.array(labels)

    @staticmethod
    def rgb_2_gray(x):
        """
        (H, W, C) --> (H, W)
        """
        return cv2.cvtColor(x, cv2.COLOR_BGR2GRAY)

    @staticmethod
    def normalization(x):
        """
        Args:
            x : np.array : (H, W)

        Return:
            np.array : (H, W)
        """
        x = x - x.min(keepdims=True)
        x = x / x.max(keepdims=True)
        x = x - 0.5
        return  x / 0.5

### Nets

In [None]:
class Generator(nn.Module):
    def __init__(self, img_size, latent_dim, channels):
        super(Generator, self).__init__()

        self.init_size = img_size // 4
        self.linear_layer = nn.Sequential(nn.Linear(latent_dim, 128 * self.init_size ** 2))

        self.conv_layers = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, z):
        out = self.linear_layer(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        image = self.conv_layers(out)
        return image


class Discriminator(nn.Module):
    def __init__(self, channels, img_size):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.conv_layers = nn.Sequential(
            *discriminator_block(channels, 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )

        self.ds_size = img_size // 2 ** 4
        self.adverse_layer = nn.Sequential(nn.Linear(128 * self.ds_size ** 2, 1), nn.Sigmoid())

    def forward(self, image):
        out = self.conv_layers(image)
        out = out.view(out.shape[0], -1)
        validity = self.adverse_layer(out)
        return validity

### Hyperparameters

In [None]:
root_save = "./best_models"
root_ds = "./../../dataset/kaggle1/no"

batch_size = 10
img_size = 96
latent_dim = 100
channels = 1
num_epochs = 1000
sample_interval = 20

### Dataset

In [None]:
dataset = Anomaly_Dataset(root_ds)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

### Loss

In [None]:
generator = Generator(img_size=img_size, latent_dim=latent_dim, channels=channels).cuda()
discriminator = Discriminator(channels=channels, img_size=img_size).cuda()

In [None]:
adversarial_loss = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters())
optimizer_D = optim.Adam(discriminator.parameters())

In [None]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0.0)

In [None]:
generator = generator.apply(weights_init_normal)
discriminator = discriminator.apply(weights_init_normal)

### Train

In [None]:
for epoch in range(1, num_epochs):
    total_G_loss = 0.0
    total_D_loss = 0.0
    for i, (real_images, _) in enumerate(dataloader):
        valid = torch.FloatTensor(real_images.shape[0], 1).fill_(1.0).cuda()
        fake = torch.FloatTensor(real_images.shape[0], 1).fill_(0.0).cuda()
        real_images = real_images.unsqueeze(1).cuda()

        #  Train Generator
        optimizer_G.zero_grad()
        z = torch.FloatTensor(np.random.normal(0, 1, (real_images.shape[0], latent_dim))).cuda()
        gen_imgs = generator(z)
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)
        g_loss.backward()
        optimizer_G.step()
        total_G_loss += g_loss.cpu().detach().numpy()

        #  Train Discriminator
        optimizer_D.zero_grad()
        discriminator_opinion_real = discriminator(real_images)
        discriminator_opinion_fake = discriminator(gen_imgs.detach())
        real_loss = adversarial_loss(discriminator_opinion_real, valid)
        fake_loss = adversarial_loss(discriminator_opinion_fake, fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()
        total_D_loss += d_loss.cpu().detach().numpy()

    if epoch % sample_interval ==0:

        # show a sample from generator
        num_rnd = math.floor(np.random.uniform(0, real_images.shape[0]))
        plt.figure()
        ploter(real_images[num_rnd,0].detach().cpu(), gen_imgs[num_rnd,0].detach().cpu())
        plt.show()

        # show losses
        print(
                "[Epoch {}/{}] \t[D loss: {:.3f}] \t[G loss: {:.3f}]".format(
                    epoch, num_epochs, total_D_loss, total_G_loss)
            )

        # save all
        torch.save({"epoch": epoch,
            "state_dict_G": generator.state_dict(),
            "state_dict_D": discriminator.state_dict(),
            "optimizer_G": optimizer_G.state_dict(),
            "optimizer_D": optimizer_D.state_dict()
           }, root_save + f"/epoch_{epoch}.pt")

In [None]:
"""
if epoch % sample_interval == 0 and i % (len(dataloader)/5) == 0:
    save_image(gen_imgs.data[0,0],
               "images/{}_{}.png".format(str(epoch).zfill(len(str(num_epochs))),
                                         str(i).zfill(len(str(len(dataloader))))),
               normalize=True)
"""

### End