In [None]:
import numpy as np
import matplotlib.pyplot as plt

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

def relu(z):
    return np.maximum(0, z)

def deriv_relu(z):
    return (z > 0).astype(int)

def softmax(z):
    exp_z = np.exp(z - np.max(z, axis=1, keepdims=True))
    return exp_z / np.sum(exp_z, axis=1, keepdims=True)


def deriv_softmax(a):
    return a * (1 - a)

def categorical_cross_entropy_loss(pred, targets):
    epsilon = 1e-7
    pred = np.clip(pred, epsilon, 1 - epsilon)
    loss = -np.sum(targets * np.log(pred)) / len(pred)
    return loss

def deriv_categorical_cross_entropy(a, targets):
    return a - targets


def normalize_data(X_train, X_test):
    mean = np.mean(X_train, axis=0)
    std = np.std(X_train, axis=0)
    
    std[std < 1e-8] = 1e-8

    X_train_normalized = (X_train - mean) / std

    X_test_normalized = (X_test - mean) / std
    
    return X_train_normalized, X_test_normalized

def init_param(input_size, output_size, hidden_size, features):
    w1 = np.random.randn(features, input_size) #3072x3072
    b1 = np.random.randn(1, input_size) # 1x3072
    w2 = np.random.randn(features, hidden_size) # 3072x200
    b2 = np.random.randn(1, hidden_size) # 1x200
    w3 = np.random.randn(hidden_size, output_size) # 200x10
    b3 = np.random.randn(1, output_size) # 1x10
    return w1, b1, w2, b2, w3, b3


def forward_prop(w1, b1, w2, b2, w3, b3, X):
    z1 = X.dot(w1) + b1 # mx3072 3072x3072 = mx3072
    a1 = relu(z1) # mx3072
    a1, _ = normalize_data(a1, a1)
    z2 = a1.dot(w2) + b2 # mx3072 x 3072x200 = mx200
    a2 = relu(z2) # mx200
    a2, _ = normalize_data(a2, a2) 
    z3 = a2.dot(w3) + b3 # mx200 200x10 = mx10
    a3 = softmax(z3) # mx10
    return z1, a1, z2, a2, z3, a3
    
def one_hot(Y):
    one_hot_Y = np.zeros((len(Y), np.unique(Y).size))
    one_hot_Y[np.arange(len(Y)), Y] = 1
    return one_hot_Y
    
    
def back_prop(z1, a1, z2, a2, z3, a3, w2, w3, X, Y):
    m = len(Y)
    one_hot_Y = one_hot(Y) # mx10
    dz3 = deriv_categorical_cross_entropy(a3, one_hot_Y) * deriv_softmax(a3) # mx10
    dw3 = (1 / m) * (a2.T).dot(dz3) # 200xm mx10 = 200x10
    db3 = (1 / m) * np.sum(dz3, axis=0) # 1x10
    dz2 = dz3.dot(w3.T) * deriv_relu(z2) # mx10 10x200 * mx50 = mx200
    dw2 = (1 / m) * (a1.T).dot(dz2) # 3072xm mx200 = 3072x200
    db2 = (1 / m) * np.sum(dz2, axis=0) # 1x200
    dz1 = dz2.dot(w2.T) * deriv_relu(z1) # mx20 200x3072 * mx3072 = mx3072
    dw1 = (1 / m) * (X.T).dot(dz1) # 3072xm mx3072 = 3072x3072
    db1 = (1 / m) * np.sum(dz1, axis=0) # 1x3072
    return dw1, db1, dw2, db2, dw3, db3

def update_params(w1, b1, w2, b2, w3, b3, dw1, db1, dw2, db2, dw3, db3, lr):
    w1 -= lr * (dw1 + 0.01 * w1)
    b1 -= lr * (db1 + 0.01 * b1)
    w2 -= lr * (dw2 + 0.01 * w2)
    b2 -= lr * (db2 + 0.01 * b2)
    w3 -= lr * (dw3 + 0.01 * w3)
    b3 -= lr * (db3 + 0.01 * b3)
    return w1, b1, w2, b2, w3, b3

def get_predictions(a):
    return np.argmax(a, axis=1)
    
def get_accuracy(predictions, Y):
    return np.sum(predictions == Y) / len(Y)

