# 0. Import Library

In [1]:
import os
if not os.path.exists("./tfdet"):
    !git clone -q http://github.com/burf/tfdetection.git
    !mv ./tfdetection/tfdet ./tfdet
    !rm -rf ./tfdetection

In [2]:
#ignore warning
import warnings, os
warnings.filterwarnings(action = "ignore")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

import tfdet
tfdet.core.util.set_seed(777) #set seed
device = tfdet.core.util.select_device(0) #set device

# 1. Init Dataset

In [3]:
import numpy as np

image_shape = [32, 32]
label = ["OK", "NG"]
train_size = 1000
test_size = 100
batch_size = 16

(tr_x, tr_y), (te_x, te_y) = tf.keras.datasets.mnist.load_data()

tr_x = np.expand_dims(tr_x[(tr_y == 0)][:train_size], axis = -1)

te_x = np.expand_dims(te_x[:test_size], axis = -1)
te_y = np.expand_dims((te_y[:test_size] != 0), axis = -1) #0 is ok, etc is ng.

tr_x = np.tile(tr_x, [1, 1, 1, 3])
te_x = np.tile(te_x, [1, 1, 1, 3])

tr_pipe = tfdet.dataset.pipeline.resize(tr_x, image_shape = image_shape, keep_ratio = False)
tr_pipe = tfdet.dataset.pipeline.normalize(tr_pipe, batch_size = batch_size, prefetch = True)

te_pipe = tfdet.dataset.pipeline.resize(te_x, image_shape = image_shape, keep_ratio = False)
te_pipe = tfdet.dataset.pipeline.normalize(te_pipe, batch_size = batch_size, prefetch = True)

# 2. Build Detector

In [4]:
with device:
    x = tf.keras.layers.Input(shape = [*image_shape, 3])
    out = tfdet.model.backbone.wide_resnet50_2_torch(x, weights = "imagenet", indices = [0, 1, 2])
    model = tf.keras.Model(x, out)
    feature = model.predict(tr_pipe, verbose = 1) #feature extract



# 3. Train

3-1. Init HyperParameter

In [5]:
sampling_size = 550
memory_reduce = True

n_feature = np.sum([np.shape(f)[-1] for f in feature])
sampling_index = np.random.choice(np.arange(n_feature), sampling_size, replace = False)

3-2. Generate Feature Vector

In [6]:
with device:
    feature_vector = tfdet.model.train.padim.train(feature, sampling_index = sampling_index, memory_reduce = memory_reduce) #memory_reduce is a tradeoff between accuracy and memory

3-3. Build Predict Model

In [7]:
with device:
    score, mask = tfdet.model.detector.padim(out, feature_vector, image_shape = image_shape, sampling_index = sampling_index, memory_reduce = memory_reduce) #align memory_reduce with train in test
    model = tf.keras.Model(x, [score, mask])
    pred_score, pred_mask = model.predict(te_pipe, verbose = 1)
    threshold = tfdet.util.get_threshold(te_y, pred_score)
    filtered_out = tfdet.model.postprocess.padim.FilterDetection(threshold = threshold)([score, mask])
    model = tf.keras.Model(x, filtered_out)



# 4. Evaluate

In [8]:
pred_score, pred_mask = model.predict(te_pipe, verbose = 0)
print("score : {0:.4f}".format(np.mean((0 < pred_score) == te_y)))

score : 0.9900


# 5. Save & Load

5-1. Save

In [9]:
import os, shutil, pickle

save_path = "./learn/model.pickle"

if os.path.exists(os.path.dirname(save_path)):
    shutil.rmtree(os.path.dirname(save_path))
os.makedirs(os.path.dirname(save_path), exist_ok = True)

tfdet.dataset.util.save_pickle([image_shape, feature_vector, sampling_index, memory_reduce, threshold], save_path)

'./learn/model.pickle'

5-2. Load

In [10]:
save_path = "./learn/model.pickle"

image_shape, feature_vector, sampling_index, memory_reduce, threshold = tfdet.dataset.util.load_pickle(save_path)
    
with device:
    x = tf.keras.layers.Input(shape = [*image_shape, 3])
    out = tfdet.model.backbone.wide_resnet50_2_torch(x, weights = "imagenet", indices = [0, 1, 2])
    score, mask = tfdet.model.detector.padim(out, feature_vector, image_shape = image_shape, sampling_index = sampling_index, memory_reduce = memory_reduce) #align memory_reduce with train in test
    filtered_out = tfdet.model.postprocess.padim.FilterDetection(threshold = threshold)([score, mask])
    model = tf.keras.Model(x, filtered_out)
    
pred_score, pred_mask = model.predict(te_pipe, verbose = 0)
print("score : {0:.4f}".format(np.mean((0 < pred_score) == te_y)))

score : 0.9900
