# Anomaly Detection with Tensor Networks (beta version)

Author: Luuk Coopmans, Trinity College Dublin and DIAS, luukcoopmans2@gmail.com

In this notebook we show how to train tensor networks to perform anomaly detection on the MNIST dataset. We follow the paper [1] https://arxiv.org/abs/2006.02516 by J. Wang, C. Roberts, G. Vidal and S. Leichenauer.

We start by importing the required libraries and modules (mostly updated versions as of April 12, 2021). We use the previously already existing libraries: jax for automatic differentation, tensorflow to import the datasets, sklearn for the score function and tensornetwork for tensor contractions. Moreover, some specific functions for the anomaly detection algorithm are included in the newly written ML_tensor module.

In [None]:
import jax
jax.config.update('jax_enable_x64', True) # enable 64-bit precision
import jax.numpy as np
from jax import grad, random, jit
from jax.experimental import optimizers 

import tensorflow as tf
from sklearn.metrics import roc_auc_score
import pickle

import tensornetwork as tn
tn.set_default_backend("jax")     # different backends are possible, see the tensornetwork documentation guide

import matplotlib.pyplot as plt

from ML_tensor import*

## 1. Preparation of the training and test data

Next we prepare the training and test data of the MNIST dataset (note other data sets such as the Fashion MNIST can be loaded and prepared as well). We redefine the trainingset to be the set with only one specific label and apply a (2x2) pooling layer to the images. For this we use the functions **one_label_subset** and **feature_map** from our ML_tensor module:

In [None]:
normal_label = 1 # defining which MNIST label is a normal instance

# loading and renormalizing the data
mnist = tf.keras.datasets.mnist 
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train/255.0, x_test/255.0

# picking the subset of the data corresponding to the normal instance label
x_train = one_label_subset(x_train, y_train, normal_label)

# mapping the training data to input feature vectors suitable for our tensor network
f_vecs = feature_map(x_train, Pooling=True)
f_vecs_test = feature_map(x_test, Pooling=True)

# Making the test label vector (1 for normal_instance and 0 for anomaly)
y_test[np.where(y_test==normal_label)] = 1
y_test[np.where(y_test!=normal_label)] = 0

As a check to see that we have selected the right data we can plot some of the training data:

In [None]:
plt.imshow(x_train[3],cmap=plt.cm.binary)
plt.show()

For training in batches we can split up the feature vectors (f_vecs) into smaller batches with tensorflows split function. Note that we cut of the number off f_vecs such that we have an integer number of batches.

In [None]:
batch_size =  32

n = int(np.floor(len(f_vecs)/batch_size)) # integer number of batches
batches = np.array(tf.split(f_vecs[0:batch_size*n], n))

print(np.shape(batches))

## 2. The trainable tensor network (matrix product operator), loss function and penalty function

The central object we are going to train for anomaly detection is a tensor network (more specifically a linear transformation known as matrix product operator MPO see [1] for details). The MPO is nothing but a nested list of numpy arrays which contain our trainable parameters. The indices of the arrays correspond the the legs of the tensor network and can be dangling or bonded. Dangling means they don't connect with the neighbouring arrays in the list whereas the bonded indices do. Increasing the dimension of these indices will lead to more trainable parameters. To initialize a random MPO we can just call the **rand_anomdet_MPO** from the ML_tensor module:

In [None]:
d = 2 # dimension of the input legs (should match the dimension axis=1 of the f_vecs)
b = 5 # dimension of the bonds
p = 2 # dimension of the output legs
S = 8 # parameter that determines the number of output legs.

mean = 0
std = 0.31

MPO = rand_anomdet_MPO(np.shape(f_vecs)[2],d, b, p, S, mean, std)

To apply the MPO to a single feature vector (required for the loss function, see below) we can call the **apply_MPO_to_fvec** function (first time maybe slow to run due to jit):

In [None]:
fvec = f_vecs[0]
MPS = apply_MPO_to_fvec(fvec, MPO) # we call this a matrix product state (MPS) because of the form of the indices

The decision function for the anomaly detection is defined to be the squared F (Frobenius) norm of the linear transformation (MPO $\equiv P$) applied to a feature vector $||P(f_{vec})||^2_2$. It returns a 1 for a normal instance and a 0 for an anomaly based on the radius $\epsilon$ known as the decision boundary which we set manually.

In [None]:
eps = 0.5 # decision radius
print('Decision for fvec is:',decision_fun(eps, MPO, fvec))

For the training we define the loss of a single f_vec to be the log of the squared F-norm -1 all squared, i.e. $(log(||P(f_{vec})||^2_2)-1)^2$. We can compute this loss by calling the function **loss_function** from ML_tensor. 

In [None]:
loss_function(MPO, fvec)

To avoid converging to a trivial solution we add a penalty term to every batch of losses given by the rectified linear of the log of the squared F operator norm: ReLu($||P||^2_2$). This penalty can be computed with the **penalty** function.

In [None]:
print(penalty(MPO))

## 3. Gradients and the training loop

To obtain the gradient of the individual loss and penalty functions we make use of the grad function from the jax library. This function takes in a function and returns the corresponding gradient function which can be called with the same arguments. Note that running the cells below for the first time could take a while due the compilation of the jit, running the cell for the second time is should be must faster thanks to the jit (see jax documentation).

In [None]:
loss_gradient = jit(grad(loss_function))
loss_grad = loss_gradient(MPO,fvec)

In [None]:
penalty_gradient = jit(grad(penalty))
grad_pen = penalty_gradient(MPO)

After a slow jit compilation (could take a few minutes) it should run now fast (milliseconds) in the following cells:

