In [None]:
import torch
import torchvision
from torchvision.transforms.transforms import Compose
from tools.training_cycle import fit
from sklearn.metrics import confusion_matrix
from matplotlib import pyplot
import seaborn as sns
sns.set()

In [None]:
bs = 100
lr = 0.01
epochs = 100
INP_SIZE = 784
OUTP_SIZE = 10

In [None]:
# model selection
from models.mixed1 import Mixed1
from models.mixed2 import Mixed2
from models.vpnn import Vpnn
from models.vpnn_t import Vpnn_t
from models.s_relu import S_ReLU

model = Vpnn(INP_SIZE, OUTP_SIZE, hidden_layers=3,rotations=3,
             chebyshev_M=2, diagonal_M=0.01, svd_u=True)

if torch.cuda.is_available():
    model = model.cuda()

opt = torch.optim.SGD(model.parameters(), lr, momentum=0.9)
loss_funct = torch.nn.CrossEntropyLoss()

In [None]:
DATA_PATH = "data"  # MNIST data folder
DOWNLOAD_MNIST = False  # set true if do not have data, downloads it
FILE_PATH = ''  # path for saving data/model

In [None]:
# changes data from 2d image to 1d list
# HxW -> L
def flatten(x):
    return x.view(784)

In [None]:
# by default, data is 28x28 PIL images

# list of functions that operate on data on load
PREPROCESSING = Compose([torchvision.transforms.ToTensor(), flatten])

train_ds = torchvision.datasets.MNIST(
      DATA_PATH, train=True, download=DOWNLOAD_MNIST,
      transform=PREPROCESSING
)
valid_ds = torchvision.datasets.MNIST(
      DATA_PATH, train=False, download=DOWNLOAD_MNIST,
      transform=PREPROCESSING
)

train_dl = torch.utils.data.DataLoader(
    train_ds, batch_size=bs, shuffle=True, pin_memory=True
)
valid_dl = torch.utils.data.DataLoader(
    valid_ds, batch_size=bs * 2, pin_memory=True
)

In [None]:
# run training cycle
data=fit(model, loss_funct, train_dl, valid_dl, opt, epochs, one_hot_size=10)

In [None]:
# shows graphs of losses and accuracies in ui
pyplot.plot(data['losses'])
pyplot.title('losses')
pyplot.xlabel('training batch')
pyplot.ylabel('loss')
pyplot.show()

pyplot.plot(data['accs'])
pyplot.title('accuracy on validation data')
pyplot.xlabel('epoch')
pyplot.ylabel('% accuracy')
pyplot.show()

In [None]:
# confusion matrix
mat=0
for x, y in valid_dl:
    if torch.cuda.is_available():
        x = x.cuda()
    pred = model(x)
    _, pred = torch.max(pred, 1)
    mat += confusion_matrix(y.numpy(), pred.cpu().detach().numpy(), labels=list(range(10)))
sns.heatmap(mat.T, square=True, annot=True, fmt='d', cbar=False,
           xticklabels='auto', yticklabels='auto')
pyplot.ylim(10,0)
pyplot.xlabel('true label')
pyplot.ylabel('predicted label')
pyplot.show()