In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from collections import Counter
from sklearn.model_selection import train_test_split
from preprocess_utils import sample_per_class, imbalance_train, get_dataset

import tensorflow as tf
from tensorflow import keras
from result_utils import plot_curve, Result
import absl.logging
absl.logging.set_verbosity(absl.logging.ERROR)

In [None]:
# Train dataset
image_size = 128
batch_size = 32
cwru_path = 'Dataset/pywt_speed/1'
features, labels, class_names = get_dataset(cwru_path, image_size)

In [None]:
# Test dataset
test_path = 'Dataset/pywt_speed/0'
x_test, y_test, _ = get_dataset(test_path, image_size)

In [None]:
'''Method 1 for Fault type classification'''
# Prepare equally sized dataset and split into 3 subsets 
x_resampled, y_resampled = sample_per_class(features, labels, 800)
x_train, X_test, y_train, Y_test = train_test_split(x_resampled, y_resampled, train_size=0.6, random_state=7, stratify=y_resampled)
x_val, x_test, y_val, y_test = train_test_split(X_test, Y_test, test_size=0.5, random_state=7, stratify=Y_test)

# Create imbalance training dataset with 1% of original size
x_train, y_train = imbalance_train(x_train, y_train, 0.05) 

# Show data distribution for each class
Counter(y_train), Counter(y_val), Counter(y_test)

In [None]:
'''Method 2 For Speed Domain'''
# Prepare equally sized dataset and split into 3 subsets 
x_resampled, y_resampled = sample_per_class(features, labels, 800)
x_train, x_val, y_train, y_val = train_test_split(x_resampled, y_resampled, train_size=0.8, random_state=7, stratify=y_resampled)

# Show data distribution for each class
Counter(y_train), Counter(y_val), Counter(y_test)

In [None]:
'''Method 3'''
# 1. Prepare equally sized dataset
# 2. Turn into imbalanced dataset
# 3. Split into 3 subsets
x_resampled, y_resampled = sample_per_class(features, labels, 4700)
x_resampled, y_resampled = imbalance_train(x_resampled, y_resampled, 0.01)
x_train, X_test, y_train, Y_test = train_test_split(x_resampled, y_resampled, train_size=0.6, random_state=7, stratify=y_resampled)
x_val, x_test, y_val, y_test = train_test_split(X_test, Y_test, test_size=0.5, random_state=7, stratify=Y_test)
(Counter(y_train), Counter(y_val), Counter(y_test))

In [None]:
'''Add fake data'''
def augmentation(x_real, y_real, fake_root_dir, image_size, sample_size):
    '''Augment imbalanced dataset with synthetic samples.'''
    fake_x, fake_y, _ = get_dataset(fake_root_dir, image_size)
    fx_resampled, fy_resampled = sample_per_class(fake_x, fake_y, sample_size)
    fx_reshaped = fx_resampled.reshape((-1, image_size*image_size*3))

    x_reshaped = x_real.reshape((-1, image_size*image_size*3))

    augmented_x = np.concatenate((x_reshaped, fx_reshaped))
    augmented_y = np.concatenate((y_real, fy_resampled))
    
    augmented_x = augmented_x.reshape((len(augmented_x), image_size, image_size, 3))
    print(f'Augmented dataset distribution: {Counter(augmented_y)}')
    return augmented_x, augmented_y

x_train, y_train = augmentation(x_train, y_train, 'Dataset/fake_cwt', image_size, 400)

In [None]:
# # Visualize dataset
# plt.figure(figsize=(10, 10))
# for images, labels in train_ds.take(1):
#   for i in range(9):
#     ax = plt.subplot(3, 3, i + 1)
#     plt.imshow(images[i])
#     plt.title(labels[i].numpy())
#     plt.axis("off")

In [None]:
# Model 1: ConvNet
model = keras.Sequential(
    [
        keras.layers.Conv2D(32, (3, 3), input_shape=(128, 128, 3),
                            padding="same"),
        keras.layers.ReLU(),
        keras.layers.MaxPooling2D(),
        keras.layers.Conv2D(64, (3, 3), padding="same"),
        keras.layers.ReLU(),
        keras.layers.MaxPooling2D(),
        # keras.layers.Conv2D(96, (3, 3), padding="same"),
        # keras.layers.ReLU(),
        # keras.layers.MaxPooling2D(),
        keras.layers.Conv2D(128, (3, 3), padding="same"),
        keras.layers.ReLU(),
        keras.layers.GlobalAveragePooling2D(),      
        keras.layers.Dropout(0.2),
        keras.layers.Dense(4, activation='softmax'),
    ]
)
# model.summary()

model.compile(
    optimizer='adam',
    loss=keras.losses.SparseCategoricalCrossentropy(),
    metrics=['accuracy']
)

checkpoint_path = "./Results/cwru/cwt/hp1_400"

# def myprint(s):
#     with open(f'{checkpoint_path}/summary.txt','a') as f:
#         print(s, file=f)

# model.summary(print_fn=myprint)

checkpoint_callback = keras.callbacks.ModelCheckpoint(
    filepath=f'{checkpoint_path}/trained',
    save_weights_only=False,
    monitor="val_loss",
    mode="min",
    save_best_only=True,
)

early_callback = keras.callbacks.EarlyStopping(
    monitor='val_loss',
    mode='min',
    patience=5,
    )

cb = [
    checkpoint_callback, 
    # early_callback
    ]

history = model.fit(x_train, y_train, batch_size=64, epochs=30, validation_data=(x_val, y_val), callbacks=cb)

In [None]:
history_df = pd.DataFrame(history.history)
history_csv = f'{checkpoint_path}/history.csv'
with open(history_csv, mode='w') as f:
    history_df.to_csv(f)

In [None]:
plot_curve(history, checkpoint_path) 

In [None]:
# checkpoint_path = 'Results/cwru/cwt/imbalanced_50'
loaded_model = keras.models.load_model(f'{checkpoint_path}/trained')

In [None]:
# Inference with testing dataset
result = Result(x_test, y_test, class_names, loaded_model, checkpoint_path, 'hp0_400')
result.write_report()
result.plot_matrix()