# Apply RVAE for FashionMNIST data set
* <b>Objective:</b> In this problem, the purpose is trian a robust varational autoencoder when the training is polluted with outliers. Here we chose shoes and sneakers as inliers classes and samples from other categories as outliers. Since these images contain a significant range of gray scales, we chose the Gaussian model. 

In [4]:
from __future__ import print_function
import argparse
import torch
import math
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
import numpy as np
from sklearn.model_selection import train_test_split
from keras.datasets import fashion_mnist
import scipy.io as spio
import os
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, roc_auc_score

# Define parameters
seed:random seed 
epochs:number of epochs
CODE_SIZE: z dimention
SIGMA:constant variance for Guassian loss function
batch_size:batch size for training
log_interval:how many batches to wait before logging training status

In [8]:
seed = 10004
epochs = 150 
batch_size = 120
log_interval = 10
CODE_SIZE = 20
SIGMA = 0.5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Creat data for test loader and train loader
input: anomoly percentage
output: train_loader, test_loader

In [10]:
def create_data(frac_anom):

    torch.manual_seed(seed=seed)
    np.random.seed(seed=seed)

    (X, X_lab), (_test_images, _test_lab) = fashion_mnist.load_data()
    X_lab = np.array(X_lab)

    # find other categories
    ind = np.isin(X_lab, (0, 1, 2, 3, 4, 5, 6, 8))  #(1, 5, 7, 9)
    X_lab_outliers = X_lab[ind]
    X_outliers = X[ind]

    # find sneaker and ankle boots
    ind = np.isin(X_lab, (7, 9))  # (0, 2, 3, 4, 6))  #
    X_lab = X_lab[ind]
    X = X[ind]

    #normalize the data
    X = X / 255.0
    X_outliers = X_outliers / 255.0

    # add ouliers to the data the label for outliers is 10
    Nsamp = np.int(np.rint(len(X) * frac_anom)) + 1
    X[:Nsamp, :, :] = X_outliers[:Nsamp, :, :]
    X_lab[:Nsamp] = 10

    #split data to train and test
    X_train, X_test, X_lab_train, X_lab_test = train_test_split(
        X, X_lab, test_size=0.33, random_state=10003)
    X_train = np.expand_dims(X_train, axis=1)
    X_test = np.expand_dims(X_test, axis=1)


    
    #append samples and labels
    train_data = []
    for i in range(len(X_train)):
        train_data.append(
            [torch.from_numpy(X_train[i]).float(), X_lab_train[i]])

    test_data = []
    for i in range(len(X_test)):
        test_data.append(
            [torch.from_numpy(X_test[i]).float(), X_lab_test[i]])
        
        
    #generate train loader and test loader
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=batch_size,
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_data,
                                              batch_size=len(test_data),
                                              shuffle=False)

    return train_loader, test_loader

# Define MSE loss and beta loss for Guassian posterior


In [13]:
#MSE loss
def MSE_loss(Y, X):
    ret = (X - Y)**2
    ret = torch.sum(ret,1)
    return ret

#beta loss
def Gaussian_CE_loss(Y, X, beta, sigma=SIGMA):  # 784 for mnist
    Dim = Y.shape[1]
    const1 = -((1 + beta) / beta)
    const2 = 1 / pow((2 * math.pi * (sigma**2)), (beta * Dim / 2))
    MSE = MSE_loss(Y, X)
    term1 = torch.exp(-(beta / (2 * (sigma**2))) * MSE)
    loss = torch.sum(const1 * (const2* term1 - 1))
    return loss


def beta_loss_function(recon_x, x, mu, logvar, beta):

    if beta > 0:
        # If beta is nonzero, use the beta entropy
        BBCE = Gaussian_CE_loss(recon_x.view(-1, 784), x.view(-1, 784), beta)
    else:
        # if beta is zero use binary cross entropy
        BBCE = torch.sum(MSE_loss(recon_x.view(-1, 784), x.view(-1, 784)))

    # compute KL divergence
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BBCE + KLD

# Define network

