### Libs

In [None]:
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import os
import torch
import math
# main libraries
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from torch.utils.data import SubsetRandomSampler
from torch.nn import functional as F
from torch import nn
from torch import optim
import torchvision
from torchvision import transforms, datasets
from torch.utils.data import Dataset
import math
import torchvision.transforms.functional as TF
import cv2
from keras.utils import image_dataset_from_directory
from torchvision import datasets, transforms
from torch.utils.data import random_split
from torch.nn.parameter import Parameter

from torch.nn import init
from torchvision.utils import make_grid
from collections import OrderedDict

### Helper

In [None]:
def l2normalize(v, eps=1e-12):
    return v / (v.norm() + eps)

In [None]:
class SpectralNorm(nn.Module):
    def __init__(self, module, name='weight', power_iterations=1):
        super(SpectralNorm, self).__init__()
        self.module = module
        self.name = name
        self.power_iterations = power_iterations
        if not self._made_params():
            self._make_params()

    def _update_u_v(self):
        u = getattr(self.module, self.name + "_u")
        v = getattr(self.module, self.name + "_v")
        w = getattr(self.module, self.name + "_bar")

        height = w.data.shape[0]
        for _ in range(self.power_iterations):
            v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
            u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))

        # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
        sigma = u.dot(w.view(height, -1).mv(v))
        setattr(self.module, self.name, w / sigma.expand_as(w))

    def _made_params(self):
        try:
            u = getattr(self.module, self.name + "_u")
            v = getattr(self.module, self.name + "_v")
            w = getattr(self.module, self.name + "_bar")
            return True
        except AttributeError:
            return False


    def _make_params(self):
        w = getattr(self.module, self.name)

        height = w.data.shape[0]
        width = w.view(height, -1).data.shape[1]

        u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
        v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
        u.data = l2normalize(u.data)
        v.data = l2normalize(v.data)
        w_bar = Parameter(w.data)

        del self.module._parameters[self.name]

        self.module.register_parameter(self.name + "_u", u)
        self.module.register_parameter(self.name + "_v", v)
        self.module.register_parameter(self.name + "_bar", w_bar)


    def forward(self, *args):
        self._update_u_v()
        return self.module.forward(*args)

In [None]:
def normalization(x):
    """
    Args:
        x : np.array : (H, W)

    Return:
        np.array : (H, W)
    """
    x = x - x.min()
    x = x / x.max()
    x = x - 0.5
    return  x / 0.5

In [None]:
def ploter(image, image_hat):
    """
    (H, W)
    """
    plt.figure()
    plt.subplot(1,2,1)
    plt.imshow(image_hat, cmap='gray', vmin=-1, vmax=1)
    #plt.imshow(image_hat)
    plt.tight_layout()
    plt.title("Reconstruct")

    plt.subplot(1,2,2)
    plt.imshow(image, cmap='gray', vmin=-1, vmax=1)
    #plt.imshow(image)
    plt.tight_layout()
    plt.title("Original")

    plt.show()

In [None]:
class Anomaly_Dataset(Dataset):
    def __init__(self,
                 root
                 ):
        super(Anomaly_Dataset, self).__init__()

        self.data = Anomaly_Dataset.load_dataset(root)
        self.image, self.label = Anomaly_Dataset.get_numpy(self.data)

    def __getitem__(self, item):
        x, y =  self.image[item], self.label[item]

        # RGB -> GRAY : (H, W)
        x = x[:,:,0]

        # (1, H, W)
        x = Anomaly_Dataset.normalization(x)
        #x = np.expand_dims(x, axis=0)

        return x, y


    def __len__(self):
        return len(self.data)


    @staticmethod
    def load_dataset(path):
        img_rows = 128
        img_cols = 128
        return image_dataset_from_directory(directory = path,
                                               label_mode = 'int',
                                               color_mode = 'rgb',
                                               shuffle = False,
                                               batch_size = None,
                                               image_size = (img_rows, img_cols),
                                               crop_to_aspect_ratio = True)

    @staticmethod
    def get_numpy(PrefetchDataset):
        """
        return:
            (N, H, W, C) , (N,)
        """
        images = []
        labels = []
        for (image, label) in PrefetchDataset:
            images.append(image)
            labels.append(label)
        return np.array(images), np.array(labels)

    @staticmethod
    def rgb_2_gray(x):
        """
        (H, W, C) --> (H, W)
        """
        return cv2.cvtColor(x, cv2.COLOR_BGR2GRAY)

    @staticmethod
    def normalization(x):
        """
        Args:
            x : np.array : (H, W)

        Return:
            np.array : (H, W)
        """
        x = x - x.min(keepdims=True)
        x = x / x.max(keepdims=True)
        x = x - 0.5
        return  x / 0.5

