<a href="https://colab.research.google.com/github/TuanBC/memae-tf/blob/master/MemAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
from __future__ import absolute_import, print_function
import torch
from torch import nn
import math
from torch.nn.parameter import Parameter
from torch.nn import functional as F
import numpy as np

# Các thuộc tính của Memory M

class MemoryUnit(nn.Module):
    def __init__(self, mem_dim, fea_dim, shrink_thres=0.0025):
        super(MemoryUnit, self).__init__()
        self.mem_dim = mem_dim # N
        self.fea_dim = fea_dim # C
        self.weight = Parameter(torch.Tensor(self.mem_dim, self.fea_dim))  # N x C
        self.bias = None
        self.shrink_thres= shrink_thres
        # self.hard_sparse_shrink_opt = nn.Hardshrink(lambd=shrink_thres)

        self.reset_parameters()

    # Hàm để khởi tạo weight và bias attention layer

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input):
        # input: latent z
        # Mem: bộ weight attention
        
        att_weight = F.linear(input, self.weight)  # Fea x Mem^T, (TxC) x (CxN) = TxN
        att_weight = F.softmax(att_weight, dim=1)  # TxN
        
        # Dùng threshold để thay đổi att_weight

        # ReLU based shrinkage, hard shrinkage for positive value
        if(self.shrink_thres>0):

            att_weight = hard_shrink_relu(att_weight, lambd=self.shrink_thres)
            # att_weight = F.softshrink(att_weight, lambd=self.shrink_thres)
            # normalize???
            att_weight = F.normalize(att_weight, p=1, dim=1)
            # att_weight = F.softmax(att_weight, dim=1)
            # att_weight = self.hard_sparse_shrink_opt(att_weight)
        
        mem_trans = self.weight.permute(1, 0)  # Mem^T, MxC
        output = F.linear(att_weight, mem_trans)  # AttWeight x Mem^T^T = AW x Mem, (TxM) x (MxC) = TxC
        
        return {'output': output, 'att': att_weight}  # output, att_weight

    def extra_repr(self):
        return 'mem_dim={}, fea_dim={}'.format(
            self.mem_dim, self.fea_dim is not None
        )


# NxCxHxW -> (NxHxW)xC -> addressing Mem, (NxHxW)xC -> NxCxHxW
class MemModule(nn.Module):
    def __init__(self, mem_dim, fea_dim, shrink_thres=0.0025, device='cuda'):
        super(MemModule, self).__init__()
        self.mem_dim = mem_dim
        self.fea_dim = fea_dim
        self.shrink_thres = shrink_thres
        self.memory = MemoryUnit(self.mem_dim, self.fea_dim, self.shrink_thres)

    def forward(self, input):
        s = input.data.shape
        l = len(s)

        # Số chiều input z (latent)

        if l == 3:
            x = input.permute(0, 2, 1)
        elif l == 4:
            x = input.permute(0, 2, 3, 1)
        elif l == 5:
            x = input.permute(0, 2, 3, 4, 1)
        else:
            x = []
            print('wrong feature map size')
        x = x.contiguous()
        x = x.view(-1, s[1])
        #
        y_and = self.memory(x)
        
        # y: output
        # att: attention weight w

        y = y_and['output']
        att = y_and['att']

        if l == 3:
            y = y.view(s[0], s[2], s[1])
            y = y.permute(0, 2, 1)
            att = att.view(s[0], s[2], self.mem_dim)
            att = att.permute(0, 2, 1)
        elif l == 4:
            y = y.view(s[0], s[2], s[3], s[1])
            y = y.permute(0, 3, 1, 2)
            att = att.view(s[0], s[2], s[3], self.mem_dim)
            att = att.permute(0, 3, 1, 2)
        elif l == 5:
            y = y.view(s[0], s[2], s[3], s[4], s[1])
            y = y.permute(0, 4, 1, 2, 3)
            att = att.view(s[0], s[2], s[3], s[4], self.mem_dim)
            att = att.permute(0, 4, 1, 2, 3)
        else:
            y = x
            att = att
            print('wrong feature map size')
        return {'output': y, 'att': att}

# relu based hard shrinkage function, only works for positive values