In [14]:
class RVAE(nn.Module):
    def __init__(self):
        super(RVAE, self).__init__()

        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, CODE_SIZE)
        self.fc22 = nn.Linear(400, CODE_SIZE)
        self.fc3 = nn.Linear(CODE_SIZE, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

    
    # for reseting network weights
    def weight_reset(self):

        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                m.reset_parameters()

# Define model reset
This function calls weight_reset from the network class and reset the weights of the network.

In [17]:
def model_reset():
    model.weight_reset()

# Define model and optimizer

In [18]:
model = RVAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Training function
input: number of epochs, value of beta

Prints loss after each log intervalof bathces and after each epoch

In [19]:
def train(epoch, beta_val):
    model.train()
    train_loss = 0
    for batch_idx, (data, data_lab) in enumerate(train_loader):
   
        data = (data).to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = beta_loss_function(recon_batch, data, mu, logvar, beta=beta_val)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss / len(train_loader.dataset)))

# Testing function
input: number of epochs, value of beta
out put: total loss for the test samples, loss for the inlier samples in the test, loss for the out lier samples in the test
Saves the resconstruction of 8 random samples(4 inliers, 4 outliers)
Saves data, reconstruction and anomolies lables in a npz file

In [25]:
def test(frac_anom, beta_val):
    model.eval()
    test_loss_total = 0
    test_loss_anom = 0
    num_anom = 0
    with torch.no_grad():
        for i, (data, data_lab) in enumerate(test_loader):
        
            data = (data).to(device)
            recon_batch, mu, logvar = model(data)
            anom_lab = data_lab == 10
            num_anom += np.sum(anom_lab.numpy())  # count number of anomalies
            anom_lab = (anom_lab[:, None].float()).to(device)

            test_loss_anom += torch.sum(MSE_loss(recon_batch * anom_lab,
                                      data * anom_lab)).item()
            test_loss_total += torch.sum(MSE_loss(recon_batch, data)).item()

            if i == 0:
                n = min(data.size(0), 100)
                samp=[4, 14, 50, 60, 25, 29, 32, 65]
                comparison = torch.cat([
                    data.view(len(recon_batch), 1, 28, 28)[samp],
                    recon_batch.view(len(recon_batch), 1, 28, 28)[samp]
                ])
                save_image(comparison.cpu(),
                           'results/fashion_mnist_recon_shallow_' +
                           str(beta_val) + '_' + str(frac_anom) + '.png',
                           nrow=n)

        np.savez('results/fashion_mnist_' + str(beta_val) + '_' +
                 str(frac_anom) + '.npz',
                 recon=recon_batch.cpu(),
                 data=data.cpu(),
                 anom_lab=anom_lab.cpu())

    test_loss_normals = (test_loss_total - test_loss_anom) / (
        len(test_loader.dataset) - num_anom)
    test_loss_anom /= num_anom
    test_loss_total /= len(test_loader.dataset)

    print('====> Test set loss: {:.4f}'.format(test_loss_total))

    return test_loss_total, test_loss_anom, test_loss_normals


# Main function

Runs training and testing for a givern values of beta and percentage of anomolies

In [26]:
if __name__ == "__main__":

    brange=[0,0.01]
    erange = range(1, epochs + 1)
    anrange = np.array([0.01,0.05,0.1])
    
    test_loss_total = np.zeros((len(anrange), len(brange)))
    test_loss_anom = np.zeros((len(anrange), len(brange)))
    test_loss_normals = np.zeros((len(anrange), len(brange)))

    for b, betaval in enumerate(brange):

        for a, frac_anom in enumerate(anrange):
            train_loader, test_loader = create_data(frac_anom)
            model_reset()
            for epoch in erange:

                train(epoch, beta_val=betaval)

                print('epoch: %d, beta=%g, frac_anom=%g' %
                      (epoch, betaval, frac_anom))

            # save the model
            torch.save(model, 'fashion_mnist_beta_shallow_' + str(betaval) + '_frac_anom_' + str(frac_anom))

            test_loss_total[a, b], test_loss_anom[a, b], test_loss_normals[
                a, b] = test(frac_anom, beta_val=betaval)



        np.savez('test_loss_fashionmnist_beta_shallow' + str(b) + '.npz',
                 test_loss_total=test_loss_total,
                 test_loss_anom=test_loss_anom,
                 test_loss_normals=test_loss_normals,
                 brange=brange,
                 anrange=anrange)