### Net : It should be same as ...

In [None]:
class Encoder(nn.Module):

    def __init__(self, in_channels, init_features):
        super(Encoder, self).__init__()

        features = init_features
        self.encoder1 = Encoder._block(in_channels, features, name="enc1")
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder2 = Encoder._block(features, features * 2, name="enc2")
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder3 = Encoder._block(features * 2, features * 4, name="enc3")
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder4 = Encoder._block(features * 4, features * 8, name="enc4")
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.bottleneck = Encoder._block(features * 8, features * 16, name="bottleneck")


    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))
        enc4 = self.encoder4(self.pool3(enc3))
        bottleneck = self.bottleneck(self.pool4(enc4))

        return bottleneck, enc4, enc3, enc2, enc1


    @staticmethod
    def _block(in_channels, features, name):
        return nn.Sequential(
            OrderedDict(
                [
                    (
                        name + "conv1",
                        nn.Conv2d(
                            in_channels=in_channels,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm1", nn.BatchNorm2d(num_features=features)),
                    (name + "relu1", nn.ReLU(inplace=True)),
                    (
                        name + "conv2",
                        nn.Conv2d(
                            in_channels=features,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm2", nn.BatchNorm2d(num_features=features)),
                    (name + "relu2", nn.ReLU(inplace=True)),
                ]
            )
        )

In [None]:
class Decoder(nn.Module):

    def __init__(self, init_features, out_channels):
        super(Decoder, self).__init__()

        features = init_features

        self.upconv4 = nn.ConvTranspose2d(
            features * 16, features * 8, kernel_size=2, stride=2
        )
        self.decoder4 = Decoder._block((features * 8) * 2, features * 8, name="dec4")
        self.upconv3 = nn.ConvTranspose2d(
            features * 8, features * 4, kernel_size=2, stride=2
        )
        self.decoder3 = Decoder._block((features * 4) * 2, features * 4, name="dec3")
        self.upconv2 = nn.ConvTranspose2d(
            features * 4, features * 2, kernel_size=2, stride=2
        )
        self.decoder2 = Decoder._block((features * 2) * 2, features * 2, name="dec2")
        self.upconv1 = nn.ConvTranspose2d(
            features * 2, features, kernel_size=2, stride=2
        )
        self.decoder1 = Decoder._block(features * 2, features, name="dec1")

        self.conv = nn.Conv2d(
            in_channels=features, out_channels=out_channels, kernel_size=1
        )

    def forward(self, bottleneck, enc4, enc3, enc2, enc1):
        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)
        return torch.tanh(self.conv(dec1))

    @staticmethod
    def _block(in_channels, features, name):
        return nn.Sequential(
            OrderedDict(
                [
                    (
                        name + "conv1",
                        nn.Conv2d(
                            in_channels=in_channels,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm1", nn.BatchNorm2d(num_features=features)),
                    (name + "relu1", nn.ReLU(inplace=True)),
                    (
                        name + "conv2",
                        nn.Conv2d(
                            in_channels=features,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm2", nn.BatchNorm2d(num_features=features)),
                    (name + "relu2", nn.ReLU(inplace=True)),
                ]
            )
        )

In [None]:
class Generator(nn.Module):

    def __init__(self, in_channels, out_channels, init_features):
        super(Generator, self).__init__()

        features = init_features
        self.encoder = Encoder(in_channels, features)
        self.decoder = Decoder(features, out_channels)

    def forward(self, x):
        self.bottleneck, self.enc4, self.enc3, self.enc2, self.enc1 = self.encoder(x)
        x_hat = self.decoder(self.bottleneck, self.enc4, self.enc3, self.enc2, self.enc1)
        return x_hat

In [None]:
class Critic(nn.Module):
    def __init__(self, c_dim , df_dim=32):
        super(Critic, self).__init__()

        self.fea1 = nn.Sequential(
            SpectralNorm(nn.Conv2d(c_dim, df_dim, kernel_size=5, stride=2, padding=0)),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            SpectralNorm(nn.Conv2d(df_dim, df_dim*2, kernel_size=5, stride=2, padding=0)),
            nn.BatchNorm2d(df_dim*2, 0.5), nn.LeakyReLU(negative_slope=0.2, inplace=True),
            SpectralNorm(nn.Conv2d(df_dim*2, df_dim*4, kernel_size=3, stride=2, padding=0)),
            nn.BatchNorm2d(df_dim*4, 0.5), nn.LeakyReLU(negative_slope=0.2, inplace=True),
            SpectralNorm(nn.Conv2d(df_dim*4, df_dim*8, kernel_size=3, stride=2, padding=0)),
            nn.BatchNorm2d(df_dim*8, 0.5), nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.AvgPool2d(kernel_size=6, stride=6))

        self.concat = nn.Sequential(nn.Dropout(0.2), nn.Linear(256, 128),
                                    nn.Dropout(0.2), nn.Linear(128, 1)
                                )

    def forward(self, img):
        fea1_out = self.fea1(img).flatten(start_dim=1)
        validity = self.concat(fea1_out)

        return validity


### Hyper

In [None]:
device = 'cuda'

lr_decoder = 0.5
lr_y=0.001
beta_1 = 0.5
beta_2 = 0.999

# epochs
epochs= 500
disp_freq=20
display_step=20

step_bins = 1

### Dataset

In [None]:
# Best Generator and Criti
c_dim , gf_dim = 1, 8
df_dim = 32

gen = Generator(c_dim, c_dim, gf_dim).to(device)
crit = Critic(c_dim, df_dim).to(device)

best_model = torch.load("epoch_969_loss_0.0065.pt")

gen.load_state_dict(best_model['Generator'])
crit.load_state_dict(best_model['Critic'])

In [None]:
# Abnormal Tumar
root = "./../../dataset/kaggle1/tamiz"
dataset = Anomaly_Dataset(root)
test_loader = DataLoader(dataset, batch_size=1)

In [None]:
for image_test,_ in test_loader:
    image_test = image_test.unsqueeze(1).to(device)
    image_hat = gen(image_test)
i=0
plt.figure()
ploter(image_test[i,0].detach().cpu(), image_hat[i,0].detach().cpu())
plt.show()

### Load best

In [None]:
# Respect to ibottleneck, enc4, enc3, enc2, enc1
"""
bottleneck, enc4, enc3, enc2, enc1 = gen.encoder(image_test)
x_hat = gen.decoder(bottleneck, enc4, enc3, enc2, enc1)


bottleneck = Parameter(bottleneck)
enc4 = Parameter(enc4)
enc3 = Parameter(enc3)
enc2 = Parameter(enc2)
enc1 = Parameter(enc1)


gen.requires_grad_(False)
crit.requires_grad_(False)


# Optimizers
#optim_decoder = torch.optim.Adam([bottleneck], lr=lr_decoder, betas=(beta_1, beta_2))
#optim_decoder = torch.optim.Adam([bottleneck])
"""

In [None]:
# Respect to image_test (image_tomur)
for image_test,_ in test_loader:
    cur_batch_size = len(image_test)
    image_test = image_test.unsqueeze(1).to(device)

# optimizer
y = Parameter(image_test.clone())
optim_y = torch.optim.Adam([y], lr=lr_y, betas=(beta_1, beta_2))

gen.requires_grad_(False)
crit.requires_grad_(False)

### Train respect to y

In [None]:
import gc

gc.collect()

torch.cuda.empty_cache()

In [None]:
cur_step = 0
decoder_losses = []

In [None]:
for epoch in range(1,epochs+1):
    print(60 * "#")
    print(6 * "#" + " Epoch " + str(epoch) + " " + 45 * "#")
    print(60 * "#")

    # Set mode on "train mode"
    gen.train()
    crit.eval()

    # Loss pixel
    pixel_diff = nn.MSELoss()

    # Decoder Star Learning
    optim_y.zero_grad()
    y_hat = gen(y)
    y_loss = 10*( -crit(y_hat) ).mean(dim=0) + 5*pixel_diff(y, image_test)
    y_loss.backward(retain_graph=True)
    optim_y.step()
    decoder_losses += [y_loss.item()]

    ### Visualization code ###
    if cur_step % display_step == 0 and cur_step > 0:
        decoder_mean = sum(decoder_losses[-display_step:]) / display_step
        print(f"Epoch {epoch}, step {cur_step}: Generator loss: {decoder_mean}")


        plt.figure()
        #i=0
        ploter(image_test[i,0].detach().cpu(), y[i,0].detach().cpu())
        plt.show()
        diff = (normalization(y[0][0]) - image_test[0][0]).abs()
        plt.imshow(diff.detach().cpu(), cmap='gray', vmin=0)
        plt.show()


        plt.figure()
        num_examples = (len(decoder_losses) // step_bins) * step_bins
        plt.plot(
            range(num_examples // step_bins),
            torch.Tensor(decoder_losses[:num_examples]).view(-1, step_bins).mean(1),
            label="Decoder Loss"
        )
        plt.show()
    cur_step += 1

In [None]:
ploter(image_test[0][0].detach().cpu(), normalization(image_hat[0][0]).detach().cpu())

### Train

In [None]:
"""
for epoch in range(1,epochs+1):
    print(60 * "#")
    print(6 * "#" + " Epoch " + str(epoch) + " " + 45 * "#")
    print(60 * "#")

    # Set mode on "train mode"
    gen.train()
    crit.eval()
    for image_test,_ in test_loader:
        cur_batch_size = len(image_test)

        image_test = image_test.unsqueeze(1).to(device)

        # Decoder Star Learning
        optim_decoder.zero_grad()
        image_hat = gen.decoder(bottleneck, enc4, enc3, enc2, enc1)
        decoder_loss = ( -crit(image_hat) ).mean(dim=0)
        decoder_loss.backward(retain_graph=True)
        optim_decoder.step()
        decoder_losses += [decoder_loss.item()]

    ### Visualization code ###
    if cur_step % display_step == 0 and cur_step > 0:
        decoder_mean = sum(decoder_losses[-display_step:]) / display_step
        print(f"Epoch {epoch}, step {cur_step}: Generator loss: {decoder_mean}")


        plt.figure()
        #i=0
        #ploter(image_test[i,0].detach().cpu(), image_hat[i,0].detach().cpu())
        #plt.show()
        diff = (normalization(image_hat[0][0]) - image_test[0][0]).abs()
        plt.imshow(diff.detach().cpu(), cmap='gray', vmin=0)
        plt.show()


        plt.figure()
        num_examples = (len(decoder_losses) // step_bins) * step_bins
        plt.plot(
            range(num_examples // step_bins),
            torch.Tensor(decoder_losses[:num_examples]).view(-1, step_bins).mean(1),
            label="Decoder Loss"
        )
        plt.show()
    cur_step += 1
"""

### End

In [None]:
ploter(image_test[0][0].detach().cpu(), normalization(image_hat[0][0]).detach().cpu())