In [None]:
# TensorFlow and tf.keras
import tensorflow as tf
import keras
import sys
sys.path.insert(0, '../')

# Helper libraries
import numpy as np
import matplotlib.pyplot as plt

%load_ext autoreload
%autoreload 2

print(tf.__version__)

### Import the MNIST dataset/ Preprocess the data

In [None]:
mnist = keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()

In [None]:
num_classes = 10

In [None]:
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

### Train and test the network on MNIST

In [None]:
from keras.models import Sequential
from keras.layers import Dense, Activation, Conv2D, MaxPooling2D, Dropout, Flatten

In [None]:
from adv_util import create_fully_connected

In [None]:
def compile_train_test(x_train, y_train, x_test, y_test, reg, corrupt_func = None):
    n, img_rows, img_cols = x_train.shape
    num_classes = 10
    D = img_rows * img_cols
    x_train_flattened = x_train.reshape(n, D)
    x_test_flattened = x_test.reshape(x_test.shape[0], D)
    input_shape = (D,)
    
    model = create_fully_connected(input_shape = input_shape, num_classes = num_classes, reg = reg)
    model.compile(optimizer="sgd", loss='categorical_crossentropy', metrics=['accuracy'])
    
    print("Reg value:" + str(reg))
    
    """
    #Fit regular data
    model.fit(x_train_flattened, y_train, batch_size=128, epochs=15, verbose=True, validation_split=.1)
    loss_regular, accuracy_regular  = model.evaluate(x_test_flattened, y_test, verbose=False)
    print('Test accuracy on regular' + str(accuracy_regular))
    """
    
    #Fit corrupted data
    x_train_corrupt = corrupt_data(x_train, int(np.round(0.2*x_train.shape[0])), corrupt_func)
    x_train_corrupt_flat = x_train_corrupt.reshape(n, D)
    model.fit(x_train_corrupt_flat, y_train, batch_size=128, epochs=15, verbose=True, validation_split=.1)
    loss_corrupt, accuracy_corrupt  = model.evaluate(x_test_flattened, y_test, verbose=False)
    
    print('Test accuracy on corrupt' + str(accuracy_corrupt))
    
    return loss_regular, accuracy_regular, loss_corrupt, accuracy_corrupt

# Blurring

In [None]:
from mnist_corruption import gaussian_blurring, corrupt_data

In [None]:
x_small = x_train[0:3]

x_corrupted_small = gaussian_blurring(x_small, 2)
x_corrupted_reshaped = x_corrupted_small.reshape((x_corrupted_small.shape[0], 784))

In [None]:
plt.imshow(x_corrupted_small[2])

## Stability and Generalization

In [None]:
from adv_util import create_fully_connected

In [None]:
reg = 0.01
compile_train_test(x_train, y_train, x_test, y_test, reg, gaussian_blurring)