====> Epoch: 1 Average loss: 39.2259
epoch: 1, beta=0, frac_anom=0.01
====> Epoch: 2 Average loss: 28.3701
epoch: 2, beta=0, frac_anom=0.01
====> Epoch: 3 Average loss: 26.8836
epoch: 3, beta=0, frac_anom=0.01
====> Epoch: 4 Average loss: 25.4160
epoch: 4, beta=0, frac_anom=0.01
====> Epoch: 5 Average loss: 24.3417
epoch: 5, beta=0, frac_anom=0.01
====> Epoch: 6 Average loss: 23.6293
epoch: 6, beta=0, frac_anom=0.01
====> Epoch: 7 Average loss: 23.0973
epoch: 7, beta=0, frac_anom=0.01
====> Epoch: 8 Average loss: 22.6937
epoch: 8, beta=0, frac_anom=0.01
====> Epoch: 9 Average loss: 22.4618
epoch: 9, beta=0, frac_anom=0.01
====> Epoch: 10 Average loss: 22.0700
epoch: 10, beta=0, frac_anom=0.01
====> Epoch: 11 Average loss: 21.8762
epoch: 11, beta=0, frac_anom=0.01
====> Epoch: 12 Average loss: 21.6056
epoch: 12, beta=0, frac_anom=0.01
====> Epoch: 13 Average loss: 21.3931
epoch: 13, beta=0, frac_anom=0.01
====> Epoch: 14 Average loss: 21.2843
epoch: 14, beta=0, frac_anom=0.01
====> Epoc

====> Epoch: 21 Average loss: 20.5569
epoch: 21, beta=0, frac_anom=0.01
====> Epoch: 22 Average loss: 20.5013
epoch: 22, beta=0, frac_anom=0.01
====> Epoch: 23 Average loss: 20.4735
epoch: 23, beta=0, frac_anom=0.01
====> Epoch: 24 Average loss: 20.3910
epoch: 24, beta=0, frac_anom=0.01
====> Epoch: 25 Average loss: 20.3791
epoch: 25, beta=0, frac_anom=0.01
====> Epoch: 26 Average loss: 20.2877
epoch: 26, beta=0, frac_anom=0.01
====> Epoch: 27 Average loss: 20.2803
epoch: 27, beta=0, frac_anom=0.01
====> Epoch: 28 Average loss: 20.2325
epoch: 28, beta=0, frac_anom=0.01
====> Epoch: 29 Average loss: 20.2080
epoch: 29, beta=0, frac_anom=0.01
====> Epoch: 30 Average loss: 20.1588
epoch: 30, beta=0, frac_anom=0.01
====> Epoch: 31 Average loss: 20.0709
epoch: 31, beta=0, frac_anom=0.01
====> Epoch: 32 Average loss: 20.0877
epoch: 32, beta=0, frac_anom=0.01
====> Epoch: 33 Average loss: 20.0528
epoch: 33, beta=0, frac_anom=0.01
====> Epoch: 34 Average loss: 20.0478
epoch: 34, beta=0, frac_an

====> Epoch: 41 Average loss: 19.8616
epoch: 41, beta=0, frac_anom=0.01
====> Epoch: 42 Average loss: 19.8321
epoch: 42, beta=0, frac_anom=0.01
====> Epoch: 43 Average loss: 19.8282
epoch: 43, beta=0, frac_anom=0.01
====> Epoch: 44 Average loss: 19.7745
epoch: 44, beta=0, frac_anom=0.01
====> Epoch: 45 Average loss: 19.7748
epoch: 45, beta=0, frac_anom=0.01
====> Epoch: 46 Average loss: 19.7780
epoch: 46, beta=0, frac_anom=0.01
====> Epoch: 47 Average loss: 19.6793
epoch: 47, beta=0, frac_anom=0.01
====> Epoch: 48 Average loss: 19.6789
epoch: 48, beta=0, frac_anom=0.01
====> Epoch: 49 Average loss: 19.6680
epoch: 49, beta=0, frac_anom=0.01
====> Epoch: 50 Average loss: 19.6638
epoch: 50, beta=0, frac_anom=0.01
====> Epoch: 51 Average loss: 19.6253
epoch: 51, beta=0, frac_anom=0.01
====> Epoch: 52 Average loss: 19.5937
epoch: 52, beta=0, frac_anom=0.01
====> Epoch: 53 Average loss: 19.5873
epoch: 53, beta=0, frac_anom=0.01
====> Epoch: 54 Average loss: 19.5971
epoch: 54, beta=0, frac_an

