In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.autograd import Variable

import time
import pickle
import numpy as np
import argparse
import datetime
import itertools
import seaborn as sns
from sklearn import metrics
from matplotlib import pyplot as plt

import modules
from utils import dataset
from imp import reload
from modules import model

In [2]:
device = torch.device('cuda')

In [3]:
learning_rate = 1e-4
beta1 = 0.5
beta2 = 0.999
num_epochs = 10
latent_size_CIFAR10 = 100
latent_size_MNIST = 200
acc_lam = {}

# CIFAR10

In [None]:
normalize_CIFAR10 = torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
transform_CIFAR10 = torchvision.transforms.Compose([torchvision.transforms.RandomHorizontalFlip(),
                                                    torchvision.transforms.RandomVerticalFlip(),
                                                    torchvision.transforms.ToTensor(),
                                                    normalize_CIFAR10])

In [None]:
train_CIFAR10 = torchvision.datasets.CIFAR10('./data/CIFAR10/', train=True, 
                                            transform=transform_CIFAR10, target_transform=None, download=False)
test_CIFAR10 = torchvision.datasets.CIFAR10('./data/CIFAR10/', train=False, 
                                            transform=transform_CIFAR10, target_transform=None, download=False)

In [None]:
idx = torch.tensor(train_CIFAR10.targets) != torch.tensor(0)
dset_train = torch.utils.data.dataset.Subset(train_CIFAR10, np.where(idx==True)[0])
dset_train_anomalous = torch.utils.data.dataset.Subset(train_CIFAR10, np.where(idx==False)[0])
trainloader = torch.utils.data.DataLoader(dset_train, batch_size=1, shuffle=True)
print (len(trainloader))

In [None]:
idx = torch.tensor(test_CIFAR10.targets) != torch.tensor(0)
dset_test = torch.utils.data.dataset.Subset(test_CIFAR10, np.where(idx==True)[0])
dset_test_anomalous = torch.utils.data.dataset.Subset(test_CIFAR10, np.where(idx==False)[0])
testloader_normal = torch.utils.data.DataLoader(dset_test, batch_size=1, shuffle=True)
testloader_anomalous = torch.utils.data.ConcatDataset([dset_train_anomalous, dset_test_anomalous])
testloader_anomalous = torch.utils.data.DataLoader(testloader_anomalous, batch_size=1, shuffle=True)
print (len(trainloader))
print (len(testloader_normal), len(testloader_anomalous))

In [None]:
dis_BCElogit_criterion = nn.BCEWithLogitsLoss()
dis_criterion = nn.CrossEntropyLoss()
aen_criterion = nn.MSELoss()  

In [None]:
enc = model.Encoder_CIFAR10(latent_size_CIFAR10)
gen = model.Generator_CIFAR10(latent_size_CIFAR10)
dis_xz = model.Discriminator_xz_CIFAR10(latent_size_CIFAR10, 0.2)
dis_xx = model.Discriminator_xx_CIFAR10(latent_size_CIFAR10, 0.2)
dis_zz = model.Discriminator_zz_CIFAR10(latent_size_CIFAR10, 0.2)

In [None]:
gen.load_state_dict(torch.load('./models/CIFAR10/Gaussian/9/CIFAR1010epochs2020-03-12-19-13-39G'))
enc.load_state_dict(torch.load('./models/CIFAR10/Gaussian/9/CIFAR1010epochs2020-03-12-19-13-39E'))
dis_xz.load_state_dict(torch.load('./models/CIFAR10/Gaussian/9/CIFAR1010epochs2020-03-12-19-13-39D_xz'))
dis_xx.load_state_dict(torch.load('./models/CIFAR10/Gaussian/9/CIFAR1010epochs2020-03-12-19-13-39D_xx'))
dis_zz.load_state_dict(torch.load('./models/CIFAR10/Gaussian/9/CIFAR1010epochs2020-03-12-19-13-39D_zz'))

In [None]:
dis_xz.to(device)
dis_xx.to(device)
dis_zz.to(device)
enc.to(device)
gen.to(device);