def hard_shrink_relu(input, lambd=0, epsilon=1e-12):
    output = (F.relu(input-lambd) * input) / (torch.abs(input - lambd) + epsilon)
    return output


In [0]:
try:
    # %tensorflow_version only exists in Colab.
    %tensorflow_version 2.x
except Exception:
    pass

TensorFlow 2.x selected.


In [0]:
import tensorflow as tf

from tensorflow.keras.models import Model
from tensorflow.keras.layers import *
from tensorflow.keras.activations import softmax, relu
from tensorflow.keras.initializers import RandomUniform

In [0]:
c = tf.constant([[[1.0, 2.0]]])
print(c)

tf.Tensor([[[1. 2.]]], shape=(1, 1, 2), dtype=float32)


In [0]:
d = tf.constant([[1.0, 2.0], [30.0, -4.0], [1,1]])
print(d)

tf.Tensor(
[[ 1.  2.]
 [30. -4.]
 [ 1.  1.]], shape=(3, 2), dtype=float32)


In [0]:
tf.linalg.normalize(d, ord=1, axis=1)

(<tf.Tensor: id=5, shape=(3, 2), dtype=float32, numpy=
 array([[ 0.33333334,  0.6666667 ],
        [ 0.88235295, -0.11764706],
        [ 0.5       ,  0.5       ]], dtype=float32)>,
 <tf.Tensor: id=4, shape=(3, 1), dtype=float32, numpy=
 array([[ 3.],
        [34.],
        [ 2.]], dtype=float32)>)

In [0]:
def tf_swap_last_2_axis(x):
    if tf.rank(x)<=2:
        return tf.transpose(x)
    else:
        # check again in case rank>2
        return tf.transpose(x, [i for i in range(tf.shape(x).shape[0]-2)] + [tf.shape(x).shape[0]-2, tf.shape(x).shape[0]-1])
    
# tf_swap_last_2_axis_test(d)

In [0]:
def compute_cosine_distances(a, b):
    # a: Input, shape = (batch * n_a * fea_dim) 
    # b: Memory, shape = (n_b * fea_dim)
    
    # output: shape = (batch * n_a * n_b)

    a_normalized, _ = tf.linalg.normalize(a, ord=1, axis=-1)
    b_normalized, _ = tf.linalg.normalize(b, ord=1, axis=-1)
    
    # b_normalized_transposed = tf_swap_last_2_axis(b_normalized)
    b_normalized_transposed = tf.transpose(b_normalized)
    
    distance = tf.matmul(a_normalized, b_normalized_transposed)

    return distance

In [0]:
tf.rank(d)

<tf.Tensor: id=6, shape=(), dtype=int32, numpy=2>

In [0]:
print(compute_cosine_distances(c, d))
print(compute_cosine_distances(c, tf.squeeze(c, 0)))

tf.Tensor([[[0.5555556  0.21568629 0.5       ]]], shape=(1, 1, 3), dtype=float32)
tf.Tensor([[[0.5555556]]], shape=(1, 1, 1), dtype=float32)


In [0]:
# next step: add regularizer

class MemoryUnit(Layer):
    def __init__(self, mem_dim, shrink_thres=0.0025):
        # C: dimension of vector z
        # M: size of the memory
        super(MemoryUnit, self).__init__()
        self.mem_dim = mem_dim
        self.kernel_regularizer = None
        self.shrink_thres= shrink_thres
    
    def build(self, input_shape):
        self.std = 8
        
        # M x C
        self.weight = self.add_weight(shape=(self.mem_dim, input_shape[-1]),
                                      initializer=RandomUniform(-self.std, self.std, seed=2803),
                                      regularizer=self.kernel_regularizer,
                                      trainable=True)

    def call(self, inputs):
        # att_weight = F.linear(inputs, self.weight)  # Fea x Mem^T, (TxC) x (CxM) = TxM
        # att_weight = F.softmax(att_weight, dim=1)  # TxM

        att_weight = compute_cosine_distances(inputs, self.weight) # Fea x Mem^T, (batchxTxC) x (CxM) = TxM
        # att_weight = tf.matmul(inputs, tf.transpose(self.weight))
        att_weight = softmax(att_weight) # TxM

        if(self.shrink_thres>0):
            att_weight = relu(att_weight, threshold=self.shrink_thres)

            # normalize by p=1 (L1 normalization)
            att_weight, _ = tf.linalg.normalize(att_weight, ord=1, axis=1)
        output = tf.matmul(att_weight, self.weight)
        return output


