In [1]:
import numpy as np
from sklearn.neighbors import NearestNeighbors
from matplotlib import pyplot as plt
import scipy.io
from sklearn.preprocessing import OneHotEncoder
import tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics, callbacks

In [123]:
def to_binary(x):
    return np.where(x > 0, 1.0, 0.0).astype(np.uint8)
        
def relu(x):
    return np.where(x > 0, x, 0.0)

def top_k_ids(D, k):
    Di = np.argpartition(D, -k, axis=1)
    if k < 0:
        Di_top = Di[:, :(-k)]
    else:
        Di_top = Di[:, (-k):]
    Dv_top = np.take_along_axis(D, Di_top, axis=1)
    sorted_top = np.argsort(Dv_top, axis=1)
    return np.take_along_axis(
        Di_top, 
        sorted_top if k < 0 else np.flip(sorted_top, axis=1), 
        axis=1
    )

def query_coded(xc, yc, k):
    D = np.asarray([np.sum(ycv ^ xc, axis=1) for ycv in yc])
    return top_k_ids(D, k)

def overlaps(X, Y):
    assert X.shape == Y.shape
    return np.asarray([len(np.intersect1d(x, y)) for x, y in zip(X, Y)])

def shm(*matrices, **kwargs):
    plt.figure(1, figsize=(25,10))
    for m_id, matrix in enumerate(matrices):
        plt.subplot(len(matrices), 1, m_id+1)
        plt.imshow(np.squeeze(matrix).T, cmap='gray', origin='lower')
        plt.colorbar()

    if kwargs.get("file"):
        plt.savefig(kwargs["file"])
        plt.clf()
    else:
        plt.show()

def quality(proposed, gt, number_of_neighbors):
    return np.mean(overlaps(proposed, gt) / number_of_neighbors)

def quality_coded(xc, yc, gt, k):
    return quality(query_coded(xc, yc, k), gt, k)

def norm(w, factor=1.0):
    for i in range(w.shape[1]):
        w[:, i] = factor * w[:, i] / np.sqrt(np.sum(np.square(w[:,i])))
        
def plot_weights(W):
    input_d = int(np.sqrt(W.shape[0]))
    hidden_d = int(np.sqrt(W.shape[1]))

    W_plot = np.zeros((input_d * hidden_d, input_d * hidden_d))
    i = 0
    for y in range(hidden_d):
        for x in range(hidden_d):
            W_plot[
                y*input_d:(y+1)*input_d,
                x*input_d:(x+1)*input_d
            ] = W[:, i].reshape(28,28)
            i += 1
    return shm(W_plot)

def weights_inner_product(W, p):
    return np.asarray([np.linalg.norm(W[:, ni], ord=p) for ni in range(W.shape[1])])

def train_mlp(
    X, 
    Y, 
    Xt, 
    Yt, 
    hidden_layer_size=None, 
    epochs=300,
    init_learning_rate=0.02,
    verbose=0
):
    def read_mnist_tf(X, Y):
        return (
            tf.data.Dataset.from_tensor_slices((X,Y))
                .shuffle(X.shape[0])
                .batch(100)
        )
    
    def scheduler(epoch):
        return init_learning_rate * (1.0 - epoch / epochs)
        
    train_dataset = read_mnist_tf(X, Y)
    test_dataset = read_mnist_tf(Xt, Yt)
    
    if hidden_layer_size is not None:
        model = Sequential([
            layers.Dense(hidden_layer_size, activation='relu'), 
            layers.Dense(Y.shape[1])
        ])
    else:
        model = Sequential([
            layers.Dense(Y.shape[1])
        ])

    model.compile(
        optimizer=optimizers.SGD(init_learning_rate),
        loss=tf.losses.CategoricalCrossentropy(from_logits=True),
        metrics=['accuracy']
    )
    model.fit(
        train_dataset, epochs=epochs,
        validation_data=test_dataset,
        validation_steps=2,
        callbacks=[
            callbacks.EarlyStopping(monitor='val_loss', patience=10),
            callbacks.LearningRateScheduler(scheduler)
        ],
        verbose=verbose
    )    
    return max(model.history.history["val_accuracy"])

In [124]:
mat = scipy.io.loadmat('mnist_all.mat')

def read_mnist(mat, label):
    num_classes = 10 # 5
    input_dim = 784

    X = np.zeros((0, input_dim))
    Y = np.zeros((0, num_classes))
    for i in range(num_classes):
        data = mat[label + str(i)] # [:500]

        X = np.concatenate((X, data), axis=0)

        Yi = np.zeros((data.shape[0], num_classes))
        Yi[:, i] = 1.0
        Y = np.concatenate((Y, Yi), axis=0)

    X = X / 255.0
    return X.astype(np.float32), Y.astype(np.float32)

