In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np

## Data

### Constants

In [None]:
from sklearn.model_selection import train_test_split
from sklearn.utils import resample

In [None]:
from pathlib import Path

data_path = Path('./data/mnist.npz')
if data_path.is_file():
    print('Dataset found. Reading data...')
    data_clump = np.load(data_path)
    X_train_split, Y_train_split, X_test_split, Y_test_split = data_clump['arr_0'], data_clump['arr_1'], data_clump['arr_2'], data_clump['arr_3']
else:
    from keras.datasets import mnist
    print('Dataset missing. Loading data...')
    (X_train_split, Y_train_split), (X_test_split, Y_test_split) = mnist.load_data()
    np.savez_compressed(data_path, X_train_split, Y_train_split, X_test_split, Y_test_split)
X_train_split.shape, Y_train_split.shape, X_test_split.shape, Y_test_split.shape

In [None]:
NUM_SAMPLES = 10000
K = 3

In [None]:
from typing import Tuple

def make_parity_data(
    x: np.array, 
    y: np.array, 
    num_samples: int = NUM_SAMPLES,
    num_cols: int = K
) -> Tuple[torch.Tensor, torch.Tensor]:
    indices = np.random.choice(range(len(x)), (num_samples, num_cols), replace=True)
    x_parity = torch.cat([
        torch.cat([
            torch.tensor(x[indices[i][j]].reshape(1, -1), dtype=torch.float32)/255 for j in range(num_cols)
        ], dim=1) for i in range(num_samples)
    ])
    y_parity = torch.tensor(y[indices], dtype=torch.float32).sum(dim=1).reshape(-1, 1) % 2
    return x_parity, y_parity

### Data generation

