## FLAGS (IMPORTANT)

In [None]:
VALIDATE_NN = True
TRAIN = True
VALIDATE_SVM = True

# change code to use these flags and do the action otherwise load necessary stuff already saved
# also update docstrings

## Constants

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

In [None]:
HIGH_COUNT = 100
LOW_FRAC = 1/64
ZERO_FRAC = 0.5
TEST_COUNT = 100

In [None]:
CENTRE = 4 * torch.ones(7)
CENTRE

In [None]:
LOW_RADIUS = 1.
HIGH_RADIUS = 2.

## Data

In [None]:
from scripts.data.orthants import generate_train_data

In [None]:
X_training, Y_training, orthant_counts = generate_train_data(
    low_count=HIGH_COUNT,
    high_count=HIGH_COUNT,
    low_spread=0,
    high_spread=0,
    low_frac=LOW_FRAC,
    zero_frac=ZERO_FRAC,
    random_state=7357
)
X_training.shape, Y_training.shape

In [None]:
for i in range(128):
    if orthant_counts[i] == 0:
        ZERO_ORTHANT_INDEX = i
ZERO_ORTHANT_INDEX

In [None]:
X_train, X_val, Y_train, Y_val = train_test_split(X_training, Y_training, test_size=0.2, random_state=535)
X_train.shape, X_val.shape, Y_train.shape, Y_val.shape

In [None]:
from scripts.data.orthants import generate_test_data

X_test, Y_test = generate_test_data(TEST_COUNT, random_state=7753)
X_test.shape, Y_test.shape

In [None]:
from scripts.utils import make_dataloader

train_dataloader, val_dataloader = make_dataloader(X_train, Y_train, 32, True), make_dataloader(X_val, Y_val, 32, True)

In [None]:
X_total_0 = X_training[Y_training.squeeze()==0]
X_total_1 = X_training[Y_training.squeeze()==1]

In [None]:
from scripts.data.orthants import find_orthant

def obtain_closest_point_orthant(
    test_point: torch.Tensor,
    pos_points: torch.Tensor,
    neg_points: torch.Tensor
) -> torch.Tensor:
    pos_distances, neg_distances = torch.sqrt(torch.sum((test_point - pos_points)**2, dim=1)), torch.sqrt(torch.sum((test_point - neg_points)**2, dim=1))
    return torch.tensor([find_orthant(pos_points[torch.argmin(pos_distances)]), find_orthant(neg_points[torch.argmin(neg_distances)])])

In [None]:
import pandas as pd

closest_orthants_0 = pd.DataFrame(torch.cat([
    obtain_closest_point_orthant(x, X_total_0, X_total_1).reshape(1, -1) for x in X_test[ZERO_ORTHANT_INDEX][Y_test[ZERO_ORTHANT_INDEX].squeeze() == 0]
]), columns=['closest_positive_orthant', 'closest_negative_orthant'])
closest_orthants_1 = pd.DataFrame(torch.cat([
    obtain_closest_point_orthant(x, X_total_0, X_total_1).reshape(1, -1) for x in X_test[ZERO_ORTHANT_INDEX][Y_test[ZERO_ORTHANT_INDEX].squeeze() == 1]
]), columns=['closest_positive_orthant', 'closest_negative_orthant'])

In [None]:
closest_orthants_0['closest_positive_orthant'].value_counts(), closest_orthants_0['closest_negative_orthant'].value_counts() 

In [None]:
closest_orthants_1['closest_positive_orthant'].value_counts(), closest_orthants_1['closest_negative_orthant'].value_counts() 

## Neural network

In [None]:
device = 'cpu'

In [None]:
from scripts.models import SimpleNN
from scripts.train import train_model
from scripts.metrics import BinaryAccuracy

In [None]:
depths = [1, 2, 3, 4, 5]
widths = [32, 64, 128]
etas = [1e-4, 1e-3, 1e-2]
weight_decays = np.logspace(-5, 5, 11).tolist() 
weight_decays.append(0.0)