In [0]:
(train_images, train_label), (test_images, test_label) = tf.keras.datasets.mnist.load_data()

train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
test_images = test_images.reshape(test_images.shape[0], 28, 28, 1).astype('float32')

# Normalizing the images to the range of [0., 1.]
train_images /= 255.
test_images /= 255.

# Binarization
train_images[train_images >= .5] = 1.
train_images[train_images < .5] = 0.
test_images[test_images >= .5] = 1.
test_images[test_images < .5] = 0.


Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [0]:
normal_class = 9

train_normal_images = train_images[train_label==normal_class]
train_normal_label = train_label[train_label==normal_class]

train_anomaly_images = train_images[train_label!=normal_class]
train_anomaly_label = train_label[train_label!=normal_class]

test_normal_images = test_images[test_label==normal_class]
test_normal_label = test_label[test_label==normal_class]

test_anomaly_images = test_images[test_label!=normal_class]
test_anomaly_label = test_label[test_label!=normal_class]

In [0]:
train_images.shape

(60000, 28, 28, 1)

In [0]:
x_input = Input(shape=[28,28,1])

# x = Conv2D(8, 3, 1, 'same')(x_input)

x = Flatten()(x_input)

x = Dense(256, activation='relu')(x)
x = Dense(64, activation='relu', name='latent')(x)
x = Dense(256, activation='relu')(x)
x = Dense(784, activation='relu')(x)

x = Reshape((28,28,1))(x)


model = Model(x_input, x)
model.summary()

Model: "model_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_6 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
flatten_5 (Flatten)          (None, 784)               0         
_________________________________________________________________
dense_19 (Dense)             (None, 256)               200960    
_________________________________________________________________
latent (Dense)               (None, 64)                16448     
_________________________________________________________________
dense_20 (Dense)             (None, 256)               16640     
_________________________________________________________________
dense_21 (Dense)             (None, 784)               201488    
_________________________________________________________________
reshape_5 (Reshape)          (None, 28, 28, 1)         0   

In [0]:
model.compile(optimizer='adam', loss='mse')

model.fit(train_normal_images, train_normal_images, batch_size=1024, epochs=100, validation_split=0.1, verbose=2)

Train on 5354 samples, validate on 595 samples
Epoch 1/100
5354/5354 - 1s - loss: 0.1037 - val_loss: 0.0779
Epoch 2/100
5354/5354 - 0s - loss: 0.0740 - val_loss: 0.0649
Epoch 3/100
5354/5354 - 0s - loss: 0.0641 - val_loss: 0.0579
Epoch 4/100
5354/5354 - 0s - loss: 0.0574 - val_loss: 0.0526
Epoch 5/100
5354/5354 - 0s - loss: 0.0522 - val_loss: 0.0480
Epoch 6/100
5354/5354 - 0s - loss: 0.0477 - val_loss: 0.0442
Epoch 7/100
5354/5354 - 0s - loss: 0.0442 - val_loss: 0.0411
Epoch 8/100
5354/5354 - 0s - loss: 0.0410 - val_loss: 0.0382
Epoch 9/100
5354/5354 - 0s - loss: 0.0380 - val_loss: 0.0354
Epoch 10/100
5354/5354 - 0s - loss: 0.0353 - val_loss: 0.0328
Epoch 11/100
5354/5354 - 0s - loss: 0.0329 - val_loss: 0.0308
Epoch 12/100
5354/5354 - 0s - loss: 0.0310 - val_loss: 0.0292
Epoch 13/100
5354/5354 - 0s - loss: 0.0293 - val_loss: 0.0277
Epoch 14/100
5354/5354 - 0s - loss: 0.0278 - val_loss: 0.0265
Epoch 15/100
5354/5354 - 0s - loss: 0.0267 - val_loss: 0.0255
Epoch 16/100
5354/5354 - 0s - lo