def update_learning_rate(initial_learning_rate, epoch, decay = 0.1, epoch_update = 50):
    return initial_learning_rate * decay ** (epoch // epoch_update)

def gradient_descent(X, Y, X_test, Y_test, hidden_size=50, epoch = 200, epoch_update = 10, initial_learning_rate = 0.1, batch_size = 250):
    X, X_test = normalize_data(X, X_test)
    train_samples, features = X.shape
    w1, b1, w2, b2, w3, b3 = init_param(features, len(np.unique(Y)), hidden_size, features)
    
    train_error = []
    train_accuracy = []
    test_accuracy = []
    for i in range(epoch):
        learning_rate = update_learning_rate(initial_learning_rate, (i + 1))
        for j in range(int(train_samples / batch_size)):
            X_train = X[batch_size*j : batch_size*(j+1)]
            Y_train = Y[batch_size*j : batch_size*(j+1)]
            z1, a1, z2, a2, z3, a3 = forward_prop(w1, b1, w2, b2, w3, b3, X_train)
            dw1, db1, dw2, db2, dw3, db3 = back_prop(z1, a1, z2, a2, z3, a3, w2, w3, X_train, Y_train)
            w1, b1, w2, b2, w3, b3 = update_params(w1, b1, w2, b2, w3, b3, dw1, db1, dw2, db2, dw3, db3, learning_rate)
        if (i + 1) % epoch_update == 0:
            print(f"Epoch {i + 1} / {epoch}")
            _, _, _, _, _, a3 = forward_prop(w1, b1, w2, b2, w3, b3, X)
            train_accuracy.append(get_accuracy(get_predictions(a3), Y))
            cel = categorical_cross_entropy_loss(a3, one_hot(Y))
            train_error.append(cel)
            print("Train Error: " + str(cel))
            print("Train Accuracy: " + str(train_accuracy[-1]))
            _, _, _, _, _, a3 = forward_prop(w1, b1, w2, b2, w3, b3, X_test)
            test_accuracy.append(get_accuracy(get_predictions(a3), Y_test))
            print("Test Accuracy: " + str(test_accuracy[-1]))

    # Plot errors
    fig_errors, ax_errors = plt.subplots(figsize=(12, 6))
    ax_errors.plot(range(epoch_update, epoch + 1, epoch_update), train_error, label='Train Error', marker='o', linestyle='-')
    ax_errors.set_xlabel('Epochs')
    ax_errors.set_ylabel('Error')
    ax_errors.set_title('Training Error Over Epochs (learning rate = 0.1)')
    ax_errors.axvline(50, linestyle='--', color='gray', linewidth=1, label='Learning Rate = 0.01')
    ax_errors.axvline(100, linestyle='--', color='gray', linewidth=1, label='Learning Rate = 0.001')
    ax_errors.axvline(150, linestyle='--', color='gray', linewidth=1, label='Learning Rate = 0.0001')
    ax_errors.legend()

    # Plot accuracies in the same figure
    fig_accuracies, ax_accuracies = plt.subplots(figsize=(12, 6))
    ax_accuracies.plot(range(epoch_update, epoch + 1, epoch_update), train_accuracy, label='Train Accuracy', marker='o', linestyle='-', color='blue')
    ax_accuracies.plot(range(epoch_update, epoch + 1, epoch_update), test_accuracy, label='Test Accuracy', marker='o', linestyle='-', color='red')
    ax_accuracies.set_xlabel('Epochs')
    ax_accuracies.set_ylabel('Accuracy')
    ax_accuracies.set_title('Training and Test Accuracies Over Epochs (learning rate = 0.1)')
    ax_errors.axvline(50, linestyle='--', color='gray', linewidth=1, label='Learning Rate = 0.01')
    ax_errors.axvline(100, linestyle='--', color='gray', linewidth=1, label='Learning Rate = 0.001')
    ax_errors.axvline(150, linestyle='--', color='gray', linewidth=1, label='Learning Rate = 0.0001')
    ax_accuracies.legend()

    plt.tight_layout()
    plt.show()
   

            
if __name__ == "__main__":
    train_paths = ('data_batch_1', 'data_batch_2', 'data_batch_3', 'data_batch_4', 'data_batch_5')
    test_path = 'test_batch'

    train_data = {}
    test_data = unpickle(test_path)

    for i in train_paths:
        train_data[i] = unpickle(i)

    train_labels = np.concatenate([train_data[key][b'labels'] for key in train_data.keys()])
    train_data = np.concatenate([train_data[key][b'data'] for key in train_data.keys()])

    test_labels = test_data[b'labels']
    test_data = test_data[b'data']
    gradient_descent(train_data, train_labels, test_data, test_labels)