In [None]:
lam = 0.1
loss_neg = torch.zeros((len(testloader_normal),1)).cuda()
loss_pos = torch.zeros((len(testloader_anomalous),1)).cuda()
c_neg = c_pos = 0
for step, (images, labels) in enumerate(testloader_normal, 0):
    images = images.view(-1, 3, 32, 32)
    dis_xz.eval()
    dis_xx.eval()
    dis_zz.eval()
    enc.eval()
    gen.eval()
    x_real_test = images.cuda()
    z_random = torch.randn(images.shape[0], latent_size_CIFAR10, 1, 1).cuda()
    z_gen = enc(x_real_test)
    x_gen = gen(z_random)
    rec_x = gen(z_gen)
    rec_z = enc(x_gen)
    
    l_gen, _ = dis_xz(x_real_test, z_gen)
    l_enc, _ = dis_xz(x_gen, z_random)
    
    x_logit_real, inter_layer_inp = dis_xx(x_real_test, x_real_test)
    x_logit_fake, inter_layer_rct = dis_xx(x_real_test, rec_x)
   
    fm = inter_layer_inp - inter_layer_rct
    feature_loss = torch.norm(fm, 1, keepdim=False) 
    feature_loss = feature_loss.squeeze()

    loss_neg[c_neg] = feature_loss.detach()
    c_neg += 1


for step, (images, labels) in enumerate(testloader_anomalous, 0):
    images = images.view(-1, 3, 32, 32)
    dis_xz.eval()
    dis_xx.eval()
    dis_zz.eval()
    enc.eval()
    gen.eval()
    x_real_test = images.cuda()
    z_random = torch.randn(images.shape[0], latent_size_CIFAR10, 1, 1).cuda()
    z_gen = enc(x_real_test)
    x_gen = gen(z_random)
    rec_x = gen(z_gen)
    rec_z = enc(x_gen)
    
    l_gen, _ = dis_xz(x_real_test, z_gen)
    l_enc, _ = dis_xz(x_gen, z_random)
    
    x_logit_real, inter_layer_inp = dis_xx(x_real_test, x_real_test)
    x_logit_fake, inter_layer_rct = dis_xx(x_real_test, rec_x)
   
    fm = inter_layer_inp - inter_layer_rct
    feature_loss = torch.norm(fm, 1, keepdim=False) 
    feature_loss = feature_loss.squeeze()
    
    loss_pos[c_pos] = feature_loss.detach()
    c_pos += 1

print ('mean negative: %0.4f, std negative: %0.4f' %(torch.mean(loss_neg), torch.std(loss_neg)))
print ('mean positive: %0.4f, std positive: %0.4f' %(torch.mean(loss_pos), torch.std(loss_pos)))

In [None]:
x1 = loss_neg.cpu().numpy()
x2 = loss_pos.cpu().numpy()
sns.distplot(x1, hist=False, kde=True, kde_kws={'linewidth': 3}, label='Normal')
sns.distplot(x2, hist=False, kde=True, kde_kws={'linewidth': 3}, label='Anomalous')
plt.title('Distribution of normal and abnormal samples')
plt.xlabel('Anomaly Score');

FP = TP = []
neg_pre_wrong = 0
for i in range(len(loss_neg)):
    if loss_neg[i] > 120:
        neg_pre_wrong += 1

pos_pre_wrong = 0
for i in range(len(loss_pos)):
    if loss_pos[i] <= 120:
        pos_pre_wrong += 1
print ("number of normal samples missclassified: %d, number of anomalous samples missclassified: %d" 
       %(neg_pre_wrong, pos_pre_wrong))
tp = (len(loss_pos) - pos_pre_wrong)
fn = pos_pre_wrong
fp = neg_pre_wrong
tn = len(loss_neg) - neg_pre_wrong
precision = tp / (tp + fp)
## recall / sensitivity / True Positive Rate
recall = tp / (tp + fn)
## False Positive Rate / 1 - Specificity
fp_rate = fp / (fp + tn)
specificity = tn / (tn + fp)
f1 = 2 * ((precision * recall)/(precision + recall))
accuracy = (tp + tn) / (tp + tn + fp + fn)
acc_lam[lam] = accuracy
print ("tp: %d, fp: %d, fn: %d, tn: %d" %(tp, fp, fn, tn))
print ("precision: %.5f, recall: %.5f, specificity: %.5f, f1: %.5f, fp_rate: %.5f, accuracy: %.5f" 
       %(precision, recall, specificity, f1, fp_rate, accuracy))
