# Very Deep Variational Autoencoder + SVDD

Preliminaries

In [None]:
!pip install onnx

In [None]:
from google.colab import drive
drive.mount('/content/drive')

### Code taken and adapted from: https://github.com/vvvm23/vdvae

In [None]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image

import numpy as np
import itertools
import datetime
from tqdm.notebook import tqdm, trange

In [None]:
import locale
locale.getpreferredencoding = lambda: "UTF-8"

!mkdir ./imgs/
!mkdir ./saved\_checkpoints/
!mkdir ./results

### Checkpoint | Helper | HyperParameters

In [None]:
class Checkpoint(dict):
    def __getattr__(self, attr):
        try:
            return self[attr]
        except KeyError:
            return None

    def __setattr__(self, attr, value):
        self[attr] = value

In [None]:
info = lambda s: print(f"\33[92m> {s}\33[0m")
error = lambda s: print(f"\33[31m! {s}\33[0m")
warning = lambda s: print(f"\33[94m$ {s}\33[0m")

def get_device(try_cuda):
    if try_cuda == False:
        info("CUDA disabled by hyperparameters.")
        return torch.device('cpu')
    if torch.cuda.is_available():
        info("CUDA is available.")
        return torch.device('cuda')
    error("CUDA is unavailable but selected in hyperparameters.")
    error("Falling back to default device.")
    return torch.device('cpu')

In [None]:
class Hyperparameters(dict):
    def __getattr__(self, attr):
        try:
            return self[attr]
        except KeyError:
            return None

    def __setattr__(self, attr, value):
        self[attr] = value

def get_parameters(task):
    print("Loading HPS for task", task)
    HPS = Hyperparameters()
    HPS['cuda'] = True
    HPS['checkpoint'] = 5
    HPS['tqdm'] = True

    if task == 'cifar10':
        HPS['dataset'] = 'cifar10'
        HPS['batch_size'] = 32

        HPS['in_channels'] = 3
        HPS['h_width'] = 64
        HPS['m_width'] = 32
        HPS['z_dim'] = 16
        HPS['nb_blocks'] = 4
        HPS['nb_res_blocks'] = 3
        HPS['scale_rate'] = 2

        HPS['nb_iterations'] = 1_100_00
        HPS['lr'] = 2e-4
        HPS['decay'] = 1e-2
    elif task == 'stl10':
        HPS['dataset'] = 'cifar10'
        HPS['batch_size'] = 16

        HPS['in_channels'] = 3
        HPS['h_width'] = 128
        HPS['m_width'] = 64
        HPS['z_dim'] = 32
        HPS['nb_blocks'] = 5
        HPS['nb_res_blocks'] = 3
        HPS['scale_rate'] = 2

        HPS['nb_iterations'] = 1_100_00
        HPS['lr'] = 2e-4
        HPS['decay'] = 1e-2
    elif task == 'mnist' or task == 'fashion_mnist':
        HPS['dataset'] = task
        HPS['batch_size'] = 32

        HPS['in_channels'] = 1
        HPS['h_width'] = 32
        HPS['m_width'] = 16
        HPS['z_dim'] = 8
        HPS['nb_blocks'] = 2
        HPS['nb_res_blocks'] = 3
        HPS['scale_rate'] = 2

        HPS['nb_iterations'] = 100_000
        HPS['lr'] = 2e-4
        HPS['decay'] = 1e-2
    else:
        error("Unrecognized HPS task! Exiting..")
        exit()

    return HPS

### VD-VAE

Helper functions and classes

In [None]:
"""
    Encoder Components:
        - Encoder, contains all the EncoderBlocks and manages data flow through them.
        - EncoderBlock, contains sub-blocks of residual units and a pooling layer.
        - ResidualBlock, contains a block of residual connections, as described in the paper (1x1,3x3,3x3,1x1)
            - We could slightly adapt, and make it a ReZero connection. Needs some testing.

    Decoder Components:
        - Decoder, contains all DecoderBlocks and manages data flow through them.
        - DecoderBlock, contains sub-blocks of top-down units and an unpool layer.
        - TopDownBlock, implements the topdown block from the original paper.

    All is encapsulated in the main VAE class.

"""

class ConvBuilder:
    def _bconv(in_dim, out_dim, kernel_size, stride, padding):
        conv = nn.Conv2d(in_dim, out_dim, kernel_size, stride=stride, padding=padding)
        return conv
    def b1x1(in_dim, out_dim):
        return ConvBuilder._bconv(in_dim, out_dim, 1, 1, 0)
    def b3x3(in_dim, out_dim):
        return ConvBuilder._bconv(in_dim, out_dim, 3, 1, 1)

"""
    Diagonal Gaussian Distribution and loss.
    Taken directly from OpenAI implementation
    Decorators means these functions will be compiled as TorchScript
"""
@torch.jit.script
def gaussian_analytical_kl(mu1, mu2, logsigma1, logsigma2):
    return -0.5 + logsigma2 - logsigma1 + 0.5 * (logsigma1.exp() ** 2 + (mu1 - mu2) ** 2) / (logsigma2.exp() ** 2)

@torch.jit.script
def draw_gaussian_diag_samples(mu, logsigma):
    eps = torch.empty_like(mu).normal_(0., 1.)
    z = torch.exp(logsigma) * eps + mu
    return z