In [None]:
import json

if VALIDATE_NN:
    best_depth = None
    best_width = None
    best_eta = None
    best_weight_decay = None
    best_score = -torch.inf
else:
    with open('configs/nn/orthants-single-empty.json', 'r') as f:
        best_config = json.load(f)
    best_depth = best_config['depth']
    best_width = best_config['width']
    best_eta = best_config['eta']
    best_weight_decay = best_config['weight_decay']
    best_score = best_config['score']

In [None]:
import json

total_count = len(depths) * len(widths) * len(etas) * len(weight_decays)
curr_count = 0
EPOCHS = 10

if VALIDATE_NN:
    print(f'Cross-validating across {total_count} models.\n')

    for depth in depths:
        for width in widths:
            for eta in etas:
                for weight_decay in weight_decays:
                    model = SimpleNN(7, hidden_layers=depth, hidden_units=width).to(device)
                    loss_fn = torch.nn.BCELoss()
                    optimizer = torch.optim.Adam(params=model.parameters(), lr=eta, weight_decay=weight_decay)
                    metric = BinaryAccuracy()

                    history = train_model(
                        model=model,
                        train_dataloader=train_dataloader,
                        val_dataloader=val_dataloader,
                        loss_fn=loss_fn,
                        optimizer=optimizer,
                        metric=metric,
                        epochs=EPOCHS,
                        verbose=0,
                        device=device
                    )
                    curr_count += 1
                    score = history['val_score'][-1]
                    print(f'[{curr_count}/{total_count}] Depth: {depth}, width: {width}, lr: {eta}, lambda: {weight_decay} ==> score: {score:.6f}')
                    if score > best_score:
                        best_score = score
                        best_depth = depth
                        best_width = width
                        best_eta = eta
                        best_weight_decay = weight_decay
    best_config = {
        'score': best_score,
        'depth': best_depth,
        'width': best_width,
        'eta': best_eta,
        'weight_decay': best_weight_decay
    }
    with open('configs/nn/orthants-single-empty.json', 'w') as f:
        json.dump(best_config, f)

print(f'\nBest validation score after {EPOCHS} epochs: {best_score:.6f}')
print(f'Best configuration: depth: {best_depth}, width: {best_width}, lr: {best_eta}, lambda: {best_weight_decay}')

In [None]:
best_model_nn = SimpleNN(7, hidden_layers=best_depth, hidden_units=best_width).to(device)

In [None]:
from scripts.utils import EarlyStopping

loss_fn = torch.nn.BCELoss()
optimizer = torch.optim.Adam(params=best_model_nn.parameters(), lr=best_eta, weight_decay=best_weight_decay)
metric = BinaryAccuracy()
early_stop = EarlyStopping(patience=20, min_delta=1e-4)

In [None]:
if TRAIN:
    history = train_model(
        model=best_model_nn,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        loss_fn=loss_fn,
        optimizer=optimizer,
        metric=metric,
        epochs=500,
        early_stopping=early_stop,
        device=device,
        return_models=True
    )
    torch.save(best_model_nn.state_dict(), 'models/orthants-single-empty.pth')
else:
    best_model_nn.load_state_dict(torch.load('models/orthants-single-empty.pth'))

In [None]:
from scripts.utils import plot_train_history

if TRAIN:
    plot_train_history(history)

In [None]:
from scripts.utils import plot_radial_visualization

if TRAIN:
    plot_radial_visualization(
        models=history['models'],
        mp4_save_file_name='radial_1empty',
        orthant_counts=orthant_counts,
        main_orthant=ZERO_ORTHANT_INDEX,
        fps=4
    )

In [None]:
from scripts.test import predict

scores_nn = torch.tensor([
    metric(
        predict(best_model_nn, X_test[i], device),
        Y_test[i]
    ) for i in range(128)
])

In [None]:
scores_nn.mean(), scores_nn[ZERO_ORTHANT_INDEX].mean()

## SVM

