In [1]:
import os
import sys
import numpy as np
import tensorflow as tf
from keras import datasets
import matplotlib.pyplot as plt

sys.path.append(os.getcwd() + "/../")

from bfcnn import BFCNN, collage

In [2]:
# ==============================================================================

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
tf.compat.v1.disable_eager_execution()
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

# ==============================================================================

In [3]:
# get dataset 
(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()
x_train = x_train.astype(np.float32)
x_train = np.expand_dims(x_train, axis=3)
x_test = x_test.astype(np.float32)
x_test = np.expand_dims(x_test, axis=3)

In [4]:
NO_LAYERS = 5
MIN_STD = 0.1
MAX_STD = 30
EPOCHS = 10
LR_DECAY = 0.8
LR_INITIAL = 0.1
BATCH_SIZE = 32
CLIP_NORMAL = 1.0
INPUT_SHAPE = (28, 28, 1)
PRINT_EVERY_N_BATCHES = 1000

In [5]:
# build model
model = \
    BFCNN(
        input_dims=INPUT_SHAPE, 
        no_layers=NO_LAYERS)

In [None]:
# train dataset
trained_model, history = \
    BFCNN.train(
        model=model, 
        input_dims=INPUT_SHAPE,
        dataset=x_train,
        epochs=EPOCHS,
        batch_size=BATCH_SIZE,
        min_noise_std=MIN_STD,
        max_noise_std=MAX_STD,
        lr_initial=LR_INITIAL,
        lr_decay=LR_DECAY,
        print_every_n_batches=PRINT_EVERY_N_BATCHES)



2021-03-26 13:22:20,442 INFO model.py:train:330] begin training


Epoch 1/10




Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
 272/1875 [===>..........................] - ETA: 3:39 - batch: 135.5000 - size: 32.0000 - loss: 1.1680 - mae_loss: 1.1680

In [None]:
# summarize history for loss
plt.figure(figsize=(13,5))
plt.plot(history.history["loss"])
plt.title("model loss")
plt.ylabel("loss")
plt.xlabel("epoch")
plt.legend(["train"], loc="upper right")
plt.show()

In [None]:
sample = x_test[100:200,:,:,:]
sample = sample + np.random.normal(0.0, MAX_STD/2, sample.shape)
sample = np.clip(sample, 0.0, 255.0)
results = trained_model.model.predict(sample)

plt.figure(figsize=(15,15))
plt.subplot(1, 3, 1)
plt.imshow(collage(sample), cmap="gray_r") 
plt.subplot(1, 3, 2)
plt.imshow(collage(results), cmap="gray_r") 
plt.subplot(1, 3, 3)
plt.imshow(collage(np.abs(sample - results)), cmap="gray_r") 
plt.show() 