====> Epoch: 61 Average loss: 19.4878
epoch: 61, beta=0, frac_anom=0.01
====> Epoch: 62 Average loss: 19.4299
epoch: 62, beta=0, frac_anom=0.01
====> Epoch: 63 Average loss: 19.4680
epoch: 63, beta=0, frac_anom=0.01
====> Epoch: 64 Average loss: 19.4451
epoch: 64, beta=0, frac_anom=0.01
====> Epoch: 65 Average loss: 19.4532
epoch: 65, beta=0, frac_anom=0.01
====> Epoch: 66 Average loss: 19.3805
epoch: 66, beta=0, frac_anom=0.01
====> Epoch: 67 Average loss: 19.4546
epoch: 67, beta=0, frac_anom=0.01
====> Epoch: 68 Average loss: 19.4163
epoch: 68, beta=0, frac_anom=0.01
====> Epoch: 69 Average loss: 19.3722
epoch: 69, beta=0, frac_anom=0.01
====> Epoch: 70 Average loss: 19.3607
epoch: 70, beta=0, frac_anom=0.01
====> Epoch: 71 Average loss: 19.3717
epoch: 71, beta=0, frac_anom=0.01
====> Epoch: 72 Average loss: 19.3071
epoch: 72, beta=0, frac_anom=0.01
====> Epoch: 73 Average loss: 19.3606
epoch: 73, beta=0, frac_anom=0.01
====> Epoch: 74 Average loss: 19.3573
epoch: 74, beta=0, frac_an

====> Epoch: 81 Average loss: 19.2203
epoch: 81, beta=0, frac_anom=0.01
====> Epoch: 82 Average loss: 19.2381
epoch: 82, beta=0, frac_anom=0.01
====> Epoch: 83 Average loss: 19.2246
epoch: 83, beta=0, frac_anom=0.01
====> Epoch: 84 Average loss: 19.2457
epoch: 84, beta=0, frac_anom=0.01
====> Epoch: 85 Average loss: 19.2018
epoch: 85, beta=0, frac_anom=0.01
====> Epoch: 86 Average loss: 19.2302
epoch: 86, beta=0, frac_anom=0.01
====> Epoch: 87 Average loss: 19.1801
epoch: 87, beta=0, frac_anom=0.01
====> Epoch: 88 Average loss: 19.2082
epoch: 88, beta=0, frac_anom=0.01
====> Epoch: 89 Average loss: 19.2073
epoch: 89, beta=0, frac_anom=0.01
====> Epoch: 90 Average loss: 19.1982
epoch: 90, beta=0, frac_anom=0.01
====> Epoch: 91 Average loss: 19.1470
epoch: 91, beta=0, frac_anom=0.01
====> Epoch: 92 Average loss: 19.1569
epoch: 92, beta=0, frac_anom=0.01
====> Epoch: 93 Average loss: 19.1497
epoch: 93, beta=0, frac_anom=0.01
====> Epoch: 94 Average loss: 19.1063
epoch: 94, beta=0, frac_an

====> Epoch: 100 Average loss: 19.1568
epoch: 100, beta=0, frac_anom=0.01
====> Epoch: 101 Average loss: 19.0891
epoch: 101, beta=0, frac_anom=0.01
====> Epoch: 102 Average loss: 19.0568
epoch: 102, beta=0, frac_anom=0.01
====> Epoch: 103 Average loss: 19.0749
epoch: 103, beta=0, frac_anom=0.01
====> Epoch: 104 Average loss: 19.0786
epoch: 104, beta=0, frac_anom=0.01
====> Epoch: 105 Average loss: 19.0770
epoch: 105, beta=0, frac_anom=0.01
====> Epoch: 106 Average loss: 19.0913
epoch: 106, beta=0, frac_anom=0.01
====> Epoch: 107 Average loss: 19.0682
epoch: 107, beta=0, frac_anom=0.01
====> Epoch: 108 Average loss: 19.0853
epoch: 108, beta=0, frac_anom=0.01
====> Epoch: 109 Average loss: 19.1149
epoch: 109, beta=0, frac_anom=0.01
====> Epoch: 110 Average loss: 19.0306
epoch: 110, beta=0, frac_anom=0.01
====> Epoch: 111 Average loss: 19.0336
epoch: 111, beta=0, frac_anom=0.01
====> Epoch: 112 Average loss: 19.0263
epoch: 112, beta=0, frac_anom=0.01
====> Epoch: 113 Average loss: 19.0863

