# SynAug

In [None]:
import os
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split, StratifiedKFold

import tensorflow as tf
from tensorflow.keras import layers
import tensorflow_addons as tfa
import tensorflow_hub as hub

from glob import glob
from tqdm import tqdm

import cv2
import gc

import argparse
import wandb
from wandb.keras import WandbCallback
wandb.init(project="DACON_235894", name="SynAug")

parser = argparse.ArgumentParser(description='SynAug')
parser.add_argument('--resize_size', default=224, type=int)
parser.add_argument('--optimizer', default="adam", type=str) # adam or sgd
parser.add_argument('--learning_rate', default=0.0003, type=float)
parser.add_argument('--label_smoothing', default=0, type=float) # 0 or 0.1
parser.add_argument('--batch_size', default=32, type=int)
parser.add_argument('--epochs', default=100, type=int)
parser.add_argument('--validation_split', default=0.2, type=float)
parser.add_argument('--seed', default=1011, type=int)
args = parser.parse_args('')

wandb.config.update(args)

resize_size=args.resize_size
BATCH_SIZE=args.batch_size
EPOCHS=args.epochs
VALIDATION_SPLIT=args.validation_split
SEED=args.seed

if args.optimizer == "adam":
    lr = tf.keras.optimizers.schedules.CosineDecay(args.learning_rate, decay_steps=1000)
    optim = tf.keras.optimizers.Adam(learning_rate=lr)
elif args.optimizer == "sgd":
    optim = f.keras.optimizers.SGD(learning_rate=args.learning_rate, momentum=0.9)

loss_function = tf.keras.losses.CategoricalCrossentropy(label_smoothing=args.label_smoothing)

def set_seeds(seed=SEED):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)

set_seeds()

## Preprocessing

In [None]:
def img_load(path):
    img = cv2.imread(path)[:,:,::-1]
    img = tf.image.central_crop(img, 0.9).numpy()
    img = cv2.resize(img, (resize_size, resize_size), cv2.INTER_AREA)
    return img

train_png = sorted(glob('raw_data/train/*.png'))
test_png = sorted(glob('raw_data/test/*.png'))

train_imgs = [img_load(m) for m in tqdm(train_png)]
test_imgs = [img_load(n) for n in tqdm(test_png)]

train_imgs = np.array(train_imgs)
test_imgs = np.array(test_imgs)

train_y = pd.read_csv("raw_data/train_df.csv")

train_labels = train_y["label"]
label_unique = sorted(np.unique(train_labels))
label_unique = {key : value for key, value in zip(label_unique, range(len(label_unique)))}

train_imgs.shape, train_labels.shape, test_imgs.shape, train_imgs.dtype

### Augmentation : synthetic

In [None]:
def img_load_gray(path):
    img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
    img = cv2.resize(img, (resize_size, resize_size), cv2.INTER_AREA)
    return img

gt_png = sorted(glob('gt_data/*.png'))
gt_imgs = [img_load_gray(n) for n in tqdm(gt_png)]
gt_imgs = np.array(gt_imgs)[:, :, :, np.newaxis]

gt_y = train_y.set_index("file_name").loc[os.listdir('gt_data/')].reset_index()

mask_imgs = np.where(gt_imgs!=0, True, False)
unmask_imgs = np.where(gt_imgs==0, True, False)

bad_imgs = train_imgs[gt_y["index"]]
bad_imgs *= mask_imgs

good_imgs = train_imgs[train_y["state"]=="good"]
good_y = train_y[train_y["state"]=="good"]

In [None]:
syn_imgs=[]
syn_labels=[]