<tensorflow.python.keras.callbacks.History at 0x7f6ffce054a8>

In [0]:
model.evaluate(train_anomaly_images, train_anomaly_images, batch_size=1024, verbose=2)

54051/1 - 0s - loss: 0.0478


0.0456561155160148

In [0]:
x_input = Input(shape=[28,28,1])

# x = Conv2D(8, 3, 1, 'same')(x_input)

x = Flatten()(x_input)

x = Dense(256, activation='relu')(x)
x = Dense(64, activation='relu')(x)

x = MemoryUnit(100)(x)

x = Dense(256, activation='relu')(x)
x = Dense(784, activation='relu')(x)

x = Reshape((28,28,1))(x)


model_2 = Model(x_input, x)
model_2.summary()

Model: "model_7"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_9 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
flatten_8 (Flatten)          (None, 784)               0         
_________________________________________________________________
dense_28 (Dense)             (None, 256)               200960    
_________________________________________________________________
dense_29 (Dense)             (None, 64)                16448     
_________________________________________________________________
memory_unit_4 (MemoryUnit)   (None, 64)                6400      
_________________________________________________________________
dense_30 (Dense)             (None, 256)               16640     
_________________________________________________________________
dense_31 (Dense)             (None, 784)               2014

In [0]:
model_2.compile(optimizer='adam', loss='mse')

model_2.fit(train_normal_images, train_normal_images, batch_size=1024, epochs=100, validation_split=0.1, verbose=2)

Train on 5354 samples, validate on 595 samples
Epoch 1/100
5354/5354 - 1s - loss: 0.1108 - val_loss: 0.0961
Epoch 2/100
5354/5354 - 0s - loss: 0.0962 - val_loss: 0.0912
Epoch 3/100
5354/5354 - 0s - loss: 0.0930 - val_loss: 0.0893
Epoch 4/100
5354/5354 - 0s - loss: 0.0914 - val_loss: 0.0883
Epoch 5/100
5354/5354 - 0s - loss: 0.0904 - val_loss: 0.0872
Epoch 6/100
5354/5354 - 0s - loss: 0.0893 - val_loss: 0.0862
Epoch 7/100
5354/5354 - 0s - loss: 0.0886 - val_loss: 0.0857
Epoch 8/100
5354/5354 - 0s - loss: 0.0883 - val_loss: 0.0855
Epoch 9/100
5354/5354 - 0s - loss: 0.0882 - val_loss: 0.0854
Epoch 10/100
5354/5354 - 0s - loss: 0.1030 - val_loss: 0.1191
Epoch 11/100
5354/5354 - 0s - loss: 0.1223 - val_loss: 0.1181
Epoch 12/100
5354/5354 - 0s - loss: 0.1211 - val_loss: 0.1169
Epoch 13/100
5354/5354 - 0s - loss: 0.1199 - val_loss: 0.1157
Epoch 14/100
5354/5354 - 0s - loss: 0.1187 - val_loss: 0.1145
Epoch 15/100
5354/5354 - 0s - loss: 0.1175 - val_loss: 0.1134
Epoch 16/100
5354/5354 - 0s - lo

<tensorflow.python.keras.callbacks.History at 0x7f6ff8f2f7f0>

In [0]:
model_2.evaluate(train_anomaly_images, train_anomaly_images, batch_size=1024, verbose=2)

54051/1 - 0s - loss: 0.1150


0.11046588828595295

In [0]:
model_encoder = Model(inputs=model_2.input, outputs=model_2.get_layer('memory_unit_4').output)

model_encoder.summary()

Model: "model_9"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_9 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
flatten_8 (Flatten)          (None, 784)               0         
_________________________________________________________________
dense_28 (Dense)             (None, 256)               200960    
_________________________________________________________________
dense_29 (Dense)             (None, 64)                16448     
_________________________________________________________________
memory_unit_4 (MemoryUnit)   (None, 64)                6400      
Total params: 223,808
Trainable params: 223,808
Non-trainable params: 0
_________________________________________________________________


In [0]:
import matplotlib.pyplot as plt