====> Epoch: 120 Average loss: 19.0380
epoch: 120, beta=0, frac_anom=0.01
====> Epoch: 121 Average loss: 19.0133
epoch: 121, beta=0, frac_anom=0.01
====> Epoch: 122 Average loss: 18.9906
epoch: 122, beta=0, frac_anom=0.01
====> Epoch: 123 Average loss: 18.9691
epoch: 123, beta=0, frac_anom=0.01
====> Epoch: 124 Average loss: 18.9731
epoch: 124, beta=0, frac_anom=0.01
====> Epoch: 125 Average loss: 18.9307
epoch: 125, beta=0, frac_anom=0.01
====> Epoch: 126 Average loss: 18.9868
epoch: 126, beta=0, frac_anom=0.01
====> Epoch: 127 Average loss: 18.9555
epoch: 127, beta=0, frac_anom=0.01
====> Epoch: 128 Average loss: 18.9628
epoch: 128, beta=0, frac_anom=0.01
====> Epoch: 129 Average loss: 18.9522
epoch: 129, beta=0, frac_anom=0.01
====> Epoch: 130 Average loss: 18.9494
epoch: 130, beta=0, frac_anom=0.01
====> Epoch: 131 Average loss: 18.9528
epoch: 131, beta=0, frac_anom=0.01
====> Epoch: 132 Average loss: 18.9643
epoch: 132, beta=0, frac_anom=0.01
====> Epoch: 133 Average loss: 18.9199

====> Epoch: 139 Average loss: 18.9412
epoch: 139, beta=0, frac_anom=0.01
====> Epoch: 140 Average loss: 18.9418
epoch: 140, beta=0, frac_anom=0.01
====> Epoch: 141 Average loss: 18.8667
epoch: 141, beta=0, frac_anom=0.01
====> Epoch: 142 Average loss: 18.8705
epoch: 142, beta=0, frac_anom=0.01
====> Epoch: 143 Average loss: 18.8674
epoch: 143, beta=0, frac_anom=0.01
====> Epoch: 144 Average loss: 18.8869
epoch: 144, beta=0, frac_anom=0.01
====> Epoch: 145 Average loss: 18.9162
epoch: 145, beta=0, frac_anom=0.01
====> Epoch: 146 Average loss: 18.8922
epoch: 146, beta=0, frac_anom=0.01
====> Epoch: 147 Average loss: 18.9467
epoch: 147, beta=0, frac_anom=0.01
====> Epoch: 148 Average loss: 18.8889
epoch: 148, beta=0, frac_anom=0.01
====> Epoch: 149 Average loss: 18.9041
epoch: 149, beta=0, frac_anom=0.01
====> Epoch: 150 Average loss: 18.8629
epoch: 150, beta=0, frac_anom=0.01


RuntimeError: The size of tensor a (28) must match the size of tensor b (3960) at non-singleton dimension 2

In [None]:
# Colors and legends for ROC


In [None]:
cwd = os.getcwd()

FPRs = dict()
TPRs = dict()
AUC = dict()

lgd = {
    0: 'VAE-1%',
    1: 'VAE-5%',
    2: 'VAE-10%',
    3: 'RVAE-1%',
    4: 'RVAE-5%',
    5: 'RVAE-10%'
}
colors = {0: 'r', 1: 'b', 2: 'k', 3: 'r', 4: 'b', 5: 'k'}
lsty = {0: '--', 1: '--', 2: '--', 3: '-', 4: '-', 5: '-'}
c = 0