In [1]:
%matplotlib inline
import scipy.stats as sct
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from ipywidgets import interact, fixed
import glob
import pickle
from pprint import pprint
import os
from PIL import Image
from tqdm import tqdm

In [2]:
# %%html
# <style>
# div.input {
#     display:none;
# }
# </style>

In [3]:
# データセットの枚数を数える
dataset = os.path.expanduser('~/group/msuzuki/MVTechAD')
classnames = sorted([d for d in glob.glob(dataset + '/*') if os.path.isdir(d)])
for classn in classnames:
    n_train = len(glob.glob(os.path.join(classn, 'train', '*', '*.png')))
    n_test = len(glob.glob(os.path.join(classn, 'test', '*', '*.png')))
    # print("{classn:<10}\t{train}\t{test}".format(classn=classn.split('/')[-1], train=n_train, test=n_test))

In [4]:
logdirs = sorted(glob.glob('output4all/*_mbn_*'))
# output4all_kfold = "output4all_kfold_val"
output4all_kfold = "output4all"
classnames = sorted(list(set([d.split('.')[0] for d in logdirs])))
# print([c.split('/')[1] for c in classnames])
# n_layers = [0, 6, 12, 18]
n_layers = [19]
# thresholds = [0.85, 0.9, 0.95, 0.99, 0.999]
thresholds = [0.999]
seeds = [6, 5, 15, 32, 85, 55, 71, 16, 78, 69]
mul_sigs = [0, 1, 2, 3, 4, 5]

In [5]:
def loadfile(getfilename):
    filename = glob.glob(getfilename)[0]
    with open(filename, 'rb') as f:
        variable = pickle.load(f)
    return variable

In [6]:
def showtestimg(dirname, i):
    imgd = Image.open(glob.glob(os.path.join(dirname, "D_test_{0}_*.png".format(i-1)))[0], "r")
    plt.rcParams['figure.figsize'] = (10.0, 10.0)
    plt.figure()
    plt.imshow(np.array(imgd))
    plt.show()
    
    imgl = Image.open(glob.glob(os.path.join(dirname, "learned{0}.png".format(i-1)))[0], "r")
    print(glob.glob(os.path.join(dirname, "learned{0}.png".format(i-1)))[0])
    plt.rcParams['figure.figsize'] = (5.0, 5.0)
    plt.figure()
    plt.imshow(np.array(imgl))
    plt.show()

In [7]:
def show_eachresult(classname, n_layer, threshold, seed, mul_sig=3):
    dirname = classname + '._mbn_1.0.{0}.{1}.{2}.{3}'.format(n_layer, seed, threshold, mul_sig)
    
    print("AUC")
    auclog = loadfile(os.path.join(dirname, "AUClog*.pcl"))
    print(len(auclog))
    print(auclog[:5])
    plt.rcParams['figure.figsize'] = (5.0, 5.0)
    plt.figure()
    plt.ylim(0.7, 1)
    plt.plot(auclog, label="Max = %.3f" % max(auclog))
    plt.legend()
    plt.show()
    
    print("D")
    dlog = loadfile(os.path.join(dirname, "Dlog*.pcl"))
    print(len(dlog))
    plt.rcParams['figure.figsize'] = (10.0, 2.0)
    plt.figure()
    plt.plot(dlog)
    plt.show()
    
    print("UPD")
    updlog = loadfile(os.path.join(dirname, "UPDlog*.pcl"))
    print(len(updlog))
    plt.rcParams['figure.figsize'] = (10.0, 2.0)
    plt.figure()
    plt.plot(updlog)
    plt.show()
    
    d_test_names = glob.glob(os.path.join(dirname, "D_test_*_*.png"))
    # d_test_max = max([int(os.path.basename(dtn).split('_')[2]) for dtn in d_test_names])
    interact(showtestimg, dirname=dirname, i=(0, len(auclog), 1))

In [8]:
interact(show_eachresult, classname=classnames, n_layer=n_layers, threshold=thresholds, seed=seeds)

