In [None]:
import torch
import torchvision
from torchvision.transforms.transforms import Compose
from tools.training_cycle import fit
from sklearn.metrics import confusion_matrix
from matplotlib import pyplot
import seaborn as sns
from tools.ARCENE_loader import arcene_load
sns.set()

download at https://archive.ics.uci.edu/ml/datasets/Arcene <br>
loads directly into memory, 64 MB<br>
expects '/ARCENE/arcene_valid.labels', '/ARCENE/arcene_valid.data', '/ARCENE/arcene_train.labels', '/ARCENE/arcene_train.data' in folder DATA_PATH <br>

100 instances of 10k attributes in test and train

replaces -1 labels with 1

In [None]:
bs = 10
lr = 0.0001
epochs = 100
INP_SIZE = 10000
OUTP_SIZE = 2

In [None]:
# model selection
from models.mixed1 import Mixed1
from models.mixed2 import Mixed2
from models.vpnn import Vpnn
from models.vpnn_t import Vpnn_t
from models.s_relu import S_ReLU

model = Vpnn(INP_SIZE, OUTP_SIZE, hidden_layers=3,rotations=3,
             chebyshev_M=2, diagonal_M=0.01, svd_u=True)

if torch.cuda.is_available():
    model = model.cuda()

opt = torch.optim.SGD(model.parameters(), lr, momentum=0.9)
loss_funct = torch.nn.CrossEntropyLoss()

In [None]:
DATA_PATH = "data"  # MNIST data folder
FILE_PATH = ''  # path for saving data/model

In [None]:
# by default, data is 920x14, 13 attributes and 14th column labels

# list of functions that operate on data on load - acts on torch tensor
PREPROCESSING = Compose([])


train_ds, valid_ds = arcene_load(DATA_PATH, preprocessing=PREPROCESSING)

train_dl = torch.utils.data.DataLoader(
    train_ds, batch_size=bs, shuffle=True, pin_memory=True
)
valid_dl = torch.utils.data.DataLoader(
    valid_ds, batch_size=bs * 2, pin_memory=True
)

In [None]:
# run training cycle
data=fit(model, loss_funct, train_dl, valid_dl, opt, epochs, one_hot_size=2)

In [None]:
# shows graphs of losses and accuracies in ui
pyplot.plot(data['losses'])
pyplot.title('losses')
pyplot.xlabel('training batch')
pyplot.ylabel('loss')
pyplot.show()

pyplot.plot(data['accs'])
pyplot.title('accuracy on validation data')
pyplot.xlabel('epoch')
pyplot.ylabel('% accuracy')
pyplot.show()

In [None]:
# confusion matrix
mat=0
from tools.helpers import expand, acc
for x, y in valid_dl:
    if torch.cuda.is_available():
        x = x.cuda()
    pred = model(x)
    _, pred = torch.max(pred, 1)
    mat += confusion_matrix(y.numpy(), pred.cpu().detach().numpy(), labels=list(range(2)))
sns.heatmap(mat.T, square=True, annot=True, fmt='d', cbar=False,
           xticklabels='auto', yticklabels='auto')
pyplot.ylim(2,0)
pyplot.xlabel('true label')
pyplot.ylabel('predicted label')
pyplot.show()