In [None]:
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score, make_scorer
from scripts.ntk import NTK
from sklearn.model_selection import GridSearchCV

In [None]:
ntk = NTK(best_model_nn).get_ntk

In [None]:
model_base_ntk = SVC(kernel=ntk, max_iter=int(1e4))
params_ntk = {
    'C': np.logspace(-5, 5, 11)
}

gammas = np.logspace(-5, 5, 11).tolist()
gammas.append('scale')
gammas.append('auto')
model_base_rbf = SVC(kernel='rbf', max_iter=int(1e4))
params_rbf = {
    'C': np.logspace(-5, 5, 11),
    'gamma': gammas
}

scorer = make_scorer(accuracy_score)

In [None]:
if VALIDATE_SVM:
    model_cv_ntk = GridSearchCV(
        estimator=model_base_ntk,
        param_grid=params_ntk,
        scoring=scorer,
        n_jobs=5,
        refit=False,
        verbose=3
    )
    model_cv_ntk.fit(X_train, Y_train.squeeze())
    best_params_ntk = model_cv_ntk.best_params_
    best_score_ntk = max(model_cv_ntk.cv_results_['mean_test_score'])
    best_params_ntk, best_score_ntk

In [None]:
if VALIDATE_SVM:
    model_cv_rbf = GridSearchCV(
        estimator=model_base_rbf,
        param_grid=params_rbf,
        scoring=scorer,
        n_jobs=5,
        refit=False,
        verbose=3
    )
    model_cv_rbf.fit(X_train, Y_train.squeeze())
    best_params_rbf = model_cv_rbf.best_params_
    best_score_rbf = max(model_cv_rbf.cv_results_['mean_test_score'])
    best_params_rbf, best_score_rbf

In [None]:
import json

if VALIDATE_SVM:
    if best_score_rbf >= best_score_ntk:
        best_model_km = SVC(C=best_params_rbf['C'], kernel='rbf', gamma=best_params_rbf['gamma'])
        best_config = {
            'kernel': 'rbf',
            'C': best_params_rbf['C'],
            'gamma': best_params_rbf['gamma']
        }
        with open('configs/svm/orthants-single-empty.json', 'w') as f:
            json.dump(best_config, f)
    else:
        best_model_km = SVC(C=best_params_ntk['C'], kernel=ntk)
        best_config = {
            'kernel': 'ntk',
            'C': best_params_ntk['C']
        }
        with open('configs/svm/orthants-single-empty.json', 'w') as f:
            json.dump(best_config, f)
else:
    with open('configs/svm/orthants-single-empty.json', 'r') as f:
        best_config = json.load(f)
    if best_config['kernel'] == 'rbf':
        best_model_km = SVC(C=best_config['C'], gamma=best_config['gamma'])
    else:
        best_model_km = SVC(kernel=ntk, C=best_config['C'])

In [None]:
best_model_km.fit(X_train, Y_train.squeeze())

In [None]:
preds_train, preds_val = best_model_km.predict(X_train), best_model_km.predict(X_val)
score_train, score_val = accuracy_score(Y_train.squeeze(), preds_train), accuracy_score(Y_val.squeeze(), preds_val)
score_train, score_val

In [None]:
scores_km = np.array([
    accuracy_score(
        best_model_km.predict(X_test[i]),
        Y_test[i].squeeze()
    ) for i in range(128)
])

In [None]:
scores_km.mean(), scores_km[ZERO_ORTHANT_INDEX].mean()

## NN vs SVM

In [None]:
plt.figure(figsize=(13, 6))

plt.subplot(121)
plt.scatter(orthant_counts, scores_nn)
plt.xlabel('Number of points')
plt.ylabel('Binary accuracy')
plt.ylim((0., 1.))
plt.title('NN')

plt.subplot(122)
plt.scatter(orthant_counts, scores_km)
plt.xlabel('Number of points')
plt.ylabel('Binary accuracy')
plt.ylim((0., 1.))
plt.title('SVM')

plt.suptitle('Accuracy per orthant')