In [None]:
X_training, Y_training = make_parity_data(X_train_split, Y_train_split)
X_test, Y_test = make_parity_data(X_test_split, Y_test_split, num_samples=NUM_SAMPLES//5)
X_training.shape, Y_training.shape, X_test.shape, Y_test.shape

In [None]:
X_training_cv, Y_training_cv = resample(X_training, Y_training, replace=False, n_samples=2000)
X_training_cv.shape, Y_training_cv.shape

In [None]:
X_train, X_val, Y_train, Y_val = train_test_split(X_training, Y_training, test_size=0.2)
X_train.shape, X_val.shape, Y_train.shape, Y_val.shape

In [None]:
X_train_cv, X_val_cv, Y_train_cv, Y_val_cv = train_test_split(X_training_cv, Y_training_cv, test_size=0.2)
X_train_cv.shape, X_val_cv.shape, Y_train_cv.shape, Y_val_cv.shape

## Models

### Neural network

In [None]:
device = 'cpu'

In [None]:
from scripts.models import SimpleNN
from scripts.metrics import BinaryAccuracy
from scripts.train import train_model
from scripts.utils import EarlyStopping, make_dataloader

In [None]:
depths = [1, 2, 3, 4, 5]
widths = [16, 32, 64]
weight_decays = torch.logspace(-3, 3, 7)
etas = [1e-4, 1e-3, 1e-2]
batch_sizes = [16, 32, 64]

In [None]:
best_score = 0.0
best_depth = None
best_width = None
best_weight_decay = None
best_eta = None
best_batch_size = None

In [None]:
total_count = len(depths) * len(widths) * len(weight_decays) * len(etas) * len(batch_sizes)
count = 0
EPOCHS = 50

print(f'Cross-validating across {total_count} models.\n')

for depth in depths:
    for width in widths:
        for weight_decay in weight_decays:
            for eta in etas:
                for batch_size in batch_sizes:
                    count += 1
                    model = SimpleNN(input_size=784*K, hidden_layers=depth, hidden_units=width).to(device)
                    loss_fn = torch.nn.BCELoss()
                    optimizer = torch.optim.Adam(params=model.parameters(), lr=eta, weight_decay=weight_decay)
                    metric = BinaryAccuracy()
                    train_dataloader = make_dataloader(X_train_cv, Y_train_cv, batch_size=batch_size, shuffle=True)
                    val_dataloader = make_dataloader(X_val_cv, Y_val_cv, batch_size=batch_size, shuffle=True)

                    history = train_model(
                        model=model,
                        train_dataloader=train_dataloader,
                        val_dataloader=val_dataloader,
                        loss_fn=loss_fn,
                        optimizer=optimizer,
                        metric=metric,
                        epochs=50,
                        verbose=0,
                        device=device
                    )
                    curr_score = history['val_score'][-1]

                    print(f'[{count}/{total_count}] depth={depth}, width={width}, lambda={weight_decay:.5f}, eta={eta}, batch size={batch_size} ===> validation score={curr_score:.6f}')
                    if curr_score > best_score:
                        best_score = curr_score
                        best_depth = depth
                        best_width = width
                        best_weight_decay = weight_decay
                        best_eta = eta
                        best_batch_size = batch_size

print(f'\nValidation complete. Best validation score after {EPOCHS} epochs = {best_score:.6f}')
print(f'Best configuration: depth={best_depth}, width={best_width}, lambda={best_weight_decay:.5f}, eta={best_eta}, batch size={best_batch_size}')

In [None]:
best_model_nn = SimpleNN(input_size=784*K, hidden_layers=best_depth, hidden_units=best_width).to(device)

In [None]:
loss_fn = torch.nn.BCELoss()
optimizer = torch.optim.Adam(params=best_model_nn.parameters(), lr=best_eta, weight_decay=best_weight_decay)
metric = BinaryAccuracy()
early_stopper = EarlyStopping(patience=20, min_delta=1e-4)

In [None]:
train_dataloader = make_dataloader(X_train, Y_train, batch_size=best_batch_size, shuffle=True)
val_dataloader = make_dataloader(X_val, Y_val, batch_size=best_batch_size, shuffle=True)

In [None]:
history = train_model(
    model=best_model_nn,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    loss_fn=loss_fn,
    optimizer=optimizer,
    metric=metric,
    epochs=500,
    early_stopping=early_stopper,
    device=device
)

In [None]:
from scripts.utils import plot_train_history

plot_train_history(history)

In [None]:
from scripts.test import predict

preds_train, preds_val = predict(best_model_nn, X_train, device), predict(best_model_nn, X_val, device)
score_train, score_val = metric(preds_train, Y_train), metric(preds_val, Y_val)
score_train, score_val

In [None]:
preds_test = predict(best_model_nn, X_test, device)
score_test = metric(preds_test, Y_test)
score_test

### SVM

In [None]:
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import accuracy_score, make_scorer
from scripts.ntk import NTK

In [None]:
ntk = NTK(best_model_nn).get_ntk

In [None]:
model_base_ntk = SVC(kernel=ntk)
params_ntk = {
    'C': np.logspace(-3, 3, 7)
}

gammas = np.logspace(-5, 5, 11, base=2).tolist()
gammas.append('scale')
gammas.append('auto')
model_base_rbf = SVC()
params_rbf = {
    'C': np.logspace(-3, 3, 7),
    'gamma': gammas
}

scorer = make_scorer(accuracy_score)

In [None]:
model_cv_ntk = GridSearchCV(estimator=model_base_ntk, param_grid=params_ntk, scoring=scorer, n_jobs=5, refit=False, cv=5, verbose=3)
model_cv_ntk.fit(X_train, Y_train.squeeze())
best_params_ntk = model_cv_ntk.best_params_
best_score_ntk = max(model_cv_ntk.cv_results_['mean_test_score'])
best_params_ntk

In [None]:
model_cv_rbf = GridSearchCV(estimator=model_base_rbf, param_grid=params_rbf, scoring=scorer, n_jobs=5, refit=False, cv=5, verbose=3)
model_cv_rbf.fit(X_train, Y_train.squeeze())
best_params_rbf = model_cv_rbf.best_params_
best_score_rbf = max(model_cv_rbf.cv_results_['mean_test_score'])
best_params_rbf

In [None]:
if best_score_ntk > best_score_rbf:
    best_model_km = SVC(C=best_params_ntk['C'], kernel=ntk, tol=1e-4)
else:
    best_model_km = SVC(C=best_params_rbf['C'], gamma=best_params_rbf['gamma'], tol=1e-4)

In [None]:
best_model_km.fit(X_train, Y_train.squeeze())

In [None]:
preds_train, preds_val = best_model_km.predict(X_train), best_model_km.predict(X_val)
score_train, score_val = accuracy_score(Y_train, preds_train), accuracy_score(Y_val, preds_val)
score_train, score_val

In [None]:
preds_test = best_model_km.predict(X_test)
score_test = accuracy_score(Y_test, preds_test)
score_test