interactive(children=(Dropdown(description='classname', options=('output4all/bottle', 'output4all/cable', 'out…

<function __main__.show_eachresult(classname, n_layer, threshold, seed, mul_sig=3)>

In [9]:
## 関数群

In [10]:
def get_maxauc(classname, defe=True):
    maxauclog = {}
    for l in n_layers:
        for t in thresholds:
            # auclog = loadfile(classname + '._mbn_1.0.{0}.55.{1}/{2}'.format(l, t, 'AUClog*.pcl'))
            dirname = classname + '._mbn_1.0.{0}.55.{1}*'.format(l, t)
            auclog = loadfile(os.path.join(dirname, 'AUClog*.pcl'))
            if len(auclog)==1:
                continue
            elif defe:
                maxauclog[(l, t)] = max(auclog[1:]) # good を省く
            else:  # good only
                maxauclog[(l, t)] = auclog[0]
    if len(maxauclog)==0:
        return 0.0, None
    max_l, max_t = max(maxauclog, key=maxauclog.get)
    return max(maxauclog.values()), max_l, max_t

In [11]:
def get_maxauc_lt(classname, l, t, defe=True):
    dirname = classname + '._mbn_1.0.{0}.55.{1}*'.format(l, t)
    auclog = loadfile(os.path.join(dirname, 'AUClog*.pcl'))
    if defe:
        return max(auclog)
    else:  # good only
        return auclog[0]

In [12]:
def get_auclog_lt(classname, l, t):
    dirname = classname + '._mbn_1.0.{0}.55.{1}*'.format(l, t)
    auclog = loadfile(os.path.join(dirname, 'AUClog*.pcl'))
    return auclog

In [13]:
def get_searchresult(classname, dirname):
    # print(os.path.join(dirname, classname.split('/')[1] + '*/paramsearch.pcl'))
    paramsearch = loadfile(os.path.join(dirname, classname.split('/')[1] + '*/paramsearch.pcl'))
    top = sorted(paramsearch.items(), key=lambda x:x[1])
    return top

In [14]:
## 異常データまで学習した場合の最大のAUC

In [15]:
print("label\t\tAUCgod\tAUCdef\tdiff")
for classname in classnames:
    # 正常
    v, l, t = get_maxauc(classname, defe=False)
    # 異常
    vv, ll, tt = get_maxauc(classname, defe=True)
    
    print("{0:<12}\t{1}\t{2}\t{3}".format(classname.split('/')[1], "%.2f" % v, "%.2f" % vv, "%.2f" % (vv-v)))

label		AUCgod	AUCdef	diff
bottle      	1.00	1.00	0.00
cable       	0.89	0.92	0.03
capsule     	0.93	0.96	0.03
carpet      	0.79	0.84	0.05
grid        	0.52	0.51	-0.00
hazelnut    	0.97	0.99	0.02
leather     	0.99	1.00	0.01
metal_nut   	0.90	0.94	0.04
pill        	0.86	0.90	0.04
screw       	0.82	0.91	0.08
tile        	1.00	1.00	0.00
toothbrush  	1.00	1.00	0.00
transistor  	0.91	0.92	0.01
wood        	0.96	1.00	0.03
zipper      	0.99	0.99	0.00


In [16]:
## 学習枚数（１〜５）ごとのAUCの変化

In [17]:
# 各々の枚数でのAUC
print("Class, AE, Good, Def1, Def2, Def3, Def4") 
for classname in classnames:
    l = 19
    t = 0.999
    auclog = get_auclog_lt(classname, l, t)
    print("{:<12}".format(classname.split('/')[-1]), end=', ')
    print("0.00", end=', ')
    for l in auclog[:5]:
    #for l in [auclog[0]]:
        # pass
        print("%.2f" % l, end=',\t')
    v, _, _ = get_maxauc(classname)
    # print("%.2f" % v, "\t%.2f" % (v-l))
    print('')

Class, AE, Good, Def1, Def2, Def3, Def4
bottle      , 0.00, 1.00,	0.99,	1.00,	1.00,	1.00,	
cable       , 0.00, 0.89,	0.89,	0.89,	0.88,	0.88,	
capsule     , 0.00, 0.93,	0.94,	0.94,	0.93,	0.94,	
carpet      , 0.00, 0.79,	0.82,	0.80,	0.80,	0.81,	
grid        , 0.00, 0.52,	0.51,	0.50,	0.51,	0.51,	
hazelnut    , 0.00, 0.97,	0.98,	0.98,	0.99,	0.98,	
leather     , 0.00, 0.99,	0.99,	1.00,	0.99,	0.99,	
metal_nut   , 0.00, 0.90,	0.92,	0.93,	0.93,	0.93,	
pill        , 0.00, 0.86,	0.86,	0.86,	0.86,	0.87,	
screw       , 0.00, 0.82,	0.85,	0.85,	0.86,	0.86,	
tile        , 0.00, 1.00,	1.00,	1.00,	1.00,	1.00,	
toothbrush  , 0.00, 1.00,	0.99,	0.99,	0.99,	1.00,	
transistor  , 0.00, 0.91,	0.91,	0.91,	0.91,	0.91,	
wood        , 0.00, 0.96,	0.98,	0.98,	0.98,	0.98,	
zipper      , 0.00, 0.99,	0.99,	0.99,	0.98,	0.98,	


## 学習する異常データの選別 (Reject)
k-foldでvalidationデータに対してポアソン分布を仮定して，正常と異常の閾値を決定．  
その閾値を用いて，異常データを学習に使用するかしないかを決定．

左のグラフから
1. 全ての異常データを学習したAUCの推移
2. ポアソン分布を仮定した３σで異常として判定されたもののみ学習のAUCの推移
3. 2の正常部分空間を使って３σを基準としてのaccuracyの推移

軸
- 横軸は学習した異常のデータ数
- 1と2の縦軸はAUC
- 3の縦軸はaccuracy

In [18]:
def plot_ac(auclog, title, lim=(0.7, 1), color=None):
    # print(len(auclog))
    # print(auclog[:5])
    
    # plt.figure()
    plt.title(title)
    plt.xlabel("Number of Defective Images")
    plt.ylabel(title)
    plt.ylim(*lim)
    plt.plot(auclog, label="Max = %.3f" % max(auclog), color=color)
    plt.legend()
    # plt.show()

In [19]:
def show_eachresult_reject(classn, n_layer, threshold):
    plt.rcParams['figure.figsize'] = (15.0, 5.0)
    plt.figure()
    
    dirname0 = 'output4all_flatten/' + classn + '._mbn_1.0.{0}.55.{1}'.format(n_layer, threshold)
    auclog0 = loadfile(os.path.join(dirname0, "AUClog*.pcl"))
    plt.subplot(1, 3, 1)
    plot_ac(auclog0, "AUC")
    
    dirname1 = 'output4all_reject/' + classn + '._mbn_1.0.{0}.55.{1}'.format(n_layer, threshold)
    auclog1 = loadfile(os.path.join(dirname1, "AUClog*.pcl"))
    plt.subplot(1, 3, 2)
    plot_ac(auclog1, "AUC with Reject")
    # print("ACC")
    acclog1 = loadfile(os.path.join(dirname1, "ACClog*.pcl"))
    plt.subplot(1, 3, 3)
    plot_ac(acclog1, "Accuracy with Reject", (0, 1), "red")

In [20]:
logdirs0 = sorted(glob.glob('output4all_flatten/*_mbn_*'))
classn = sorted(list(set([d.split('.')[0].split('/')[1] for d in logdirs0])))

# interact(show_eachresult_reject, classn=classn, n_layer=n_layers, threshold=thresholds)
for c in classn:
    print(c)
    show_eachresult_reject(classn=c, n_layer=19, threshold=0.999)
    plt.show()

- Rejectを使用するとAUCの低下を防げているものも存在する
- Accuracyは横ばいか向上

In [21]:
## AUCの最大値の比較

In [22]:
def show_allresult_reject(classn, n_layer, threshold):
    dirname0 = 'output4all_flatten/' + classn
    v0, *_ = get_maxauc(dirname0)
    
    dirname1 = 'output4all_reject/' + classn
    v1, *_ = get_maxauc(dirname1)
    
    print("{:<12}".format(classn), "\t%.3f"%v0, "\t%.3f"%v1, "\t\t%.2f"%(v1-v0))
    
logdirs0 = sorted(glob.glob('output4all_flatten/*_mbn_*'))
classn = sorted(list(set([d.split('.')[0].split('/')[1] for d in logdirs0])))
print("label\t\tauc\taucreject\tdiff")
print("-"*50)
for cl in classn:
    show_allresult_reject(cl, 19, 0.999)

label		auc	aucreject	diff
--------------------------------------------------