In [None]:
loss_grad = loss_gradient(MPO,fvec)
grad_pen = penalty_gradient(MPO)

To compute the values and gradient for a single batch we can call the **batch_loss_and_gradient** again due to jit the first time this can be slow (up to 5 min). The hyperparameter alpha is the relative importance of the penalty versus the loss that is chosen manually as in [1].

In [None]:
alpha = 0.4
batch = batches[0]
value, pen_value, grads = batch_loss_and_gradient(MPO,batch, alpha)
print('Loss is:',value, 'Penalty is',pen_value)

We can also compute the predictions of the MPO anomaly detector with the **anomaly_detection** function and compute the roc_auc score (from sklearn): (again this is a bit slow unfortunately)

In [None]:
y_pred = anomaly_detection(MPO,f_vecs_test)
print('Starting Roc_auc score is:',roc_auc_score(y_test, y_pred))

Before we define the learning loop we can initialize an optimizer for the optimization by using calling one from the jax library:

In [None]:
step_size = 1e-3
opt_init, opt_update, get_params = optimizers.adam(step_size) # other optimizers such as SGD are also available

Finally we are now ready to set up the training loop and start training our tensor network!

In [None]:
num_epochs =  5
print_roc_every = 5

loss_list = []
pen_list = []

opt_state = opt_init(MPO) # initialize the learnable parameters for the optimizer

value, pen_value, grads = batch_loss_and_gradient(MPO,batch, 0.4) # initial value and gradient
loss_list.append(value)

for i in range(num_epochs):
    
    for j in range(len(batches)):
        
        batch = batches[j]
        
        # update the MPO
        value, pen_value, grads = batch_loss_and_gradient(MPO,batch, alpha)
        opt_state = opt_update(i, grads, opt_state)
        MPO = get_params(opt_state)
        
        if j%10 == 0:
            print('After batch update step:',j,'new loss is:', value, 'new penalty is:', pen_value)
        
    # update learning curves every epoch
    loss_list.append(value)
    pen_list.append(pen_value)
    
    # compute and print roc_au score every print_roc_every epochs
    if i%print_roc_every == 0:
        y_pred = anomaly_detection(MPO,f_vecs_test)
        print('New roc_auc score is:',roc_auc_score(y_test, y_pred))
    
    # save the MPO every epoch
    with open('MPO_epoch{}.pkl'.format(i),'wb') as f:
            pickle.dump(MPO, f)

# save the final learning curve
np.save('loss_array.npy',np.array(loss_list))

We can plot the learning and penalty curves to watch the progress of the learning algorithm: 

In [None]:
plt.plot(np.array(loss_list))
plt.xlabel('epoch')
plt.ylabel('batch_loss')

In [None]:
plt.plot(np.array(pen_list))
plt.xlabel('epoch')
plt.ylabel('penalty value')

Also we can compute the final predictions and roc_auc score:

In [None]:
y_pred = anomaly_detection(MPO,f_vecs_test)
print('Roc_auc score is:',roc_auc_score(y_test, y_pred))

## 4. Notes for future updates/improvements of the code:

- There are countless possibilities to improve the above code. Most importantly the computational time should be improved. Currently the code is only running on one core (the exact reason why this is happening is at this point still a bit unclear to me). I expect that once it runs on a few cpu cores or even on a gpu it will be much faster. Another bottleneck in time I discovered is in the gradient update of the MPO: while running the **batch_loss_and_gradient** (the 2nd time after jit compilation) for the full epoch size takes about 33 seconds within the training loop it takes 176 seconds on my laptop (see test below).    

- When the computational time is reduced it will be also possible to scan over the hyper parameters $\alpha$, batch_size and step_size to determine the optimal learning settings. Currently the algorithm converges to a roc_auc_score of about 0.85 after 5 epochs. In the paper [1] 300 episodes were used and a roc_auc score of 0.998 was found. I expect that once the optimal learning parameters are set the value of the roc_auc score of the code presented here should become closer to the score in [1].

- Another point worth investigating further is the instability of the penalty function, it can get very big due to the contraction of the many tensors in the MPO (this was also reported in [1]). Perhaps with different initial conditions or a different value for alpha this could be resolved. Another option would be to look if it is possible  to normalize the MPO in some way (eg this can be done for standard MPS states).  

- (Technical point) Finally the feature map used here maps images to product states, in future releases it might be interesting to explore other ways of encoding the images such as in entangled states. An example I am thinking of is that you take a subset of images (product states) add them together in an approximate fashion (ie keeping the bond dimension low) which forms a new entangled MPS. Then this new entangled MPS could be again combined with an MPO like done here. Or it could be interesting to compute the overlap of this MPS with other states within and outside the training set to see if you can use it as an anomaly detector (and/or classifier). 


Thanks for questions and suggestions please email me at luukcoopmans2@gmail.com

## Tests:

Testing the time of an epoch (without update):

In [None]:
import time
then = time.time()
for i in range(len(batches)):
    value, pen_value, grads = batch_loss_and_gradient(MPO,batch, alpha)
    print(time.time()-then)

Testing the time of an epoch (with adam updates):

In [None]:
step_size = 1e-3
opt_init, opt_update, get_params = optimizers.adam(step_size) # other optimizers such as SGD are also available
opt_state = opt_init(MPO)

In [None]:
import time
then = time.time()
for i in range(len(batches)):
    value, pen_value, grads = batch_loss_and_gradient(MPO,batch, alpha)
    opt_state = opt_update(i, grads, opt_state)
    MPO = get_params(opt_state)
    print(time.time()-then)