In [1]:
import pwd
from IPython import get_ipython

get_ipython().magic('load_ext autoreload')
get_ipython().magic('autoreload 2')

  get_ipython().magic('load_ext autoreload')
  get_ipython().magic('autoreload 2')


In [2]:
from jax import random
import jax.numpy as jnp
from scipy.io import arff

from src.dbopt.FCNN import FCNN
from src.dbopt.DB_sampler import DecisionBoundarySampler
from src.dbopt.DB_Top_opt import DecisionBoundrayOptimizer

In [3]:
seed = 24
key = random.PRNGKey(seed)

from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

gpu


### Importing the data

In [4]:
data = arff.loadarff('data/column_2C_weka.arff')[0]

def row_void_to_array(entry):
    return jnp.array([entry[i] for i in range(6)])

key, ds_key = random.split(key)

x = jnp.array(list(map(row_void_to_array, data)))
y = jnp.array(list(map(lambda row : 0 if row[6]==b'Normal' else 1, data)))
dataset = random.permutation(ds_key, jnp.concatenate((jnp.expand_dims(y, axis=1), x), axis=1), axis=0)

train_fraction = 8/10
num_training_examples = int(jnp.ceil(dataset.shape[0]*train_fraction))
train_dataset = dataset[:num_training_examples, :]
test_dataset = dataset[num_training_examples:, :]

print("dataset shape : ", dataset.shape)
print("proportion of positives in the dataset : ", jnp.sum(dataset[:, 0])/dataset.shape[0])
print("training set shape : ", train_dataset.shape)
print("test dataset shape : ", test_dataset.shape)

dataset shape :  (310, 7)
proportion of positives in the dataset :  0.67741936
training set shape :  (248, 7)
test dataset shape :  (62, 7)


### Fit a network

In [5]:
model = FCNN(num_neurons_per_layer=[100, 100, 100, 2])
key, init_x_key = random.split(key)
x_init = random.uniform(init_x_key, (6,))
key, init_key = random.split(key)
params = model.init(init_key, x_init)

key, train_key = random.split(key)
params = model.train(train_key, params, train_dataset, 60, lr=0.0001, logs_frequency=10, test_set=test_dataset)
print(f"final training accuracy : {model.accuracy(params, train_dataset)}, final test accuracy : {model.accuracy(params, test_dataset)}")

epoch 0, loss = 404.28289794921875, training accuracy = [0.43548387],  test accuracy = [0.37096775]
epoch 10, loss = 72.06607055664062, training accuracy = [0.71370965],  test accuracy = [0.6935484]
epoch 20, loss = 26.540523529052734, training accuracy = [0.8104839],  test accuracy = [0.7419355]
epoch 30, loss = 17.966552734375, training accuracy = [0.8508064],  test accuracy = [0.80645156]
epoch 40, loss = 14.566466331481934, training accuracy = [0.87903225],  test accuracy = [0.82258064]
epoch 50, loss = 18.227752685546875, training accuracy = [0.87096775],  test accuracy = [0.79032254]
epoch 60, loss = 8.341793060302734, training accuracy = [0.85483867],  test accuracy = [0.7741935]
epoch 70, loss = 13.33534049987793, training accuracy = [0.86693543],  test accuracy = [0.7741935]
epoch 80, loss = 11.52695083618164, training accuracy = [0.87903225],  test accuracy = [0.82258064]
epoch 90, loss = 11.532136917114258, training accuracy = [0.87903225],  test accuracy = [0.79032254]
fina

### Try the regularizer

In [7]:
desired_homology = {0:1}
db_opt = DecisionBoundrayOptimizer(model, params, n_sampling=1000, input_dimension=6,
                                  desired_homology=desired_homology, sampling_epochs=2)

In [8]:
print("test accuracy before db opt : ", model.accuracy(params, test_dataset))
params = db_opt.optimize(5, train_dataset, test_dataset)

test accuracy before db opt :  [0.7580645]
FrozenDict({
    params: {
        layers_0: {
            bias: Array([-1.2677349e+01, -1.4318197e+00,  5.5122309e+00,  6.3009028e+00,
                    1.0594342e+01, -6.1178784e+00, -1.1187552e+01,  4.3522048e+00,
                   -1.4844434e+01, -2.4932404e+01, -9.7647762e-01, -2.0957588e+01,
                   -2.9713926e+01, -3.8438702e-01, -1.1918629e+01, -6.3295703e+00,
                    2.5693724e+01, -1.7415041e+01,  3.7332721e+00, -6.9758701e+00,
                   -1.4607015e+00, -2.4454496e+01, -2.6605446e+00, -1.7161328e+00,
                    8.7304564e+00, -3.8343937e+01,  4.3455772e+00,  3.1817052e+00,
                    1.1340516e+01,  1.2741053e+00, -2.2092075e+01,  9.8906832e+00,
                   -3.1341949e+00, -1.6457214e+01, -2.9254913e+01, -4.9231300e+00,
                    1.2650509e+01, -7.5199366e-01,  1.7099738e+00,  1.8446345e+00,
                    1.2496313e+01, -3.5436034e-02,  5.4330635e-01,  1.7172