### 在資料集上測試 (MVtech)
資料集: 
> THE MVTEC ANOMALY DETECTION DATASET (MVTEC AD)
> https://www.mvtec.com/company/research/datasets/mvtec-ad


<img src="https://www.mvtec.com/fileadmin/Redaktion/mvtec.com/company/research/datasets/dataset_overview_large.png" alt="drawing" width="400"/>

In [None]:
!wget https://github.com/TA-aiacademy/course_3.0/releases/download/CVCNN_Data/unet.py

In [None]:
import os
import cv2

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import Input, Model, Sequential, layers
from tensorflow.keras.preprocessing.image import array_to_img, img_to_array
from functools import partial
from IPython.display import display

from sklearn.model_selection import train_test_split

from unet import UNet

import albumentations as A
from albumentations import DualTransform
from typing_extensions import Concatenate
from typing_extensions import Concatenate

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
BATCH_SIZE = 16
NUM_LABELS = 1
WIDTH = 128
HEIGHT = 128

In [None]:
unet = UNet(num_class=1, padding='same', retain_dim=(WIDTH, HEIGHT))
x = np.zeros((1, HEIGHT, WIDTH, 3), dtype=np.float32)
y_pred = unet(x)
print(y_pred.shape)
del unet

In [None]:
cmd = '''mkdir ./data
wget -q https://www.mydrive.ch/shares/38536/3830184030e49fe74747669442f0f282/download/420937454-1629951595/capsule.tar.xz -O data/MVtech-capsule.tar.xz
tar -Jxf data/MVtech-capsule.tar.xz --overwrite --directory ./data
'''
if not os.path.isdir('./data/capsule'):
    for i in cmd.split('\n'):
        os.system(i)

In [None]:
item = 'capsule'
path = os.getcwd()
img_dir = f'{path}/data/{item}/test/scratch/'
print(len(os.listdir(img_dir)))
anno_dir = f'{path}/data/{item}/ground_truth/scratch/'

### 取得 image list
輸出: data_dic (字典)
- key: X_train, X_test, y_train, y_test

In [None]:
mask_dir = lambda anno_dir, X_lis: [anno_dir+i.split('.')[0]+'_mask.png'
                                    for i in X_lis]

imgs_path_list = sorted(os.listdir(img_dir))

size = 1/len(imgs_path_list)
size = 0.1
train, test = train_test_split(imgs_path_list, train_size=size, random_state=0)

key = 'X_train, X_test, y_train, y_test'.split(', ')

lis = []
for number in [train, test]:
    lis.append([img_dir+i for i in number])  # X
for number in [train, test]:
    lis.append(mask_dir(anno_dir, number))  # y
data_dic = dict(zip(key, lis))
data_dic['X_train']

### Build dataset

In [None]:
def data_generater(imgs_path_list, anno_path_list, img_transform=None):
    for img_path in imgs_path_list:
        img_path = str(bytes.decode(img_path))
        file_name = img_path.split('/')[-1].split('.')[0]
        mask_path = [str(bytes.decode(i))
                     for i in anno_path_list
                     if str(i).__contains__(file_name)][0]
        
        image = cv2.imread(img_path)
        image = np.array(image, dtype=np.float32)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)  # (1000, 1000)
        mask = np.array(mask, dtype=np.float32)

        if img_transform == 1:
            transformed = transform(image=image, mask=mask)
            image = transformed['image']
            mask = transformed['mask']
        elif img_transform == 2:
            transformed = target_transform(image=image, mask=mask)
            image = transformed['image']
            mask = transformed['mask']

        image = tf.constant(image)

        mask = tf.constant(mask)
        mask = tf.expand_dims(mask, axis=-1)
        mask = tf.image.convert_image_dtype(mask, tf.float32)
        mask = mask/255.
        yield image, mask

### 使用 albumentations 進行資料擴增

In [None]:
# https://albumentations.ai/docs/getting_started/mask_augmentation/

transform = A.Compose([
    A.CenterCrop(300, 900, p=0.5),
    A.HorizontalFlip(p=0.5),
    A.Rotate((-30, 30), interpolation=0),
    A.ToFloat(always_apply=True),
    A.Resize(WIDTH, HEIGHT),
])

target_transform = A.Compose([
    A.ToFloat(always_apply=True),
    A.Resize(WIDTH, HEIGHT),
])

In [None]:
# 在這邊會強制對所有不滿 BATCH_SIZE 的訓練資料做數量上的匹配
if len(data_dic['X_train']) < BATCH_SIZE:
    lis = data_dic['X_train']
    lis = [lis[i % len(lis)] for i in range(BATCH_SIZE)]
    data_dic['X_train'] = lis

train_ds = tf.data.Dataset.from_generator(
    data_generater,
    output_signature=(tf.TensorSpec(shape=(None, None, 3),
                                    dtype=tf.float32),
                      tf.TensorSpec(shape=(None, None, 1), dtype=tf.float32)),
    args=[data_dic['X_train'], data_dic['y_train'], 1])

test_ds = tf.data.Dataset.from_generator(
    data_generater, output_signature=(tf.TensorSpec(shape=(None, None, 3),
                                                    dtype=tf.float32),
                                      tf.TensorSpec(shape=(None, None, 1),
                                                    dtype=tf.float32)),
    args=[data_dic['X_test'], data_dic['y_test'], 2])

dataset_train = train_ds.batch(BATCH_SIZE).prefetch(buffer_size=32)
dataset_test = test_ds.batch(BATCH_SIZE).prefetch(buffer_size=32)

In [None]:
for batch in dataset_train.take(1):
    print(f'{len(batch)}, {batch[0].shape}, {batch[1].shape}')

#### 補充: 如何從 dataset 抽 image, mask 出來

In [None]:
def show_image_mask(*img_list, split=False):
    plt.figure(figsize=(10, 3))
    for i, img in enumerate(list(img_list), 1):
        plt.subplot(1, len(img_list), i)
        img = img - img.min()
        if len(np.shape(img)) == 2:
            plt.imshow(img, cmap='gray')
        else:
            img = img.astype(np.int32)
            plt.imshow(img)
    plt.show()
    plt.close()

In [None]:
for batch in dataset_train.take(1):
    for image, mask in zip(batch[0], batch[1]):
        show_image_mask(image.numpy(), mask.numpy().squeeze())

# 訓練模型

In [None]:
model = UNet(num_class=1, padding='same', retain_dim=(WIDTH, HEIGHT))

loss_fn = tf.nn.sigmoid_cross_entropy_with_logits
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

model.compile(optimizer=optimizer, loss=loss_fn)
model.fit(dataset_train, batch_size=16, epochs=500)
# model.save_weights('UNet')

In [None]:
for batch in dataset_train.take(1):
    for image, mask in zip(batch[0], batch[1]):
        pre = model(np.expand_dims(image, 0)).numpy()
        pre[pre>=0.2] = 1
        pre[pre<0.2] = 0
        show_image_mask(image.numpy(), mask.numpy().squeeze(), pre.squeeze())