In [None]:
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder

# Data

In [None]:
mnist = load_digits()

In [None]:
X = mnist.data
y = mnist.target

In [None]:
ohe = OneHotEncoder()
y = ohe.fit_transform(y[:,np.newaxis]).toarray()

In [None]:
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2)

# Algorithm

In [None]:
def predict_probabilities(X, Theta):
    activation = X.dot(Theta)
    numerator = np.exp(activation - activation.max(axis=1)[:,np.newaxis])
    denominator = numerator.sum(axis=0)
    return (numerator / denominator)

def cost(X, y, Theta, epsilon=1e-8):
    P = predict_probabilities(X, Theta)
    return -np.log(P * y  + epsilon).sum() / X.shape[0]

def accuracy(X, y, theta):
    predictions = predict_probabilities(X, theta).argmax(axis=1)
    ground_truth = y.argmax(axis=1)
    n_correct = (predictions == ground_truth).sum()
    return n_correct / X.shape[0]

def gradient(X, y, Theta):
    Grad = np.zeros(Theta.shape)
    P = predict_probabilities(X, Theta)
    
    return -(X.T.dot(y - P))

In [None]:
n_iterations = 10
learning_rate = 0.001
Theta = np.random.randn(X_train.shape[1], y_train.shape[1])

print(f'Epoch {0:2} of {n_iterations}')
print('\tCosts:')
print(f'\t\tTrain={cost(X_train, y_train, Theta):.2f}')
print(f'\t\t  Val={cost(X_val, y_val, Theta):.2f}')
print('\tAccuracies:')
print(f'\t\tTrain={accuracy(X_train, y_train, Theta):.2f}')
print(f'\t\t  Val={accuracy(X_val, y_val, Theta):.2f}')
print()

for i in range(n_iterations):
    Theta = Theta - learning_rate * gradient(X_train, y_train, Theta)
    
    if i % 1 == 0:
        print(f'Epoch {i+1:2} of {n_iterations}')
        print('\tCosts:')
        print(f'\t\tTrain={cost(X_train, y_train, Theta):.2f}')
        print(f'\t\t  Val={cost(X_val, y_val, Theta):.2f}')
        print('\tAccuracies:')
        print(f'\t\tTrain={accuracy(X_train, y_train, Theta):.2f}')
        print(f'\t\t  Val={accuracy(X_val, y_val, Theta):.2f}')
        print()