### Libs

In [30]:
from torchvision.datasets import CelebA
import os
import torch.nn as nn
import glob
import os
import torch
from torch.utils.data import Dataset
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import os
import torch
import math
# main libraries
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from torch.utils.data import SubsetRandomSampler
from torch.nn import functional as F
from torch import nn
from torch import optim
import torchvision
from torchvision import transforms, datasets
from torch.utils.data import Dataset
import math
import torchvision.transforms.functional as TF
import cv2
from keras.utils import image_dataset_from_directory
from torchvision import datasets, transforms
from torch.utils.data import random_split
from torch.nn.parameter import Parameter

from torch.nn import init
from torchvision.utils import make_grid
from collections import OrderedDict

import torch
from torch.utils.data import Dataset

from skimage import transform
import skimage.io as io
import numpy as np
from torchvision.utils import save_image

In [31]:
from torchvision.models import vgg16

### Helper Functions

In [32]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [33]:
def show_tensor_images(image_tensor, num_images=25, size=(1, 96, 96)):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in a uniform grid.
    '''
    image_unflat = image_tensor.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

In [34]:
def ploter(image, image_hat):
    """
    (H, W)
    """
    plt.figure()
    plt.subplot(1,2,1)
    #plt.imshow(image_hat, cmap='gray', vmin=-1, vmax=1)
    plt.imshow(image_hat)
    plt.tight_layout()
    plt.title("Reconstruct")

    plt.subplot(1,2,2)
    #plt.imshow(image, cmap='gray', vmin=-1, vmax=1)
    plt.imshow(image)
    plt.tight_layout()
    plt.title("Original")

    plt.show()

### Dataloader

In [35]:
class Anomaly_Dataset(Dataset):
    def __init__(self,
                 root
                 ):
        super(Anomaly_Dataset, self).__init__()

        self.data = Anomaly_Dataset.load_dataset(root)
        self.image, self.label = Anomaly_Dataset.get_numpy(self.data)

    def __getitem__(self, item):
        x, y =  self.image[item], self.label[item]

        # RGB -> GRAY : (H, W)
        x = x[:,:,0]

        # (1, H, W)
        x = Anomaly_Dataset.normalization(x)
        #x = np.expand_dims(x, axis=0)

        return x, y


    def __len__(self):
        return len(self.data)


    @staticmethod
    def load_dataset(path):
        img_rows = 96
        img_cols = 96
        return image_dataset_from_directory(directory = path,
                                               label_mode = 'int',
                                               color_mode = 'rgb',
                                               shuffle = False,
                                               batch_size = None,
                                               image_size = (img_rows, img_cols),
                                               crop_to_aspect_ratio = True)

    @staticmethod
    def get_numpy(PrefetchDataset):
        """
        return:
            (N, H, W, C) , (N,)
        """
        images = []
        labels = []
        for (image, label) in PrefetchDataset:
            images.append(image)
            labels.append(label)
        return np.array(images), np.array(labels)

    @staticmethod
    def rgb_2_gray(x):
        """
        (H, W, C) --> (H, W)
        """
        return cv2.cvtColor(x, cv2.COLOR_BGR2GRAY)

    @staticmethod
    def normalization(x):
        """
        Args:
            x : np.array : (H, W)

        Return:
            np.array : (H, W)
        """
        x = x - x.min(keepdims=True)
        x = x / x.max(keepdims=True)
        x = x - 0.5
        return  x / 0.5

### Generator

In [36]:
class Generator(nn.Module):
    def __init__(self, img_size, latent_dim, channels):
        super(Generator, self).__init__()

        self.init_size = img_size // 4
        self.linear_layer = nn.Sequential(nn.Linear(latent_dim, 128 * self.init_size ** 2))

        self.conv_layers = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )

    def forward(self, z):
        out = self.linear_layer(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        image = self.conv_layers(out)
        return image

In [37]:
generator = Generator(img_size=96, latent_dim=100, channels=3).cuda()
count_parameters(generator)

7670275

### Disc New

In [38]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        # Load VGG (Pretrain = False)
        self.conv_layers = vgg16().features[:9]


        self.adverse_layer = nn.Sequential(
            nn.Linear(128 * 48 * 48, 1),
            nn.Sigmoid())

    def forward(self, image):
        out = self.conv_layers(image)
        out = out.view(out.shape[0], -1)
        validity = self.adverse_layer(out)
        return validity

In [39]:
# TEST
x = torch.ones(10, 3, 96, 96)
discriminator = Discriminator()
discriminator(x).shape

torch.Size([10, 1])

In [40]:
def convert_relu_to_softplus(model):
    for child_name, child in model.named_children():
        if isinstance(child, nn.ReLU):
            setattr(model, child_name, nn.LeakyReLU(negative_slope=0.2))
        else:
            convert_relu_to_softplus(child)

#convert_relu_to_softplus(discriminator)

### Losses

In [41]:
def discriminator_layers_loss(discriminator, vgg, real_images, lamda = 0):
    # CONV2D # 0, 2, 5, 7
    # RELU   # 1, 3, 6, 8


    criterion = nn.MSELoss()
    similarity_loss = torch.nn.CosineSimilarity()

    y_pred_1, y_pred_3, y_pred_6, y_pred_8 = discriminator.conv_layers[:1+1](real_images),\
                                             discriminator.conv_layers[:3+1](real_images),\
                                             discriminator.conv_layers[:6+1](real_images),\
                                             discriminator.conv_layers[:8+1](real_images)

    y_1, y_3, y_6, y_8 = vgg.features[:1+1](real_images),\
                         vgg.features[:3+1](real_images),\
                         vgg.features[:6+1](real_images),\
                         vgg.features[:8+1](real_images)


    abs_loss_1 = criterion(y_pred_1, y_1)
    loss_1 = torch.mean(1 - similarity_loss(y_pred_1.view(y_pred_1.shape[0], -1), y_1.view(y_1.shape[0], -1)))

    abs_loss_3 = criterion(y_pred_3, y_3)
    loss_3 = torch.mean(1 - similarity_loss(y_pred_3.view(y_pred_3.shape[0], -1), y_3.view(y_3.shape[0], -1)))

    abs_loss_6 = criterion(y_pred_6, y_6)
    loss_6 = torch.mean(1 - similarity_loss(y_pred_6.view(y_pred_6.shape[0], -1), y_6.view(y_6.shape[0], -1)))

    abs_loss_8 = criterion(y_pred_8, y_8)
    loss_8 = torch.mean(1 - similarity_loss(y_pred_8.view(y_pred_8.shape[0], -1), y_8.view(y_8.shape[0], -1)))


    total_loss = loss_1 + loss_3 + loss_6 + loss_8 + lamda * (
                    abs_loss_1 + abs_loss_3 + abs_loss_6 + abs_loss_8)

    return total_loss

In [42]:
def generator_layers_loss(vgg, real_images, gen_imgs, lamda = 0):
    # CONV2D # 0, 2, 5, 7
    # RELU   # 1, 3, 6, 8

    criterion = nn.MSELoss()
    similarity_loss = torch.nn.CosineSimilarity()

    y_pred_1, y_pred_3, y_pred_6, y_pred_8 = vgg.features[:1+1](gen_imgs),\
                                             vgg.features[:3+1](gen_imgs),\
                                             vgg.features[:6+1](gen_imgs),\
                                             vgg.features[:8+1](gen_imgs)

    y_1, y_3, y_6, y_8 = vgg.features[:1+1](real_images),\
                         vgg.features[:3+1](real_images),\
                         vgg.features[:6+1](real_images),\
                         vgg.features[:8+1](real_images)


    abs_loss_1 = criterion(y_pred_1, y_1)
    loss_1 = torch.mean(1 - similarity_loss(y_pred_1.view(y_pred_1.shape[0], -1), y_1.view(y_1.shape[0], -1)))

    abs_loss_3 = criterion(y_pred_3, y_3)
    loss_3 = torch.mean(1 - similarity_loss(y_pred_3.view(y_pred_3.shape[0], -1), y_3.view(y_3.shape[0], -1)))

    abs_loss_6 = criterion(y_pred_6, y_6)
    loss_6 = torch.mean(1 - similarity_loss(y_pred_6.view(y_pred_6.shape[0], -1), y_6.view(y_6.shape[0], -1)))

    abs_loss_8 = criterion(y_pred_8, y_8)
    loss_8 = torch.mean(1 - similarity_loss(y_pred_8.view(y_pred_8.shape[0], -1), y_8.view(y_8.shape[0], -1)))


    total_loss = loss_1 + loss_3 + loss_6 + loss_8 + lamda * (
                    abs_loss_1 + abs_loss_3 + abs_loss_6 + abs_loss_8)

    return total_loss

### Hyperparameters

In [52]:
#root_save = "./best_models"
root_ds = "./../../../dataset/kaggle1/no"
device = 'cuda'

#root_save = "/content/drive/MyDrive/MRI Arman/best_models"
#root_ds = "/content/drive/MyDrive/MRI Arman/dataset/kaggle1/no"

batch_size = 10
img_size = 96
latent_dim = 100
channels = 3
num_epochs = 5000
sample_interval = 1

num_data = 91

### Dataset

In [53]:
indx = list(range(num_data))

train_sampler = SubsetRandomSampler(indx[:81])
test_sampler = SubsetRandomSampler(indx[81:])

In [54]:
dataset = Anomaly_Dataset(root_ds)
train_loader = DataLoader(dataset, batch_size= batch_size, sampler=train_sampler)
test_loader=DataLoader(dataset, batch_size= len(test_sampler), sampler=test_sampler)

Found 91 files belonging to 1 classes.


### Loss

In [55]:
os.environ['TORCH_HOME'] = './../pretrain_models'
vgg = vgg16('vgg16-397923af.pth').to(device)

RuntimeError: CUDA out of memory. Tried to allocate 392.00 MiB (GPU 0; 2.00 GiB total capacity; 966.10 MiB already allocated; 0 bytes free; 990.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [56]:
generator = Generator(img_size=img_size, latent_dim=latent_dim, channels=channels).cuda()
discriminator = Discriminator().cuda()

In [57]:
adversarial_loss = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters())
optimizer_D = optim.Adam(discriminator.parameters())

In [58]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0.0)

In [59]:
generator = generator.apply(weights_init_normal)
discriminator = discriminator.apply(weights_init_normal)

### Train

In [60]:
for epoch in range(1, num_epochs+1):
    total_G_loss = 0.0
    total_D_loss = 0.0
    for i, (real_images, _) in enumerate(train_loader):
        valid = torch.FloatTensor(real_images.shape[0], 1).fill_(1.0).cuda()
        fake = torch.FloatTensor(real_images.shape[0], 1).fill_(0.0).cuda()
        real_images = real_images.unsqueeze(1).cuda()
        real_images = real_images.repeat(1,3,1,1)

        #  Train Generator
        optimizer_G.zero_grad()
        z = torch.FloatTensor(np.random.normal(0, 1, (real_images.shape[0], latent_dim))).cuda()
        gen_imgs = generator(z)
        g_loss_adv = adversarial_loss(discriminator(gen_imgs), valid)
        g_loss_features = generator_layers_loss(vgg, real_images, gen_imgs, lamda = 0)

        ##########  C. Pixel Loss
        #g_loss_pixel = feature_loss(real_images, gen_imgs)

        ##########  C. All together
        g_loss = g_loss_adv + 0.1*g_loss_features #+ 2e-4*g_loss_pixel


        g_loss.backward()
        optimizer_G.step()
        total_G_loss += g_loss.cpu().detach().numpy()



        #  Train Discriminator
        optimizer_D.zero_grad()
        discriminator_opinion_real = discriminator(real_images)
        discriminator_opinion_fake = discriminator(gen_imgs.detach())
        real_loss = adversarial_loss(discriminator_opinion_real, valid)
        fake_loss = adversarial_loss(discriminator_opinion_fake, fake)

        d_loss_adv = (real_loss + fake_loss) / 2
        d_loss_layers = discriminator_layers_loss(discriminator, vgg, real_images, lamda = 0)

        d_loss = d_loss_adv + 0.1*d_loss_layers

        d_loss.backward()
        optimizer_D.step()
        total_D_loss += d_loss.cpu().detach().numpy()

    if epoch % sample_interval ==0:
        for _, (real_images, _) in enumerate(test_loader):
            real_images = real_images.unsqueeze(1).cuda()
            # Real and Fake images

            # Attention
            z = torch.FloatTensor(np.random.normal(0, 1, (real_images.shape[0], latent_dim))).cuda()
            gen_imgs = generator(z)

            show_tensor_images(gen_imgs)
            show_tensor_images(real_images)

        # show losses
        print(
        "[Epoch {}/{}] \t[D loss: {:.3f}] \t[G loss: {:.3f}]".format(
        epoch, num_epochs, total_D_loss, total_G_loss)
        )

        """
        # save all
        torch.save({"epoch": epoch,
        "state_dict_G": generator.state_dict(),
        "state_dict_D": discriminator.state_dict(),
        "optimizer_G": optimizer_G.state_dict(),
        "optimizer_D": optimizer_D.state_dict()
        }, root_save + f"/epoch_{epoch}.pt")
        """

RuntimeError: CUDA out of memory. Tried to allocate 24.00 MiB (GPU 0; 2.00 GiB total capacity; 1.10 GiB already allocated; 0 bytes free; 1.12 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
"""
if epoch % sample_interval == 0 and i % (len(dataloader)/5) == 0:
    save_image(gen_imgs.data[0,0],
               "images/{}_{}.png".format(str(epoch).zfill(len(str(num_epochs))),
                                         str(i).zfill(len(str(len(dataloader))))),
               normalize=True)
"""

### End