anomalous = torch.ones((len(loss_pos), 1))
normal = torch.zeros((len(loss_neg), 1))
y = torch.cat((anomalous, normal), 0)
scores = torch.cat((loss_pos, loss_neg), 0)
fpr, tpr, thresholds = metrics.roc_curve(y.cpu(), scores.cpu())
prec, rec, _ = metrics.precision_recall_curve(y.cpu(), scores.cpu())
average_precision = metrics.average_precision_score(y.cpu(), scores.cpu())
auc = metrics.auc(fpr, tpr)
print ('AUC', auc)
print ('average precision :', average_precision)
plt.figure()
plt.plot(fpr, tpr, color='darkorange', label='ROC curve (area = %0.2f)' % auc)
plt.plot([0.0, 1.0], color='navy', linestyle='--')
plt.xlim([-0.01, 1.0])
plt.ylim([-0.01, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.legend(loc="lower right")
plt.show();

# MNIST

In [None]:
normalize_MNIST = torchvision.transforms.Normalize((0.5, ), (0.5, ))
transform_MNIST = torchvision.transforms.Compose([torchvision.transforms.RandomHorizontalFlip(),
                                                  torchvision.transforms.RandomVerticalFlip(),
                                                  torchvision.transforms.ToTensor(), normalize_MNIST])

In [None]:
train_MNIST = torchvision.datasets.MNIST('./data/MNIST/', train=True,
                                         transform=transform_MNIST, target_transform=None, download=False)
test_MNIST = torchvision.datasets.MNIST('./data/MNIST/', train=False,
                                            transform=transform_MNIST, target_transform=None, download=False)

In [None]:
idx = torch.as_tensor(train_MNIST.targets) != torch.tensor(0)
dset_train = (torch.utils.data.dataset.Subset(train_MNIST, np.where(idx != 0)[0]))
dset_train_anomalous = torch.utils.data.dataset.Subset(train_MNIST, np.where(idx == 0)[0])

trainloader = torch.utils.data.DataLoader(dset_train, batch_size=1, shuffle=True)
print (len(trainloader))

In [None]:
idx = torch.as_tensor(test_MNIST.targets) != torch.tensor(0)
dset_test = torch.utils.data.dataset.Subset(test_MNIST, np.where(idx != 0)[0])
dset_test_anomalous = torch.utils.data.dataset.Subset(test_MNIST, np.where(idx == 0)[0])
testloader_normal = torch.utils.data.DataLoader(dset_test, batch_size=1, shuffle=True)
testloader_anomalous = torch.utils.data.ConcatDataset([dset_train_anomalous, dset_test_anomalous])
testloader_anomalous = torch.utils.data.DataLoader(testloader_anomalous, batch_size=1, shuffle=True)
print (len(testloader_normal))
print (len(testloader_anomalous))

In [None]:
dis_BCElogit_criterion = nn.BCEWithLogitsLoss()
dis_criterion = nn.CrossEntropyLoss()
aen_criterion = nn.MSELoss()

In [None]:
enc = model.Encoder_MNIST(latent_size_MNIST)
gen = model.Generator_MNIST(latent_size_MNIST)
dis_xz = model.Discriminator_xz_MNIST(latent_size_MNIST, 0.2)
dis_xx = model.Discriminator_xx_MNIST(latent_size_MNIST, 0.2)
dis_zz = model.Discriminator_zz_MNIST(latent_size_MNIST, 0.2)

In [None]:
gen.load_state_dict(torch.load('./models/MNIST/Gaussian/1/MNIST10epochs2020-03-24-23-34-04G'))
enc.load_state_dict(torch.load('./models/MNIST/Gaussian/1/MNIST10epochs2020-03-24-23-34-04E'))
dis_xz.load_state_dict(torch.load('./models/MNIST/Gaussian/1/MNIST10epochs2020-03-24-23-34-04D_xz'))
dis_xx.load_state_dict(torch.load('./models/MNIST/Gaussian/1/MNIST10epochs2020-03-24-23-34-04D_xx'))
dis_zz.load_state_dict(torch.load('./models/MNIST/Gaussian/1/MNIST10epochs2020-03-24-23-34-04D_zz'))

In [None]:
dis_xz.to(device)
dis_xx.to(device)
dis_zz.to(device)
enc.to(device)
gen.to(device);

In [None]:
lam = 0.1
loss_neg = torch.zeros((len(testloader_normal),1)).cuda()
loss_pos = torch.zeros((len(testloader_anomalous),1)).cuda()
c_neg = c_pos = 0
for step, (images, labels) in enumerate(testloader_normal, 0):
    dis_xz.eval()
    dis_xx.eval()
    dis_zz.eval()
    enc.eval()
    gen.eval()
    
    x_real_test = images.cuda()
    z_random = torch.randn(images.shape[0], latent_size_MNIST, 1, 1).cuda()
    z_gen = enc(x_real_test)
    x_gen = gen(z_random)
    rec_x = gen(z_gen)
    rec_z = enc(x_gen)
    
    l_gen, _ = dis_xz(x_real_test, z_gen)
    l_enc, _ = dis_xz(x_gen, z_random)
    
    x_logit_real, inter_layer_inp = dis_xx(x_real_test, x_real_test)
    x_logit_fake, inter_layer_rct = dis_xx(x_real_test, rec_x)
   
    fm = inter_layer_inp - inter_layer_rct
    feature_loss = torch.norm(fm, 1, keepdim=False) 
    feature_loss = feature_loss.squeeze()

    loss_neg[c_neg] = feature_loss.detach()
    c_neg += 1

for step, (images, labels) in enumerate(testloader_anomalous, 0):
    dis_xz.eval()
    dis_xx.eval()
    dis_zz.eval()
    enc.eval()
    gen.eval()
    
    x_real_test = images.cuda()
    z_random = torch.randn(images.shape[0], latent_size_MNIST, 1, 1).cuda()
    z_gen = enc(x_real_test)
    x_gen = gen(z_random)
    rec_x = gen(z_gen)
    rec_z = enc(x_gen)
    
    l_gen, _ = dis_xz(x_real_test, z_gen)
    l_enc, _ = dis_xz(x_gen, z_random)
    
    x_logit_real, inter_layer_inp = dis_xx(x_real_test, x_real_test)
    x_logit_fake, inter_layer_rct = dis_xx(x_real_test, rec_x)
   
    fm = inter_layer_inp - inter_layer_rct
    feature_loss = torch.norm(fm, 1, keepdim=False) 
    feature_loss = feature_loss.squeeze()
    
    loss_pos[c_pos] = feature_loss.detach()
    c_pos += 1

print ('mean negative: %0.4f, std negative: %0.4f' %(torch.mean(loss_neg), torch.std(loss_neg)))
print ('mean positive: %0.4f, std positive: %0.4f' %(torch.mean(loss_pos), torch.std(loss_pos)))

In [None]:
x1 = loss_neg.cpu().numpy()
x2 = loss_pos.cpu().numpy()
sns.distplot(x1, hist=False, kde=True, kde_kws={'linewidth': 3}, label='Normal')
sns.distplot(x2, hist=False, kde=True, kde_kws={'linewidth': 3}, label='Anomalous')
plt.title('Distribution of normal and abnormal samples')
plt.xlabel('Anomaly Score');

FP = TP = []
neg_pre_wrong = 0
for i in range(len(loss_neg)):
    if loss_neg[i] > 50:
        neg_pre_wrong += 1

pos_pre_wrong = 0
for i in range(len(loss_pos)):
    if loss_pos[i] <= 50:
        pos_pre_wrong += 1
print ("number of normal samples missclassified: %d, number of anomalous samples missclassified: %d" 
       %(neg_pre_wrong, pos_pre_wrong))
tp = (len(loss_pos) - pos_pre_wrong)
fn = pos_pre_wrong
fp = neg_pre_wrong
tn = len(loss_neg) - neg_pre_wrong
precision = tp / (tp + fp)
## recall / sensitivity / True Positive Rate
recall = tp / (tp + fn)
## False Positive Rate / 1 - Specificity
fp_rate = fp / (fp + tn)
f1 = 2 * ((precision * recall)/(precision + recall))
accuracy = (tp + tn) / (tp + tn + fp + fn)
acc_lam[lam] = accuracy
print ("tp: %d, fp: %d, fn: %d, tn: %d" %(tp, fp, fn, tn))
print ("precision: %.5f, recall: %.5f, specificity: %.5f, f1: %.5f, fp_rate: %.5f, accuracy: %.5f" 
       %(precision, recall, specificity, f1, fp_rate, accuracy))
anomalous = torch.ones((len(loss_pos), 1))
normal = torch.zeros((len(loss_neg), 1))
y = torch.cat((anomalous, normal), 0)
scores = torch.cat((loss_pos, loss_neg), 0)
fpr, tpr, thresholds = metrics.roc_curve(y.cpu(), scores.cpu())
auc = metrics.auc(fpr, tpr)
print ('AUC', auc)
plt.figure()
plt.plot(fpr, tpr, color='darkorange', label='ROC curve (area = %0.2f)' % auc)
plt.plot([0.0, 1.0], color='navy', linestyle='--')
plt.xlim([-0.01, 1.0])
plt.ylim([-0.01, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.legend(loc="lower right")
plt.show();

In [None]:
for ids, (images, labels) in enumerate(testloader_anomalous):
    print (images[0].shape)
    plt.imshow(images[0].squeeze(0), cmap='gray_r')
    plt.show()
    break

In [None]:
for ids, (images, labels) in enumerate(testloader_normal):
    print (images[0].shape)
    plt.imshow(images[0].squeeze(0), cmap='gray_r')
    plt.show()
    break