In [None]:
import torch
import torchvision
from torchvision.transforms.transforms import Compose
from tools.training_cycle import fit
from sklearn.metrics import confusion_matrix
from matplotlib import pyplot
from tools.emnist_setup import emnist_setup
import seaborn as sns
sns.set()

In [None]:
bs = 100
lr = 0.01
epochs = 20

description of EMNIST


splits ('byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist')<br>
'byclass':[0-9] [a-z] [A-Z], 814255 characters, 62 classes<br>
'bymerge':[0-9] [a,b,d,e,f,g,h,n,q,r,t] [A-Z], 814255 characters, 47 classes<br>
'balanced':[0-9] [a,b,d,e,f,g,h,n,q,r,t] [A-Z], 131600 characters, 47 classes<br>
'letters':[A-Z], 145600 characters, 26 classes<br>
'digits':[0-9], 280000 characters, 10 classes<br>
'mnist':[0-9], 70000 characters, 10 classes<br>

In [None]:
DIRECTORY = 'Pooling_Chebyshev_diagonal_rotational' # change to directory you want graphs in graphs/<dir>
GRAPH_PATH = 'graphs' + '/' + DIRECTORY  
DATA_PATH = "data"  # MNIST data folder
FILE_PATH = ''  # path for saving data/model
SPLIT = 'letters'  # choose which set of EMNIST to take
DOWNLOAD_EMNIST = False  # set true if do not have data, downloads it
LABEL_DICT, CLASSES, TARGET_TRANSFORM = emnist_setup(SPLIT)  # retrives constants based on split


In [None]:
INP_SIZE=784
OUTP_SIZE=CLASSES

In [None]:
# model selection
from models.mixed1 import Mixed1
from models.mixed2 import Mixed2
from models.vpnn import Vpnn
from models.vpnn_t import Vpnn_t
from models.s_relu import S_ReLU

model = Vpnn(INP_SIZE, OUTP_SIZE, hidden_layers=3,rotations=3,
             chebyshev_M=2, diagonal_M=0.01, svd_u=True)

if torch.cuda.is_available():
    model = model.cuda()

opt = torch.optim.SGD(model.parameters(), lr, momentum=0.9)
loss_funct = torch.nn.CrossEntropyLoss()

In [None]:
# changes data from 2d image to 1d list
# HxW -> L
def flatten(x):
    return x.view(784)

In [None]:
# by default, data is 28x28 PIL images

# list of functions that operate on data on load
PREPROCESSING = Compose([torchvision.transforms.ToTensor(), flatten])

train_ds = torchvision.datasets.EMNIST(
      DATA_PATH, split=SPLIT, train=True, download=DOWNLOAD_EMNIST,
      transform=PREPROCESSING, target_transform=TARGET_TRANSFORM
)
valid_ds = torchvision.datasets.EMNIST(
      DATA_PATH, split=SPLIT, train=False, download=DOWNLOAD_EMNIST,
      transform=PREPROCESSING, target_transform=TARGET_TRANSFORM
)

train_dl = torch.utils.data.DataLoader(
    train_ds, batch_size=bs, shuffle=True, pin_memory=True
)
valid_dl = torch.utils.data.DataLoader(
    valid_ds, batch_size=bs * 2, pin_memory=True
)

In [None]:
# run training cycle
data=fit(model, loss_funct, train_dl, valid_dl, opt, epochs, one_hot_size=CLASSES)

In [None]:
# shows graphs of losses and accuracies in ui
pyplot.plot(data['losses'])
pyplot.title('losses')
pyplot.xlabel('training batch')
pyplot.ylabel('loss')
pyplot.show()

pyplot.plot(data['accs'])
pyplot.title('accuracy on validation data')
pyplot.xlabel('epoch')
pyplot.ylabel('% accuracy')
pyplot.show()

In [None]:
# confusion matrix
ANNOT=False  # show values in each square, will get messy if too many values
mat=0
for x, y in valid_dl:
    if torch.cuda.is_available():
        x = x.cuda()
    pred = model(x)
    _, pred = torch.max(pred, 1)
    mat += confusion_matrix(y.numpy(), pred.cpu().detach().numpy(), labels=list(LABEL_DICT.keys()))
sns.heatmap(mat.T, square=True, fmt='d',annot=ANNOT, cbar=False,
           xticklabels=list(LABEL_DICT.values()), yticklabels=list(LABEL_DICT.values()))
pyplot.xlabel('true label')
pyplot.ylabel('predicted label')
pyplot.show()