In [None]:
import torch
import torchvision
from torchvision.transforms.transforms import Compose
from tools.training_cycle import fit
from sklearn.metrics import confusion_matrix
from matplotlib import pyplot
import seaborn as sns
from tools.aclimdb_loader import aclimdb_load, test_review, review_to_words
sns.set()

aclimdb data set<br>

download at https://ai.stanford.edu/~amaas/data/sentiment/<br>
expects folder aclImdb in DATA_PATH<br>
needs files /imdb.vocab, /test/labeledBow.feat, train/labeledBow.feat<br>

need to download nltk stopwords, does so automatically<br>

data is returned from dataset as 1xINP_SIZE vector, with label 0 or 1<br>
which contains the most common INP_SIZE words<br>
label 0 = bad, label 1 = good<br>

if want to keep stopwords use remove_stopwords=False in loader<br>

In [None]:
bs = 100
lr = 0.1
epochs = 10
INP_SIZE = 5000  # number of words to accept
OUTP_SIZE = 2 

In [None]:
# model selection
from models.mixed1 import Mixed1
from models.mixed2 import Mixed2
from models.vpnn import Vpnn
from models.vpnn_t import Vpnn_t
from models.s_relu import S_ReLU

model = Vpnn(INP_SIZE, OUTP_SIZE, hidden_layers=3,rotations=3,
             chebyshev_M=2, diagonal_M=0.01, svd_u=True)

if torch.cuda.is_available():
    model = model.cuda()

opt = torch.optim.SGD(model.parameters(), lr, momentum=0.9)
loss_funct = torch.nn.CrossEntropyLoss()

In [None]:
DATA_PATH = "data"  # MNIST data folder
FILE_PATH = ''  # path for saving data/model

In [None]:
# by default, data is 1xINP_SIZE vectors
# warning: slow due to overhead expect 20s-1min to run

# list of functions that operate on data on load
PREPROCESSING = Compose([])

# nmap is a dict that maps index -> word in input data
# wmap is a dict that maps word -> index
train_ds, valid_ds, nmap, wmap = aclimdb_load(DATA_PATH, INP_SIZE,
                                        preprocessing=PREPROCESSING, remove_stopwords=True)

train_dl = torch.utils.data.DataLoader(
    train_ds, batch_size=bs, shuffle=True, pin_memory=True
)
valid_dl = torch.utils.data.DataLoader(
    valid_ds, batch_size=bs * 2, pin_memory=True
)

In [None]:
# run training cycle
data=fit(model, loss_funct, train_dl, valid_dl, opt, epochs, one_hot_size=2)

In [None]:
# shows graphs of losses and accuracies in ui
pyplot.plot(data['losses'])
pyplot.title('losses')
pyplot.xlabel('training batch')
pyplot.ylabel('loss')
pyplot.show()

pyplot.plot(data['accs'])
pyplot.title('accuracy on validation data')
pyplot.xlabel('epoch')
pyplot.ylabel('% accuracy')
pyplot.show()

In [None]:
# test a custom review on model
# warning: unsantized input. will not count punctuation, uppercase etc
review = 'this was a good movie'
test_review(review, wmap, INP_SIZE, model)

In [None]:
# see word count in a given review
review = 10
review_to_words(valid_ds[review][0],nmap)

In [None]:
# confusion matrix
mat=0
for x, y in valid_dl:
    if torch.cuda.is_available():
        x = x.cuda()
    pred = model(x)
    _, pred = torch.max(pred, 1)
    mat += confusion_matrix(y.numpy(), pred.cpu().detach().numpy(), labels=list(range(2)))
sns.heatmap(mat.T, square=True, annot=True, fmt='d', cbar=False,
           xticklabels='auto', yticklabels='auto')
pyplot.ylim(2,0)
pyplot.xlabel('true label')
pyplot.ylabel('predicted label')
pyplot.show()