X, Y = read_mnist(mat, "train")
Xt, Yt = read_mnist(mat, "test")

In [125]:
train_mlp(X, Y, Xt, Yt, hidden_layer_size=2000, verbose=1)

Epoch 1/300
Epoch 2/300
Epoch 3/300
Epoch 4/300
Epoch 5/300
Epoch 6/300
Epoch 7/300
Epoch 8/300
Epoch 9/300
Epoch 10/300
Epoch 11/300
Epoch 12/300
Epoch 13/300
Epoch 14/300
Epoch 15/300
Epoch 16/300
Epoch 17/300
Epoch 18/300
Epoch 19/300
Epoch 20/300
Epoch 21/300
Epoch 22/300
Epoch 23/300
Epoch 24/300
Epoch 25/300


0.985

In [126]:
init_learning_rate = 0.02
# init_learning_rate = 0.05
num_hidden = 2000
batch_size = 200
prec = 1e-30
delta = 0.4
p = 2.0
k = 2
batch_ids = np.arange(batch_size)
number_of_train_batches = X.shape[0] // batch_size
number_of_test_batches = Xt.shape[0] // batch_size

In [None]:
epochs = 500
W = np.random.normal(0.0, 1.0, (X.shape[1], num_hidden))

Xa = np.zeros((X.shape[0], num_hidden), dtype=np.float32)
Xta = np.zeros((Xt.shape[0], num_hidden), dtype=np.float32)

for epoch in range(epochs):
    permute_ids = np.random.permutation(X.shape[0])
    X = X[permute_ids, :]
    Y = Y[permute_ids, :]
    
    learning_rate = init_learning_rate * (1-epoch / epochs)

    dW_norm = 0.0
    for i in range(number_of_train_batches):
        x = X[i*batch_size:(i+1)*batch_size,:]
        
        a = relu(np.dot(x, W) ** (p-1))
        
        top_ids = top_k_ids(a, k)

        a_deriv = np.zeros((batch_size, num_hidden))
        a_deriv[batch_ids, top_ids[:, 0]] = 1.0
        a_deriv[batch_ids, top_ids[:, k-1]] = -delta

        a_deriv_sum = np.sum(a_deriv * a, 0)

        dW = np.dot(x.T, a_deriv) - np.expand_dims(a_deriv_sum, 0) * W

        denom = np.amax(np.absolute(dW))

        W += learning_rate * dW / np.where(denom > 1e-20, denom, 1e-20)
        dW_norm += np.linalg.norm(dW)

        Xa[i*batch_size:(i+1)*batch_size,:] = a[:]
        
    for i in range(number_of_test_batches):
        xt = Xt[i*batch_size:(i+1)*batch_size,:]
        
        at = relu(np.dot(xt, W) ** (p-1))
        
        Xta[i*batch_size:(i+1)*batch_size,:] = at[:]
    
    if (epoch % (epochs // 100)) == 0 or epoch == epochs-1:
        res = train_mlp(Xa, Y, Xta, Yt)
        print("Epoch {}, |dW| = {:.4f}, min(W) = {:.4f}, max(W) = {:.4f}, |W| = {:.4f}, acc = {:.4f}".format(
            epoch, 
            dW_norm / number_of_train_batches,
            np.min(W),
            np.max(W),
            np.mean(weights_inner_product(W, 2)),
            res
        ))


Epoch 0, |dW| = 13121.6921, min(W) = -4.5353, max(W) = 5.1656, |W| = 27.2882, acc = 0.9800
Epoch 5, |dW| = 5921.7600, min(W) = -4.4206, max(W) = 5.1656, |W| = 21.9708, acc = 0.9850
Epoch 10, |dW| = 3555.7579, min(W) = -4.1871, max(W) = 5.1784, |W| = 16.5983, acc = 0.9850
Epoch 15, |dW| = 2014.4766, min(W) = -4.2032, max(W) = 4.7444, |W| = 11.6244, acc = 0.9900
Epoch 20, |dW| = 1016.9236, min(W) = -3.6790, max(W) = 3.8893, |W| = 7.2762, acc = 0.9700
Epoch 25, |dW| = 442.6131, min(W) = -3.3000, max(W) = 3.3747, |W| = 3.9761, acc = 0.9600
Epoch 30, |dW| = 176.6922, min(W) = -3.3607, max(W) = 2.7104, |W| = 2.1757, acc = 0.9500


In [None]:
plot_weights(W)