for c in tqdm(train_y["class"].unique()):
    
    if c == "cable":
        
        temp_good_y = good_y[good_y["class"]==c]
        temp_bad_y = gt_y[(gt_y['class'] == c) & (gt_y['label'] != 'cable-cable_swap')]

        good_len = len(temp_good_y)
        bad_len = len(temp_bad_y)

        temp_good_imgs = train_imgs[temp_good_y["index"]]
        temp_unmask_imgs = unmask_imgs[(gt_y['class'] == c) & (gt_y['label'] != 'cable-cable_swap')]
        temp_bad_imgs = bad_imgs[(gt_y['class'] == c) & (gt_y['label'] != 'cable-cable_swap')]

        temp_good_imgs = np.repeat(temp_good_imgs, bad_len, axis=0)
        temp_unmask_imgs = np.tile(temp_unmask_imgs, (good_len, 1, 1, 1))
        temp_bad_imgs = np.tile(temp_bad_imgs, (good_len, 1, 1, 1))
        temp_bad_labels = np.tile(temp_bad_y["label"].values, good_len)

        temp_good_imgs *= temp_unmask_imgs

        syn_imgs.append(temp_good_imgs + temp_bad_imgs)
        syn_labels.append(temp_bad_labels)
    
    elif c == "metal_nut":
        
        temp_good_y = good_y[good_y["class"]==c]
        temp_bad_y = gt_y[(gt_y['class'] == c) & (gt_y['label'] != 'metal_nut-flip')]

        good_len = len(temp_good_y)
        bad_len = len(temp_bad_y)

        temp_good_imgs = train_imgs[temp_good_y["index"]]
        temp_unmask_imgs = unmask_imgs[(gt_y['class'] == c) & (gt_y['label'] != 'metal_nut-flip')]
        temp_bad_imgs = bad_imgs[(gt_y['class'] == c) & (gt_y['label'] != 'metal_nut-flip')]

        temp_good_imgs = np.repeat(temp_good_imgs, bad_len, axis=0)
        temp_unmask_imgs = np.tile(temp_unmask_imgs, (good_len, 1, 1, 1))
        temp_bad_imgs = np.tile(temp_bad_imgs, (good_len, 1, 1, 1))
        temp_bad_labels = np.tile(temp_bad_y["label"].values, good_len)

        temp_good_imgs *= temp_unmask_imgs

        syn_imgs.append(temp_good_imgs + temp_bad_imgs)
        syn_labels.append(temp_bad_labels)
        
    elif c != "screw":
        
        temp_good_y = good_y[good_y["class"]==c]
        temp_bad_y = gt_y[gt_y["class"]==c]

        good_len = len(temp_good_y)
        bad_len = len(temp_bad_y)

        temp_good_imgs = train_imgs[temp_good_y["index"]]
        temp_unmask_imgs = unmask_imgs[gt_y["class"]==c]
        temp_bad_imgs = bad_imgs[gt_y["class"]==c]

        temp_good_imgs = np.repeat(temp_good_imgs, bad_len, axis=0)
        temp_unmask_imgs = np.tile(temp_unmask_imgs, (good_len, 1, 1, 1))
        temp_bad_imgs = np.tile(temp_bad_imgs, (good_len, 1, 1, 1))
        temp_bad_labels = np.tile(temp_bad_y["label"].values, good_len)

        temp_good_imgs *= temp_unmask_imgs

        syn_imgs.append(temp_good_imgs + temp_bad_imgs)
        syn_labels.append(temp_bad_labels)

syn_imgs = np.vstack(syn_imgs)
syn_labels = np.hstack(syn_labels)

syn_imgs.shape, syn_labels.shape

### Augmentation : rotate

In [None]:
def aug_rotate(imgs, labels, times):
    aug_imgs = []
    for i in range(times):
        if i==0:
            aug_imgs.append(tfa.image.rotate(imgs, tf.constant(np.pi),
                                             fill_mode="nearest"))
        else:
            aug_imgs.append(tfa.image.rotate(imgs, tf.constant(np.pi/i),
                                             fill_mode="nearest"))
    aug_labels = np.tile(labels, times)
    return np.vstack(aug_imgs), aug_labels

In [None]:
cable_imgs = train_imgs[train_y["label"]=='cable-cable_swap']
cable_labels = train_labels[train_y["label"]=='cable-cable_swap']

cable_imgs, cable_labels = aug_rotate(cable_imgs, cable_labels, 180)

cable_imgs.shape, cable_labels.shape, cable_imgs.dtype

In [None]:
metal_imgs = train_imgs[train_y["label"]=='metal_nut-flip']
metal_labels = train_labels[train_y["label"]=='metal_nut-flip']

metal_imgs, metal_labels = aug_rotate(metal_imgs, metal_labels, 90)

metal_imgs.shape, metal_labels.shape, metal_imgs.dtype

In [None]:
screw_imgs = train_imgs[(train_y['state'] != "good") & (train_y['class'] == 'screw')]
screw_labels = train_labels[(train_y['state'] != "good") & (train_y['class'] == 'screw')]

screw_imgs, screw_labels = aug_rotate(screw_imgs, screw_labels, 90)

screw_imgs.shape, screw_labels.shape, screw_imgs.dtype

In [None]:
tooth_imgs = train_imgs[train_y["label"]=='toothbrush-good']
tooth_labels = train_labels[train_y["label"]=='toothbrush-good']

tooth_imgs, tooth_labels = aug_rotate(tooth_imgs, tooth_labels, 18)

tooth_imgs.shape, tooth_labels.shape, tooth_imgs.dtype

In [None]:
good_imgs = train_imgs[train_y["state"]=='good']
good_labels = train_labels[train_y["state"]=='good']

good_imgs, good_labels = aug_rotate(good_imgs, good_labels, 4)

good_imgs.shape, good_labels.shape, good_imgs.dtype

### Augmentation Dataset

In [None]:
aug_imgs = np.concatenate((train_imgs, syn_imgs, good_imgs,
                           screw_imgs, metal_imgs, cable_imgs, tooth_imgs), axis=0)
aug_labels=np.concatenate((train_labels, syn_labels, good_labels,
                           screw_labels, metal_labels, cable_labels, tooth_labels), axis=0)

aug_imgs.shape, aug_labels.shape

In [None]:
def sampling_func(data):
    N = len(data)
    sample_n = 200
    sample = data.take(np.random.permutation(N)[:sample_n])
    return sample

sample_y=pd.DataFrame(aug_labels, columns=["label"]).groupby('label', group_keys=False).apply(sampling_func)