"""
    Helper module to call super().__init__() for us
"""
class HelperModule(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.build(*args, **kwargs)

    def build(self, *args, **kwargs):
        raise NotImplementedError

Encoder Components

In [None]:
class ResidualBlock(HelperModule):
    def build(self, in_width, hidden_width, rezero=False): # hidden_width should function as a bottleneck!
        self.conv = nn.ModuleList([
            ConvBuilder.b1x1(in_width, hidden_width),
            ConvBuilder.b3x3(hidden_width, hidden_width),
            ConvBuilder.b3x3(hidden_width, hidden_width),
            ConvBuilder.b1x1(hidden_width, in_width)
        ])
        if rezero:
            self.gate = nn.Parameter(torch.tensor(0.0))
        else:
            self.gate = 1.0

    def forward(self, x):
        xh = x
        for l in self.conv:
            xh = l(F.gelu(xh))
        y = x + self.gate*xh
        return y

class EncoderBlock(HelperModule):
    def build(self, in_dim, middle_width, nb_r_blocks, downscale_rate):
        self.downscale_rate = downscale_rate
        self.res_blocks = nn.ModuleList([
            ResidualBlock(in_dim, middle_width)
        for _ in range(nb_r_blocks)])

    def forward(self, x):
        y = x
        for l in self.res_blocks:
            y = l(y)
        a = y
        y = F.avg_pool2d(y, kernel_size=self.downscale_rate, stride=self.downscale_rate)
        return y, a # y is input to next block, a is activations to topdown layer


class Encoder(HelperModule):
    def build(self, in_dim, hidden_width, middle_width, nb_encoder_blocks, nb_res_blocks=3, downscale_rate=2):
        self.in_conv = ConvBuilder.b3x3(in_dim, hidden_width)
        self.enc_blocks = nn.ModuleList([
            EncoderBlock(hidden_width, middle_width, nb_res_blocks, 1 if i==(nb_encoder_blocks-1) else downscale_rate)
        for i in range(nb_encoder_blocks)])

        # TODO: could just pass np.sqrt( ... ) value to EncoderBlock, rather than this weird loop
        # it is the same in every block.
        for be in self.enc_blocks:
            for br in be.res_blocks:
                br.conv[-1].weight.data *= np.sqrt(1 / (nb_encoder_blocks*nb_res_blocks))

    def forward(self, x):
        x = self.in_conv(x)
        activations = [x]
        for b in self.enc_blocks:
            x, a = b(x)
            activations.append(a)
        return activations

In [None]:
class SVDDLayer(HelperModule):
    def build(self, in_features, out_features, bias=False):
        self.flatten = nn.Flatten()
        self.dense_extract = nn.Linear(in_features=in_features, out_features=out_features, bias=bias)
        # self.dropout = nn.Dropout(0.2)
        self.dense_reconstruct = nn.Linear(in_features=out_features, out_features=in_features, bias=bias)
        # self.dense_extract = nn.Sequential(
        #     nn.Linear(in_features=in_features, out_features=1000, bias=bias),
        #     nn.LeakyReLU(),
        #     nn.Linear(in_features=1000, out_features=out_features, bias=bias),
        # )
        # self.dense_reconstruct = nn.Sequential(
        #     nn.Linear(in_features=out_features, out_features=1000, bias=bias),
        #     nn.LeakyReLU(),
        #     nn.Linear(in_features=1000, out_features=in_features, bias=bias)
        # )

    def forward(self, x):
        init_shapes = x.shape
        x = self.flatten(x)
        lin_x = F.leaky_relu(self.dense_extract(x))
        x = F.leaky_relu(self.dense_reconstruct(lin_x))
        x = x.view(*init_shapes)
        return x, lin_x

Decoder Components

In [None]:
class Block(HelperModule):
    def build(self, in_width, hidden_width, out_width): # hidden_width should function as a bottleneck!
        self.conv = nn.ModuleList([
            ConvBuilder.b1x1(in_width, hidden_width),
            ConvBuilder.b3x3(hidden_width, hidden_width),
            ConvBuilder.b3x3(hidden_width, hidden_width),
            ConvBuilder.b1x1(hidden_width, out_width)
        ])

    def forward(self, x):
        for l in self.conv:
            x = l(F.gelu(x))
        return x

class TopDownBlock(HelperModule):
    def build(self, in_width, middle_width, z_dim):
        self.cat_conv = Block(in_width*2, middle_width, z_dim*2) # parameterises mean and variance
        self.prior = Block(in_width, middle_width, z_dim*2 + in_width) # parameterises mean, variance and xh
        self.out_res = ResidualBlock(in_width, middle_width)
        self.z_conv = ConvBuilder.b1x1(z_dim, in_width)
        self.z_dim = z_dim

    def forward(self, x, a):
        xa = torch.cat([x,a], dim=1)
        qm, qv = self.cat_conv(xa).chunk(2, dim=1) # Calculate q distribution parameters. Chunk into 2 (first z_dim is mean, second is variance)
        pfeat = self.prior(x)
        pm, pv, px = pfeat[:, :self.z_dim], pfeat[:, self.z_dim:self.z_dim*2], pfeat[:, self.z_dim*2:]
        x = x + px

        z = draw_gaussian_diag_samples(qm, qv)
        kl = gaussian_analytical_kl(qm, pm, qv, pv)

        z = self.z_conv(z)
        x = x + z
        x = self.out_res(x)

        # print(f"z={z.shape}, qm={qm.shape}, qv={qv.shape}, pfeat={pfeat.shape}")
        return x, kl

    def sample(self, x):
        pfeat = self.prior(x)
        pm, pv, px = pfeat[:, :self.z_dim], pfeat[:, self.z_dim:self.z_dim*2], pfeat[:, self.z_dim*2:]
        x = x + px

        z = draw_gaussian_diag_samples(pm, pv)

        x = x + self.z_conv(z)
        x = self.out_res(x)
        return x

class DecoderBlock(HelperModule):
    def build(self, in_dim, middle_width, z_dim, nb_td_blocks, upscale_rate):
        self.upscale_rate = upscale_rate
        self.td_blocks = nn.ModuleList([
            TopDownBlock(in_dim, middle_width, z_dim)
        for _ in range(nb_td_blocks)])

    def forward(self, x, a):
        x = F.interpolate(x, scale_factor=self.upscale_rate)
        block_kl = []
        for b in self.td_blocks:
            x, kl = b(x, a)
            block_kl.append(kl)
        return x, block_kl

    def sample(self, x):
        x = F.interpolate(x, scale_factor=self.upscale_rate)
        for b in self.td_blocks:
            x = b.sample(x)
        return x

class Decoder(HelperModule):
    def build(self, in_dim, middle_width, out_dim, z_dim, nb_decoder_blocks, nb_td_blocks=3, upscale_rate=2):
        self.dec_blocks = nn.ModuleList([
            DecoderBlock(in_dim, middle_width, z_dim, nb_td_blocks, 1 if i == 0 else upscale_rate)
         for i in range(nb_decoder_blocks)])
        self.in_dim = in_dim
        self.out_conv = ConvBuilder.b3x3(in_dim, out_dim)

        for bd in self.dec_blocks:
            for bt in bd.td_blocks:
                bt.z_conv.weight.data *= np.sqrt(1 / (nb_decoder_blocks*nb_td_blocks))
                bt.out_res.conv[-1].weight.data *= np.sqrt(1 / (nb_decoder_blocks*nb_td_blocks))

    def forward(self, activations, x_encoder):
        activations = activations[::-1]
        x = x_encoder
        # x = None
        decoder_kl = []
        for i, b in enumerate(self.dec_blocks):
            a = activations[i]
            if x == None:
                x = torch.zeros_like(a)
            x, block_kl = b(x, a)
            decoder_kl.extend(block_kl)

        x = self.out_conv(x)
        return x, decoder_kl

    def sample(self, nb_samples):
        x = None
        for b in self.dec_blocks:
            if x == None:
                x = torch.zeros(nb_samples, self.in_dim, 4, 4).to('cuda') # TODO: Variable device and size
            x = b.sample(x)
        x = self.out_conv(x)
        return x

VDVAE

In [None]:
class VAE(HelperModule):
    def build(self, in_dim, hidden_width, middle_width, z_dim, nb_blocks=4, nb_res_blocks=3, scale_rate=2):
        self.encoder = Encoder(in_dim, hidden_width, middle_width, nb_blocks, nb_res_blocks=nb_res_blocks, downscale_rate=scale_rate)
        self.svdd_layer = SVDDLayer(in_features=32*14*14, out_features=128, bias=False)
        self.decoder = Decoder(hidden_width, middle_width, in_dim, z_dim, nb_blocks, nb_td_blocks=nb_res_blocks, upscale_rate=scale_rate)

    def forward(self, x):
        activations = self.encoder(x)
        last_act = activations[-1]
        x_encoded, features = self.svdd_layer(last_act)
        y, decoder_kl = self.decoder(activations, x_encoded)
        # print(str([dkl.shape for dkl in decoder_kl]))
        return y, decoder_kl, features

    def sample(self, nb_samples):
        return self.decoder.sample(nb_samples)

In [None]:
class VDVAE_MNIST(HelperModule):
    def build(self, in_dim, hidden_width, middle_width, z_dim, nb_blocks=4, nb_res_blocks=3, scale_rate=2):
        self.encoder = Encoder(in_dim, hidden_width, middle_width, nb_blocks, nb_res_blocks=nb_res_blocks, downscale_rate=scale_rate)
        self.svdd_layer = SVDDLayer(in_features=32*14*14, out_features=128, bias=False)
        self.decoder = Decoder(hidden_width, middle_width, in_dim, z_dim, nb_blocks, nb_td_blocks=nb_res_blocks, upscale_rate=scale_rate)

    def forward(self, x):
        activations = self.encoder(x)
        last_act = activations[-1]
        x_encoded, features = self.svdd_layer(last_act)
        y, decoder_kl = self.decoder(activations, x_encoded)
        # print(str([dkl.shape for dkl in decoder_kl]))
        return y, decoder_kl, features

    def sample(self, nb_samples):
        return self.decoder.sample(nb_samples)

In [None]:
class VDVAE_CIFAR10(HelperModule):
    def build(self, in_dim, hidden_width, middle_width, z_dim, nb_blocks=4, nb_res_blocks=3, scale_rate=2):
        self.encoder = Encoder(in_dim, hidden_width, middle_width, nb_blocks, nb_res_blocks=nb_res_blocks, downscale_rate=scale_rate)
        self.svdd_layer = SVDDLayer(in_features=64*4*4, out_features=128, bias=False)
        self.decoder = Decoder(hidden_width, middle_width, in_dim, z_dim, nb_blocks, nb_td_blocks=nb_res_blocks, upscale_rate=scale_rate)

    def forward(self, x):
        activations = self.encoder(x)
        last_act = activations[-1]
        # print([act.shape for act in activations])
        x_encoded, features = self.svdd_layer(last_act)
        y, decoder_kl = self.decoder(activations, x_encoded)
        # print(str([dkl.shape for dkl in decoder_kl]))
        return y, decoder_kl, features

    def sample(self, nb_samples):
        return self.decoder.sample(nb_samples)

### SVDD Utils

In [None]:
def svdd_score(x, model, c):
    x = x.to(device)
    _, _, features = model(x)
    svdd_loss = torch.sum((features - c) ** 2, dim=1)
    return svdd_loss

def update_radius(train_loader, model, c, device, nu: float = 0.05):
    "Update Radius after training"

    info('Updating radius...')
    start_time = time.time()
    distances = []
    model.eval()
    with torch.no_grad():
        for data in tqdm(train_loader):
            inputs, labels, idx = data
            scores = svdd_score(inputs, model, c)
            distances.append(scores)
    distances = torch.cat(distances, dim=0)

    new_radius = get_radius(distances, nu)
    return new_radius


def get_radius(dist: torch.Tensor, nu: float = 0.05):
    """Optimally solve for radius R via the (1-nu)-quantile of distances."""
    return np.quantile(np.sqrt(dist.clone().data.cpu().numpy()), 1 - nu)

### Training and Evaluation

In [None]:
from abc import ABC, abstractmethod
from PIL import Image
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import CIFAR10, MNIST
import torchvision.transforms as transforms

class BaseADDataset(ABC):
    """Anomaly detection dataset base class."""

    def __init__(self, root: str):
        super().__init__()
        self.root = root  # root path to data

        self.n_classes = 2  # 0: normal, 1: outlier
        self.normal_classes = None  # tuple with original class labels that define the normal class
        self.outlier_classes = None  # tuple with original class labels that define the outlier class

        self.train_set = None  # must be of type torch.utils.data.Dataset
        self.test_set = None  # must be of type torch.utils.data.Dataset

    @abstractmethod
    def loaders(self, batch_size: int, shuffle_train=True, shuffle_test=False, num_workers: int = 0) -> (
            DataLoader, DataLoader):
        """Implement data loaders of type torch.utils.data.DataLoader for train_set and test_set."""
        pass

    def __repr__(self):
        return self.__class__.__name__

class TorchvisionDataset(BaseADDataset):
    """TorchvisionDataset class for datasets already implemented in torchvision.datasets."""

    def __init__(self, root: str):
        super().__init__(root)

    def loaders(self, batch_size: int, shuffle_train=True, shuffle_test=False, num_workers: int = 0) -> (
            DataLoader, DataLoader):
        train_loader = DataLoader(
            dataset=self.train_set,
            batch_size=batch_size,
            shuffle=shuffle_train,
            num_workers=num_workers
        )
        test_loader = DataLoader(
            dataset=self.test_set,
            batch_size=batch_size,
            shuffle=shuffle_test,
            num_workers=num_workers
        )
        return train_loader, test_loader

def get_target_label_idx(labels, targets):
    """
    Get the indices of labels that are included in targets.
    :param labels: array of labels
    :param targets: list/tuple of target labels
    :return: list with indices of target labels
    """
    return np.argwhere(np.isin(labels, targets)).flatten().tolist()


def global_contrast_normalization(x, scale='l2'):
    """
    Apply global contrast normalization to tensor, i.e. subtract mean across features (pixels) and normalize by scale,
    which is either the standard deviation, L1- or L2-norm across features (pixels).
    Note this is a *per sample* normalization globally across features (and not across the dataset).
    """

    assert scale in ('l1', 'l2')

    n_features = int(np.prod(x.shape))

    mean = torch.mean(x)  # mean over all features (pixels) per sample
    x -= mean

    if scale == 'l1':
        x_scale = torch.mean(torch.abs(x))

    if scale == 'l2':
        x_scale = torch.sqrt(torch.sum(x ** 2)) / n_features

    x /= x_scale

    return x

# Tranformation class that ensures the data is in range [0, 1], with no 0-eps, 1+eps values
class ClampImageTo01():
    """Clamp image pixel values to be between 0 and 1."""

    def __call__(self, img):
        return torch.clamp(img, 0, 1)

class MNIST_Dataset(TorchvisionDataset):
    def __init__(self, root: str, normal_class=0):
        super().__init__(root)

        self.n_classes = 2  # 0: normal, 1: outlier
        self.normal_classes = tuple([normal_class])
        self.outlier_classes = list(range(0, 10))
        self.outlier_classes.remove(normal_class)

        # Pre-computed min and max values (after applying GCN) from train data per class
        min_max = [(-0.8826567065619495, 9.001545489292527),
                   (-0.6661464580883915, 20.108062262467364),
                   (-0.7820454743183202, 11.665100841080346),
                   (-0.7645772083211267, 12.895051191467457),
                   (-0.7253923114302238, 12.683235701611533),
                   (-0.7698501867861425, 13.103278415430502),
                   (-0.778418217980696, 10.457837397569108),
                   (-0.7129780970522351, 12.057777597673047),
                   (-0.8280402650205075, 10.581538445782988),
                   (-0.7369959242164307, 10.697039838804978)]

        # MNIST preprocessing: GCN (with L1 norm) and min-max feature scaling to [0,1]
        transform = transforms.Compose([transforms.ToTensor(),
                                        # transforms.Lambda(lambda x: global_contrast_normalization(x, scale='l1')),
                                        # transforms.Normalize([min_max[normal_class][0]],
                                        #                      [min_max[normal_class][1] - min_max[normal_class][0]]),
                                        ClampImageTo01(),
                                        # TODO: remove this sometimes
                                        # transforms.Lambda(torch.flatten)
                                        ])

        target_transform = transforms.Lambda(lambda x: int(x in self.outlier_classes))

        train_set = MyMNIST(root=self.root, train=True, download=True,
                            transform=transform, target_transform=target_transform)
        # Subset train_set to normal class
        train_idx_normal = get_target_label_idx(train_set.targets.clone().data.cpu().numpy(), self.normal_classes)
        info(f"Num of normal indices: {len(train_idx_normal)}")
        self.train_set = Subset(train_set, train_idx_normal)

        self.test_set = MyMNIST(root=self.root, train=False, download=True,
                                transform=transform, target_transform=target_transform)


class MyMNIST(MNIST):
    """Torchvision MNIST class with patch of __getitem__ method to also return the index of a data sample."""

    def __init__(self, *args, **kwargs):
        super(MyMNIST, self).__init__(*args, **kwargs)

    def __getitem__(self, index):
        """Override the original method of the MNIST class.
        Args:
            index (int): Index
        Returns:
            triple: (image, target, index) where target is index of the target class.
        """
        img, target = self.data[index], self.targets[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img.numpy(), mode='L')

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target, index  # only line changed

In [None]:
class CIFAR10_Dataset(TorchvisionDataset):

    def __init__(self, root: str, normal_class=5):
        super().__init__(root)

        self.n_classes = 2  # 0: normal, 1: outlier
        self.normal_classes = tuple([normal_class])
        self.outlier_classes = list(range(0, 10))
        self.outlier_classes.remove(normal_class)

        # Pre-computed min and max values (after applying GCN) from train data per class
        min_max = [(-28.94083453598571, 13.802961825439636),
                   (-6.681770233365245, 9.158067708230273),
                   (-34.924463588638204, 14.419298165027628),
                   (-10.599172931391799, 11.093187820377565),
                   (-11.945022995801637, 10.628045447867583),
                   (-9.691969487694928, 8.948326776180823),
                   (-9.174940012342555, 13.847014686472365),
                   (-6.876682005899029, 12.282371383343161),
                   (-15.603507135507172, 15.2464923804279),
                   (-6.132882973622672, 8.046098172351265)]

        # CIFAR-10 preprocessing: GCN (with L1 norm) and min-max feature scaling to [0,1]
        transform = transforms.Compose([transforms.ToTensor(),
                                        transforms.Lambda(lambda x: global_contrast_normalization(x, scale='l1')),
                                        transforms.Normalize([min_max[normal_class][0]] * 3,
                                                             [min_max[normal_class][1] - min_max[normal_class][0]] * 3),
                                        ClampImageTo01()
                                        ])

        target_transform = transforms.Lambda(lambda x: int(x in self.outlier_classes))

        train_set = MyCIFAR10(root=self.root, train=True, download=True,
                              transform=transform, target_transform=target_transform)
        # Subset train set to normal class
        train_idx_normal = get_target_label_idx(train_set.targets, self.normal_classes)
        self.train_set = Subset(train_set, train_idx_normal)

        self.test_set = MyCIFAR10(root=self.root, train=False, download=True,
                                  transform=transform, target_transform=target_transform)


class MyCIFAR10(CIFAR10):
    """Torchvision CIFAR10 class with patch of __getitem__ method to also return the index of a data sample."""

    def __init__(self, *args, **kwargs):
        super(MyCIFAR10, self).__init__(*args, **kwargs)

    def __getitem__(self, index):
        """Override the original method of the CIFAR10 class.
        Args:
            index (int): Index
        Returns:
            triple: (image, target, index) where target is index of the target class.
        """
        img, target = self.data[index], self.targets[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target, index  # only line changed

In [None]:
from torchvision.datasets import FashionMNIST

class FashionMNIST_Dataset(TorchvisionDataset):

    def __init__(self, root: str, normal_class=0):
        super().__init__(root)

        self.n_classes = 2  # 0: normal, 1: outlier
        self.normal_classes = tuple([normal_class])
        self.outlier_classes = list(range(0, 10))
        self.outlier_classes.remove(normal_class)

        # Pre-computed min and max values (after applying GCN) from train data per class
        min_max = [
            (-2.681241989135742, 24.854305267333984),
            (-2.57785701751709, 11.16978931427002),
            (-2.8081703186035156, 19.133543014526367),
            (-1.9533653259277344, 18.656726837158203),
            (-2.6103854179382324, 19.166683197021484),
            (-1.2358521223068237, 28.46310806274414),
            (-3.251605987548828, 24.19683265686035),
            (-1.0814441442489624, 21.04704475402832),
            (-4.264486789703369, 11.350274085998535),
            (-1.3859288692474365, 11.426652908325195)
        ]

        # MNIST preprocessing: GCN (with L1 norm) and min-max feature scaling to [0,1]
        transform = transforms.Compose([transforms.ToTensor(),
                                        transforms.Lambda(lambda x: global_contrast_normalization(x, scale='l1')),
                                        transforms.Normalize([min_max[normal_class][0]],# * 3,
                                                             [min_max[normal_class][1] - min_max[normal_class][0]]),# * 3)
                                        ClampImageTo01()
                                        ])
        # ])

        target_transform = transforms.Lambda(lambda x: int(x in self.outlier_classes))

        train_set = MyFashionMNIST(root=self.root, train=True, download=True,
                            transform=transform, target_transform=target_transform)
        # Subset train_set to normal class
        train_idx_normal = get_target_label_idx(train_set.targets.clone().data.cpu().numpy(), self.normal_classes)
        self.train_set = Subset(train_set, train_idx_normal)

        self.test_set = MyFashionMNIST(root=self.root, train=False, download=True,
                                transform=transform, target_transform=target_transform)


class MyFashionMNIST(FashionMNIST):
    """Torchvision FashionMNIST class with patch of __getitem__ method to also return the index of a data sample."""

    def __init__(self, *args, **kwargs):
        super(MyFashionMNIST, self).__init__(*args, **kwargs)

    def __getitem__(self, index):
        """Override the original method of the FashionMNIST class.
        Args:
            index (int): Index
        Returns:
            triple: (image, target, index) where target is index of the target class.
        """
        img, target = self.data[index], self.targets[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img.numpy(), mode='L')

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target, index  # only line changed

In [None]:
import time
from sklearn.metrics import roc_auc_score


def load_dataset_svdd(dataset_name, batch_size, data_path="./", normal_class=0):
    if dataset_name == "mnist":
        ds = MNIST_Dataset(data_path, normal_class)
        train_loader, test_loader = ds.loaders(batch_size=batch_size)
    elif dataset_name == "fashion_mnist":
        ds = FashionMNIST_Dataset(data_path, normal_class)
        train_loader, test_loader = ds.loaders(batch_size=batch_size)
    elif dataset_name == "cifar10":
        ds = CIFAR10_Dataset(data_path, normal_class)
        train_loader, test_loader = ds.loaders(batch_size=batch_size)
    else:
        error(f"Unrecognized dataset '{dataset_name}'! Exiting..")
        exit()

    return train_loader, test_loader, ds

# VD-VAE DATASET FUNCTION
def load_dataset(dataset, batch_size):
    if dataset == 'cifar10':
        dataset_transforms = transforms.Compose([
            transforms.RandomRotation(5),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        train_dataset = torchvision.datasets.CIFAR10('data', train=True, download=True, transform=dataset_transforms)
        test_dataset = torchvision.datasets.CIFAR10('data', train=False, download=True, transform=dataset_transforms)
    elif dataset == 'stl10':
        dataset_transforms = transforms.Compose([
            transforms.RandomRotation(5),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        train_dataset = torchvision.datasets.STL10('data', split='train+unlabeled', download=True, transform=dataset_transforms)
        test_dataset = torchvision.datasets.STL10('data', split='test', download=True, transform=dataset_transforms)
    elif dataset == 'mnist':
        dataset_transforms = transforms.Compose([
            transforms.RandomRotation(5),
            transforms.ToTensor(),
            # transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])

        train_dataset = torchvision.datasets.MNIST('data', train=True, download=True, transform=dataset_transforms)
        test_dataset = torchvision.datasets.MNIST('data', train=False, download=True, transform=dataset_transforms)
    else:
        error(f"Unrecognized dataset '{dataset}'! Exiting..")
        exit()

    train_loader = torch.utils.data.DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
    test_loader = torch.utils.data.DataLoader(test_dataset, shuffle=True, batch_size=batch_size)

    return train_loader, test_loader

def compute_center_c(model, loader, optim, device, center_size=128, eps=0.01):
    c = torch.zeros(center_size, device=device)
    model.eval()
    with torch.no_grad():
      n = 0
      for i, (x, _, _) in enumerate(tqdm(loader)):
          optim.zero_grad()
          x = x.to(device)
          _, _, features = model(x)
          c += torch.sum(features, dim=0)
          n += features.shape[0]
      c /= n

    #   c[(abs(c) < eps) & (c < 0)] = -eps
    #   c[(abs(c) < eps) & (c > 0)] = eps

      return c

def vae_loss(x, model, crit, c):
    x = x.to(device)
    y, decoder_kl, features = model(x)
    rl = crit(x, y).mean(dim=(1,2,3))
    rpp = torch.zeros_like(rl)
    for k in decoder_kl:
        rpp += k.sum(dim=(1,2,3))
    rpp /= np.prod(x.shape[1:])
    elbo = (rpp + rl*10).mean()
    svdd_loss = torch.sum((features - c) ** 2, dim=1)
    total_loss = elbo + svdd_loss.mean() * 10
    return y, elbo, rl.mean(), rpp.mean(), svdd_loss.mean(), total_loss

def train(model, loader, optim, crit, c, device):
    total_loss, r_loss, kl_loss = 0.0, 0.0, 0.0
    svdd_loss, elbo_loss = 0.0, 0.0
    start_time = time.time()
    model.train()
    for x, _, _ in tqdm(loader):
        optim.zero_grad()
        _, elbo, rl, kl, sl, totl = vae_loss(x, model, crit, c)
        # elbo.backward()
        totl.backward()
        optim.step()

        # total_loss += elbo
        total_loss += totl
        r_loss += rl
        kl_loss += kl
        svdd_loss += sl
        elbo_loss += elbo

    test_time = time.time() - start_time
    info('Training time: %.3f' % test_time)
    return total_loss / len(loader), r_loss / len(loader), kl_loss / len(loader), svdd_loss / len(loader), elbo_loss / len(loader)

def evaluate(model, loader, optim, crit, c, device, img_id=None):
    sample = model.sample(4)
    save_image(sample, f"imgs/vdvae-sample-{img_id}.png", normalize=True, value_range=(-1, 1))

    total_loss, r_loss, kl_loss = 0.0, 0.0, 0.0
    svdd_loss, elbo_loss = 0.0, 0.0
    start_time = time.time()

    with torch.no_grad():
        model.eval()
        for i, (x, _, _) in enumerate(tqdm(loader)):
            optim.zero_grad()
            y, elbo, rl, kl, sl, totl = vae_loss(x, model, crit, c)

            # total_loss += elbo
            total_loss += totl
            r_loss += rl
            kl_loss += kl
            svdd_loss += sl
            elbo_loss += elbo

            if img_id != None and i == 0:
                save_image(y, f"imgs/eval-recon-{img_id}.png", normalize=True, value_range=(-1, 1))

        test_time = time.time() - start_time
        info('Eval time: %.3f' % test_time)
        return total_loss / len(loader), r_loss / len(loader), kl_loss / len(loader), svdd_loss / len(loader), elbo_loss / len(loader)


def test(model, test_loader, c, R, device):
    # Testing
    info('Starting testing for SVDD...')
    start_time = time.time()
    idx_label_score = []
    model.eval()
    with torch.no_grad():
        for data in tqdm(test_loader):
            inputs, labels, idx = data
            scores = svdd_score(inputs, model, c)

            predicted_labels = scores.cpu().data.numpy() > (R ** 2)
            predicted_labels = predicted_labels * 1
            predicted_labels = predicted_labels.tolist()

            # Save tuples of (idx, label, score, radius_score) in a list
            idx_label_score += list(zip(idx.cpu().data.numpy().tolist(),
                                        labels.cpu().data.numpy().tolist(),
                                        scores.cpu().data.numpy().tolist(),
                                        predicted_labels))

    test_time = time.time() - start_time
    info('Testing time: %.3f' % test_time)

    # Compute AUC
    _, labels, scores, predicted_labels = zip(*idx_label_score)
    labels = np.array(labels)
    scores = np.array(scores)
    predicted_labels = np.array(predicted_labels)

    test_auc = roc_auc_score(labels, scores)
    test_auc_preds = roc_auc_score(labels, predicted_labels)
    warning('Test set AUC: {:.2f}%'.format(100. * test_auc))
    warning('Test set predicitons AUC: {:.2f}%'.format(100. * test_auc_preds))
    info('Finished testing.')
    return idx_label_score, test_auc, test_auc_preds

### Required only if testing architecture

In [None]:
# HPS = get_parameters("mnist")
# test_model = VDVAE_MNIST(
#     HPS.in_channels,
#     HPS.h_width,
#     HPS.m_width,
#     HPS.z_dim,
#     nb_blocks=HPS.nb_blocks,
#     nb_res_blocks=HPS.nb_res_blocks,
#     scale_rate=HPS.scale_rate
# )

HPS = get_parameters("cifar10")
test_model = VDVAE_CIFAR10(
    HPS.in_channels,
    HPS.h_width,
    HPS.m_width,
    HPS.z_dim,
    nb_blocks=HPS.nb_blocks,
    nb_res_blocks=HPS.nb_res_blocks,
    scale_rate=HPS.scale_rate
)

test_model(torch.rand(1,3,32,32))

# torch.onnx.export(test_model, torch.rand(1, 1, 28, 28), 'vdvae.onnx')

### Contents of `plot.py`

In [None]:
import matplotlib.pyplot as plt
from torchvision.utils import make_grid


def plot_images_grid(x: torch.tensor, export_img, title: str = '', nrow=8, padding=2, normalize=False, pad_value=0):
    """Plot 4D Tensor of images of shape (B x C x H x W) as a grid."""

    grid = make_grid(x, nrow=nrow, padding=padding, normalize=normalize, pad_value=pad_value)
    npgrid = grid.cpu().numpy()

    plt.imshow(np.transpose(npgrid, (1, 2, 0)), interpolation='nearest')

    ax = plt.gca()
    ax.xaxis.set_visible(False)
    ax.yaxis.set_visible(False)

    if not (title == ''):
        plt.title(title)

    plt.savefig(export_img, bbox_inches='tight', pad_inches=0.1)

### Main

In [None]:
from copy import deepcopy
# torch.autograd.set_detect_anomaly(True)
######## CHANGE BELOW
normal_class = 5
dataset_name = "cifar10"
HPS = get_parameters(dataset_name)
center_size = 128
######## CHANGE ABOVE
device = get_device(HPS.cuda)
train_loader, test_loader, dataset = load_dataset_svdd(HPS.dataset, HPS.batch_size, normal_class=normal_class)

if dataset_name == "mnist" or dataset_name == 'fashion_mnist':
    model = VDVAE_MNIST(
        HPS.in_channels,
        HPS.h_width,
        HPS.m_width,
        HPS.z_dim,
        nb_blocks=HPS.nb_blocks,
        nb_res_blocks=HPS.nb_res_blocks,
        scale_rate=HPS.scale_rate
    ).to(device)
elif dataset_name == "cifar10":
    model = VDVAE_CIFAR10(
        HPS.in_channels,
        HPS.h_width,
        HPS.m_width,
        HPS.z_dim,
        nb_blocks=HPS.nb_blocks,
        nb_res_blocks=HPS.nb_res_blocks,
        scale_rate=HPS.scale_rate
    ).to(device)

info(f"Number of trainable parameters {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
optim = torch.optim.Adam(model.parameters(), lr=HPS.lr, weight_decay=HPS.decay)
crit = torch.nn.MSELoss(reduction='none')

save_id = str(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))

######## CHANGE BELOW
MAX_EPOCH = 7
######## CHANGE ABOVE
final_epoch_res = []

nb_iterations = 0
c = compute_center_c(model, train_loader, optim, device, center_size=center_size)
best_roc_auc = 0
best_model = None
for ei in tqdm(itertools.count(), desc="Epoch"):
    nb_iterations += len(train_loader)
    train_loss, r_loss, kl_loss, svdd_loss, elbo_loss = train(model, train_loader, optim, crit, c, device)
    info(f"training, epoch {ei+1} \t iter: {nb_iterations} \t loss: {train_loss} \t r_loss {r_loss} \t kl_loss {kl_loss} \t svdd_loss {svdd_loss} \t elbo_loss {elbo_loss}")
    final_epoch_res.append(f"Epoch {ei + 1}")
    final_epoch_res.append(f"> training, epoch {ei+1} \t iter: {nb_iterations} \t loss: {train_loss} \t r_loss {r_loss} \t kl_loss {kl_loss} \t svdd_loss {svdd_loss} \t elbo_loss {elbo_loss}")

    # Update center after training
    # c = compute_center_c(model, train_loader, optim, device, center_size=center_size)

    # Eval on test set
    eval_loss, r_loss, kl_loss, svdd_loss, elbo_loss = evaluate(model, test_loader, optim, crit, c, device, img_id=str(ei).zfill(4))
    info(f"evaluate, epoch {ei+1} \t iter: {nb_iterations} \t loss: {eval_loss} \t r_loss {r_loss} \t kl_loss {kl_loss} \t svdd_loss {svdd_loss} \t elbo_loss {elbo_loss}")
    final_epoch_res.append(f"> evaluate, epoch {ei+1} \t iter: {nb_iterations} \t loss: {eval_loss} \t r_loss {r_loss} \t kl_loss {kl_loss} \t svdd_loss {svdd_loss} \t elbo_loss {elbo_loss}")

    # Update R and test AUC
    R = update_radius(train_loader, model, c, device, nu=0.05)
    warning(f"Updated radius after training: {R}")
    final_epoch_res.append(f"$ new radius R={R}")


    _, test_auc, test_auc_preds = test(model, test_loader, c, R, device)
    final_epoch_res.append(f"$ Test set ROC AUC: {test_auc}\n$ Test set predictions ROC AUC: {test_auc_preds}\n\n")
    if test_auc > best_roc_auc:
        best_model = deepcopy(model)
        best_roc_auc = test_auc

    if HPS.checkpoint > 0 and ei > 0 and ei % HPS.checkpoint == 0:
        torch.save(model.state_dict(), f"saved_checkpoints/{save_id}-vdvae-{str(ei).zfill(4)}.pt")

    if nb_iterations > HPS.nb_iterations or ei + 1 >= MAX_EPOCH:
        info("Maximum iterations reached. Exiting..")
        break

# Print to file and save the model
with open(f"./results/results_epochs_class_{normal_class}.txt", "w") as output_file:
    print("\n".join(final_epoch_res), file=output_file)

torch.save(best_model.state_dict(), f"saved_checkpoints/{save_id}-vdvae-final.pt")

# Testing SVDD
R = update_radius(train_loader, best_model, c, device, nu=0.05)
# c = compute_center_c(best_model, train_loader, optim, device, center_size=center_size)
warning(f"Updated radius after training: {R}")
test_results, _, _ = test(best_model, test_loader, c, R, device)

indices, labels, scores, _ = zip(*test_results)
indices, labels, scores = np.array(indices), np.array(labels), np.array(scores)
idx_sorted = indices[labels == 0][np.argsort(scores[labels == 0])]  # sorted from lowest to highest anomaly score

classes = {
    'fashion_mnist': {
        0: '0_tshirt',
        1: '1_trouser',
        2: '2_pullover',
        3: '3_dress',
        4: '4_coat',
        5: '5_sandal',
        6: '6_shirt',
        7: '7_sneaker',
        8: '8_bag',
        9: '9_ankleboot'
    },
    'mnist': {
        0: "0",
        1: "1",
        2: "2",
        3: "3",
        4: "4",
        5: "5",
        6: "6",
        7: "7",
        8: "8",
        9: "9"
    },
    "cifar10": {
        0: "0_airplane",
        1: "1_automobile",
        2: "2_bird",
        3: "3_cat",
        4: "4_deer",
        5: "5_dog",
        6: "6_frog",
        7: "7_horse",
        8: "8_ship",
        9: "9_truck",
    }
}

if HPS.dataset in ('mnist', "fashion_mnist", 'cifar10'):
    if HPS.dataset == 'mnist' or HPS.dataset == "fashion_mnist":
        X_normals = dataset.test_set.data[idx_sorted[:32], ...].unsqueeze(1)
        X_outliers = dataset.test_set.data[idx_sorted[-32:], ...].unsqueeze(1)
    if HPS.dataset == 'cifar10':
        X_normals = torch.tensor(np.transpose(dataset.test_set.data[idx_sorted[:32], ...], (0, 3, 1, 2)))
        X_outliers = torch.tensor(np.transpose(dataset.test_set.data[idx_sorted[-32:], ...], (0, 3, 1, 2)))
    # plot_images_grid(X_normals, export_img="./results/normals", title='Most normal examples', padding=2)
    # plot_images_grid(X_outliers, export_img="./results/outliers", title='Most anomalous examples', padding=2)
    plot_images_grid(X_normals, export_img=f"./results/normals_{classes[dataset_name][normal_class]}", title='Most normal examples', padding=2)
    plot_images_grid(X_outliers, export_img=f"./results/outliers_{classes[dataset_name][normal_class]}", title='Most anomalous examples', padding=2)

In [None]:
# !rm -r /content/results/model_*
!rm -r /content/results.zip
!zip -r /content/results.zip /content/results

In [None]:
# sa scot partea cu eps din cod

### Testing site

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vae = VAE(3, 64, 32, 32, nb_blocks=6).to(device)
x = torch.randn(1, 3, 256, 256).to(device)
y, kls = vae(x)