sample_imgs = aug_imgs[sample_y.index]
sample_labels = aug_labels[sample_y.index]

sample_imgs.shape, sample_labels.shape, sample_imgs.dtype

In [None]:
del train_imgs, train_y, train_labels, gt_imgs, gt_y,
del mask_imgs, unmask_imgs, bad_imgs, good_imgs, good_y, good_labels
del temp_good_imgs, temp_good_y, temp_unmask_imgs
del temp_bad_imgs, temp_bad_y, temp_bad_labels
del syn_imgs, screw_imgs, metal_imgs, cable_imgs, tooth_imgs
del syn_labels, screw_labels, metal_labels, cable_labels, tooth_labels
del aug_imgs, aug_labels, sample_y
gc.collect()

## Training

In [None]:
sample_imgs = sample_imgs.astype("float32")

X_train, X_val, y_train, y_val = train_test_split(sample_imgs, sample_labels,
                                                  test_size=VALIDATION_SPLIT, random_state=SEED, stratify=sample_labels)

y_train = [label_unique[k] for k in y_train]
y_train = np.array(y_train)

y_val = [label_unique[k] for k in y_val]
y_val = np.array(y_val)

y_train=tf.keras.utils.to_categorical(y_train)
y_val=tf.keras.utils.to_categorical(y_val)

X_train.shape, X_val.shape, y_train.shape, y_val.shape

In [None]:
augmentation = tf.keras.Sequential([
        layers.experimental.preprocessing.RandomFlip("vertical"),
        layers.experimental.preprocessing.RandomCrop(int(resize_size*0.9), int(resize_size*0.9)),
        layers.experimental.preprocessing.Resizing(resize_size, resize_size),
        layers.experimental.preprocessing.RandomRotation(0.5, fill_mode='constant', fill_value=0.0),
])

train_ds = (
    tf.data.Dataset.from_tensor_slices((X_train, y_train))
    .shuffle(len(X_train))
    .batch(BATCH_SIZE)
    .map(lambda x, y: (augmentation(x, training=True), y),
         num_parallel_calls=tf.data.experimental.AUTOTUNE)
    .prefetch(tf.data.experimental.AUTOTUNE)
)

val_ds = (
    tf.data.Dataset.from_tensor_slices((X_val, y_val))
    .batch(BATCH_SIZE)
    .prefetch(tf.data.experimental.AUTOTUNE)
)

In [None]:
def create_encoder():
        
    sampling_model = tf.keras.applications.EfficientNetB0(
        include_top=False,
        weights="imagenet",
        pooling='avg',
    )

    inputs = tf.keras.Input(shape=(resize_size, resize_size, 3))
    outputs = sampling_model(inputs)
    model = tf.keras.Model(inputs=inputs, outputs=outputs, name="encoder")
        
    return model

encoder = create_encoder()
encoder.summary()

In [None]:
def create_classifier(encoder):

    inputs = tf.keras.Input(shape=(resize_size, resize_size, 3))
    features = encoder(inputs)
    outputs = layers.Dense(y_train.shape[1], activation="softmax")(features)
    model = tf.keras.Model(inputs=inputs, outputs=outputs, name="classifier")

    model.compile(
        optimizer=optim,
        loss=loss_function,
        metrics=tfa.metrics.F1Score(num_classes=y_train.shape[1], average="macro")
    )
        
    return model

classifier = create_classifier(encoder)
classifier.summary()

In [None]:
checkpoint_path=f"load_model/{parser.description}"

callback = [
    tf.keras.callbacks.EarlyStopping(
        monitor='val_f1_score',
        patience=5,
        mode="max",
        restore_best_weights=True,
    ),
    tf.keras.callbacks.ModelCheckpoint(
        checkpoint_path,
        monitor="val_f1_score",
        save_best_only=True,
        save_weights_only=True,
        mode="max",
    )
]

history=classifier.fit(
    train_ds,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    callbacks=[callback, WandbCallback()],
    validation_data=val_ds,
)

In [None]:
acc = history.history['f1_score']
val_acc = history.history['val_f1_score']

loss=history.history['loss']
val_loss=history.history['val_loss']

plt.plot(acc, label='Training Macro-F1')
plt.plot(val_acc, label='Validation Macro-F1')
plt.legend(loc='lower right')
plt.title('Training and Validation Macro-F1')
plt.show()

plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

In [None]:
classifier.load_weights(checkpoint_path)

## Inference

In [None]:
test_ds = (
    tf.data.Dataset.from_tensor_slices((test_imgs))
    .batch(BATCH_SIZE)
    .prefetch(tf.data.experimental.AUTOTUNE)
)

In [None]:
pred_prob = classifier.predict(test_ds)
f_pred = np.argmax(pred_prob, axis=1)
label_decoder = {val:key for key, val in label_unique.items()}
f_result = [label_decoder[result] for result in f_pred]

pd.Series(f_result).value_counts()

In [None]:
submission = pd.read_csv("raw_data/sample_submission.csv")
submission["label"] = f_result
submission.to_csv(f"{parser.description}.csv", index=False)