In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.datasets import FashionMNIST

import random
import copy
import tqdm
import numpy as np
import matplotlib.pyplot as plt

def set_seed(seed):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

In [2]:
class Args():
    device = 'cpu'
    # device = torch.device('mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu')

    input_size = 784
    output_size = 10
    num_neurons = 50
    num_committees = 500
    size_committee = 5

    batch_size = 128
    lr = 1e-2
    weight_decay = 1e-5
    max_epochs = 100

In [3]:
train_data = FashionMNIST(root='./data', train=True, download=False)
test_data = FashionMNIST(root='./data', train=False, download=False)

X_train = train_data.data.unsqueeze(1).float().flatten(start_dim=1)/255
X_test = test_data.data.unsqueeze(1).float().flatten(start_dim=1)/255

y_train = F.one_hot(train_data.targets).float()
y_test = F.one_hot(test_data.targets).float()

train_data = torch.utils.data.TensorDataset(X_train, y_train)
test_data = torch.utils.data.TensorDataset(X_test, y_test)

In [4]:
class NPF(nn.Module):

    def __init__(self, input_size, output_size, num_neurons, size_committee, num_committees, args):
        super().__init__()
        self.args = args
        self.input_size = input_size
        self.output_size = output_size
        self.num_neurons = num_neurons
        self.num_committees = num_committees

        self.neurons = nn.ModuleList([nn.Linear(input_size, output_size, bias=False) for i in range(num_neurons)])
        self.npf = []
        for i in range(num_committees):
            while True:
                choice = torch.randint(0, num_neurons, size=(size_committee,))
                if len(torch.unique(choice)) == size_committee:
                    break
            self.npf.append(choice.tolist())

        self.init('gaussian')

        self.sigmoid = nn.Sigmoid()

    def init(self, type):
        if not hasattr(self, 'type'):
            self.type = type
        else:
            self.type = 'gaussian'

        if self.type == 'gaussian':
            for i in range(self.num_neurons):
                nn.init.normal_(self.neurons[i].weight, mean=0.0, std=1.0)
        else:
            raise NotImplementedError 

    def forward(self, x):
        z = torch.zeros((x.shape[0], self.num_committees, self.output_size)).to(self.args.device)
        for i in range(self.num_committees):
            z[:, i] = torch.prod(torch.stack([self.sigmoid(x @ self.neurons[self.npf[i][j]].weight.T) for j in range(len(self.npf[i]))]), dim=0).squeeze(1)
        return z
    
class NPV(nn.Module):

    def __init__(self, input_size, output_size, num_neurons, size_committee, num_committees, args):
        super().__init__()
        self.args = args
        self.input_size = input_size
        self.output_size = output_size
        self.num_neurons = num_neurons
        self.num_committees = num_committees

        self.npv = nn.Linear(num_committees, 1, bias=False)

        self.init('gaussian')

    def init(self, type):
        if not hasattr(self, 'type'):
            self.type = type
        else:
            self.type = 'gaussian'

        if self.type == 'gaussian':
            nn.init.normal_(self.npv.weight, mean=0.0, std=1.0)
        else:
            raise NotImplementedError 

    def forward(self, z):
        z = self.npv(z.permute(0, 2, 1)) 
        return z.squeeze(-1)


In [5]:
class Trainer():

    def __init__(self, data, models, optimzers, criterion, args):
        self.args = args
        
        self.traindata, self.testdata = data
        self.trainloader = torch.utils.data.DataLoader(self.traindata, batch_size=self.args.batch_size, shuffle=True, drop_last=True)
        self.testloader = torch.utils.data.DataLoader(self.testdata, batch_size=self.args.batch_size, shuffle=False, drop_last=False)

        self.npf, self.npv = models
        self.npf, self.npv = self.npf.to(self.args.device), self.npv.to(self.args.device)
        self.npf_opt, self.npv_opt = optimzers
        self.criterion = criterion

        self.loss = []
        self.accuracy = []

    def train_epoch(self, step):
        

        train_loss = 0
        train_accuracy = 0

        for idx, (x, y) in tqdm.tqdm(enumerate(self.trainloader)):
            self.npf.train()
            self.npv.train()
            x = x.to(self.args.device)
            y = y.to(self.args.device)

            self.npf_opt.zero_grad()
            self.npv_opt.zero_grad()

            if step%5 != 0:
                z = self.npf(x)
                y_pred = self.npv(z)
                loss = self.criterion(y_pred, y)
                loss.backward()
                self.npf_opt.step()

            if step%5 == 0:
                z = self.npf(x)
                y_pred = self.npv(z)
                loss = self.criterion(y_pred, y)
                loss.backward()
                self.npv_opt.step()

            self.npf.eval()
            self.npv.eval()
            with torch.no_grad():
                z = self.npf(x)
                y_pred = self.npv(z)
                loss = self.criterion(y_pred, y)
                accuracy = (y_pred.argmax(dim=1) == y.argmax(dim=1)).float().mean()

            train_loss += loss.item()/len(self.trainloader)
            train_accuracy += accuracy.item()/len(self.trainloader)

        return train_loss, train_accuracy
    
    def test(self):
        self.npf.eval()
        self.npv.eval()
        test_loss = 0
        test_accuracy = 0
        for idx, (x, y) in enumerate(self.testloader):
            x = x.to(self.args.device)
            y = y.to(self.args.device)

            with torch.no_grad():
                z = self.npf(x)
                y_pred = self.npv(z)
                loss = self.criterion(y_pred, y)
                accuracy = (y_pred.argmax(dim=1) == y.argmax(dim=1)).float().mean()

            test_loss += loss.item()/len(self.testloader)
            test_accuracy += accuracy.item()/len(self.testloader)
        
        return loss.item(), accuracy.item()
    
    def train(self, epochs):
        for epoch in range(epochs):
            train_loss, train_accuracy = self.train_epoch(epoch)
            self.loss.append(train_loss)
            self.accuracy.append(train_accuracy)

            test_loss, test_accuracy = self.test()

            print(f'Epoch: {epoch+1:03d}/{epochs:03d} | Train Loss: {train_loss:.3f} | Train Accuracy: {train_accuracy:.3f} | Test Loss: {test_loss:.3f} | Test Accuracy: {test_accuracy:.3f}')
        

In [6]:
set_seed(0)
args = Args()

npf = NPF(args.input_size, args.output_size, args.num_neurons, args.size_committee, args.num_committees, args)
npv = NPV(args.input_size, args.output_size, args.num_neurons, args.size_committee, args.num_committees, args)

npf_opt = optim.SGD(npf.parameters(), lr=args.lr, weight_decay=0)
npv_opt = optim.SGD(npv.parameters(), lr=args.lr, weight_decay=args.weight_decay)

criterion = nn.CrossEntropyLoss()

trainer = Trainer((train_data, test_data), (npf, npv), (npf_opt, npv_opt), criterion, args)
trainer.train(args.max_epochs)

0it [00:00, ?it/s]

468it [02:31,  3.09it/s]


Epoch: 001/100 | Train Loss: 5.024 | Train Accuracy: 0.108 | Test Loss: 5.826 | Test Accuracy: 0.062


468it [02:31,  3.09it/s]


Epoch: 002/100 | Train Loss: 3.206 | Train Accuracy: 0.224 | Test Loss: 3.738 | Test Accuracy: 0.250


468it [02:32,  3.07it/s]


Epoch: 003/100 | Train Loss: 2.189 | Train Accuracy: 0.375 | Test Loss: 2.859 | Test Accuracy: 0.312


468it [02:31,  3.10it/s]


Epoch: 004/100 | Train Loss: 1.825 | Train Accuracy: 0.451 | Test Loss: 2.465 | Test Accuracy: 0.500


468it [02:30,  3.11it/s]


Epoch: 005/100 | Train Loss: 1.637 | Train Accuracy: 0.497 | Test Loss: 2.266 | Test Accuracy: 0.500


468it [02:31,  3.09it/s]


Epoch: 006/100 | Train Loss: 1.541 | Train Accuracy: 0.521 | Test Loss: 2.172 | Test Accuracy: 0.500


468it [02:30,  3.10it/s]


Epoch: 007/100 | Train Loss: 1.465 | Train Accuracy: 0.541 | Test Loss: 2.041 | Test Accuracy: 0.500


468it [02:30,  3.11it/s]


Epoch: 008/100 | Train Loss: 1.384 | Train Accuracy: 0.563 | Test Loss: 1.944 | Test Accuracy: 0.500


468it [02:30,  3.11it/s]


Epoch: 009/100 | Train Loss: 1.320 | Train Accuracy: 0.582 | Test Loss: 1.866 | Test Accuracy: 0.500


468it [02:31,  3.09it/s]


Epoch: 010/100 | Train Loss: 1.268 | Train Accuracy: 0.597 | Test Loss: 1.797 | Test Accuracy: 0.500


468it [02:31,  3.09it/s]


Epoch: 011/100 | Train Loss: 1.235 | Train Accuracy: 0.606 | Test Loss: 1.753 | Test Accuracy: 0.500


468it [02:31,  3.10it/s]


Epoch: 012/100 | Train Loss: 1.205 | Train Accuracy: 0.614 | Test Loss: 1.693 | Test Accuracy: 0.500


468it [02:31,  3.09it/s]


Epoch: 013/100 | Train Loss: 1.169 | Train Accuracy: 0.625 | Test Loss: 1.640 | Test Accuracy: 0.500


468it [02:29,  3.12it/s]


Epoch: 014/100 | Train Loss: 1.138 | Train Accuracy: 0.634 | Test Loss: 1.592 | Test Accuracy: 0.500


468it [02:24,  3.23it/s]


Epoch: 015/100 | Train Loss: 1.110 | Train Accuracy: 0.641 | Test Loss: 1.547 | Test Accuracy: 0.562


468it [02:20,  3.32it/s]


Epoch: 016/100 | Train Loss: 1.092 | Train Accuracy: 0.647 | Test Loss: 1.518 | Test Accuracy: 0.562


468it [02:21,  3.30it/s]


Epoch: 017/100 | Train Loss: 1.073 | Train Accuracy: 0.652 | Test Loss: 1.475 | Test Accuracy: 0.562


468it [02:22,  3.29it/s]


Epoch: 018/100 | Train Loss: 1.051 | Train Accuracy: 0.658 | Test Loss: 1.436 | Test Accuracy: 0.562


468it [02:22,  3.29it/s]


Epoch: 019/100 | Train Loss: 1.031 | Train Accuracy: 0.664 | Test Loss: 1.399 | Test Accuracy: 0.562


468it [02:22,  3.29it/s]


Epoch: 020/100 | Train Loss: 1.014 | Train Accuracy: 0.669 | Test Loss: 1.365 | Test Accuracy: 0.562


468it [02:21,  3.30it/s]


Epoch: 021/100 | Train Loss: 1.002 | Train Accuracy: 0.672 | Test Loss: 1.343 | Test Accuracy: 0.562


468it [02:22,  3.29it/s]


Epoch: 022/100 | Train Loss: 0.989 | Train Accuracy: 0.676 | Test Loss: 1.309 | Test Accuracy: 0.562


468it [02:21,  3.30it/s]


Epoch: 023/100 | Train Loss: 0.974 | Train Accuracy: 0.679 | Test Loss: 1.278 | Test Accuracy: 0.562


468it [02:20,  3.33it/s]


Epoch: 024/100 | Train Loss: 0.960 | Train Accuracy: 0.683 | Test Loss: 1.249 | Test Accuracy: 0.562


468it [02:21,  3.31it/s]


Epoch: 025/100 | Train Loss: 0.947 | Train Accuracy: 0.687 | Test Loss: 1.221 | Test Accuracy: 0.562


468it [02:21,  3.31it/s]


Epoch: 026/100 | Train Loss: 0.938 | Train Accuracy: 0.689 | Test Loss: 1.205 | Test Accuracy: 0.562


468it [02:21,  3.32it/s]


Epoch: 027/100 | Train Loss: 0.929 | Train Accuracy: 0.692 | Test Loss: 1.177 | Test Accuracy: 0.625


468it [02:21,  3.31it/s]


Epoch: 028/100 | Train Loss: 0.918 | Train Accuracy: 0.695 | Test Loss: 1.150 | Test Accuracy: 0.625


468it [02:21,  3.31it/s]


Epoch: 029/100 | Train Loss: 0.907 | Train Accuracy: 0.698 | Test Loss: 1.124 | Test Accuracy: 0.625


468it [02:21,  3.32it/s]


Epoch: 030/100 | Train Loss: 0.897 | Train Accuracy: 0.701 | Test Loss: 1.098 | Test Accuracy: 0.625


468it [06:53,  1.13it/s]


Epoch: 031/100 | Train Loss: 0.891 | Train Accuracy: 0.703 | Test Loss: 1.085 | Test Accuracy: 0.625


468it [02:22,  3.29it/s]


Epoch: 032/100 | Train Loss: 0.883 | Train Accuracy: 0.705 | Test Loss: 1.060 | Test Accuracy: 0.625


468it [02:22,  3.28it/s]


Epoch: 033/100 | Train Loss: 0.875 | Train Accuracy: 0.707 | Test Loss: 1.035 | Test Accuracy: 0.625


468it [02:22,  3.29it/s]


Epoch: 034/100 | Train Loss: 0.866 | Train Accuracy: 0.709 | Test Loss: 1.012 | Test Accuracy: 0.625


468it [02:22,  3.29it/s]


Epoch: 035/100 | Train Loss: 0.858 | Train Accuracy: 0.712 | Test Loss: 0.989 | Test Accuracy: 0.625


468it [02:22,  3.29it/s]


Epoch: 036/100 | Train Loss: 0.853 | Train Accuracy: 0.713 | Test Loss: 0.979 | Test Accuracy: 0.625


468it [02:22,  3.28it/s]


Epoch: 037/100 | Train Loss: 0.847 | Train Accuracy: 0.715 | Test Loss: 0.957 | Test Accuracy: 0.625


468it [02:22,  3.29it/s]


Epoch: 038/100 | Train Loss: 0.840 | Train Accuracy: 0.717 | Test Loss: 0.936 | Test Accuracy: 0.625


468it [02:20,  3.32it/s]


Epoch: 039/100 | Train Loss: 0.833 | Train Accuracy: 0.719 | Test Loss: 0.917 | Test Accuracy: 0.688


468it [02:20,  3.33it/s]


Epoch: 040/100 | Train Loss: 0.827 | Train Accuracy: 0.721 | Test Loss: 0.898 | Test Accuracy: 0.688


468it [02:20,  3.34it/s]


Epoch: 041/100 | Train Loss: 0.822 | Train Accuracy: 0.722 | Test Loss: 0.889 | Test Accuracy: 0.688


468it [02:20,  3.33it/s]


Epoch: 042/100 | Train Loss: 0.817 | Train Accuracy: 0.723 | Test Loss: 0.871 | Test Accuracy: 0.688


468it [02:20,  3.33it/s]


Epoch: 043/100 | Train Loss: 0.811 | Train Accuracy: 0.725 | Test Loss: 0.856 | Test Accuracy: 0.688


468it [02:20,  3.33it/s]


Epoch: 044/100 | Train Loss: 0.805 | Train Accuracy: 0.726 | Test Loss: 0.840 | Test Accuracy: 0.688


468it [02:20,  3.34it/s]


Epoch: 045/100 | Train Loss: 0.800 | Train Accuracy: 0.729 | Test Loss: 0.826 | Test Accuracy: 0.688


468it [02:21,  3.30it/s]


Epoch: 046/100 | Train Loss: 0.796 | Train Accuracy: 0.730 | Test Loss: 0.818 | Test Accuracy: 0.688


468it [02:22,  3.29it/s]


Epoch: 047/100 | Train Loss: 0.791 | Train Accuracy: 0.730 | Test Loss: 0.805 | Test Accuracy: 0.688


468it [02:21,  3.30it/s]


Epoch: 048/100 | Train Loss: 0.786 | Train Accuracy: 0.732 | Test Loss: 0.792 | Test Accuracy: 0.688


468it [02:22,  3.30it/s]


Epoch: 049/100 | Train Loss: 0.781 | Train Accuracy: 0.734 | Test Loss: 0.781 | Test Accuracy: 0.688


468it [02:21,  3.31it/s]


Epoch: 050/100 | Train Loss: 0.776 | Train Accuracy: 0.735 | Test Loss: 0.770 | Test Accuracy: 0.688


468it [02:20,  3.33it/s]


Epoch: 051/100 | Train Loss: 0.773 | Train Accuracy: 0.736 | Test Loss: 0.763 | Test Accuracy: 0.688


468it [02:20,  3.33it/s]


Epoch: 052/100 | Train Loss: 0.769 | Train Accuracy: 0.738 | Test Loss: 0.752 | Test Accuracy: 0.688


468it [02:20,  3.33it/s]


Epoch: 053/100 | Train Loss: 0.765 | Train Accuracy: 0.739 | Test Loss: 0.742 | Test Accuracy: 0.688


468it [02:20,  3.33it/s]


Epoch: 054/100 | Train Loss: 0.760 | Train Accuracy: 0.741 | Test Loss: 0.733 | Test Accuracy: 0.688


468it [02:20,  3.33it/s]


Epoch: 055/100 | Train Loss: 0.756 | Train Accuracy: 0.742 | Test Loss: 0.724 | Test Accuracy: 0.688


468it [02:21,  3.32it/s]


Epoch: 056/100 | Train Loss: 0.753 | Train Accuracy: 0.743 | Test Loss: 0.718 | Test Accuracy: 0.688


468it [02:20,  3.32it/s]


Epoch: 057/100 | Train Loss: 0.749 | Train Accuracy: 0.744 | Test Loss: 0.710 | Test Accuracy: 0.688


468it [02:20,  3.32it/s]


Epoch: 058/100 | Train Loss: 0.745 | Train Accuracy: 0.745 | Test Loss: 0.701 | Test Accuracy: 0.688


468it [02:20,  3.33it/s]


Epoch: 059/100 | Train Loss: 0.741 | Train Accuracy: 0.746 | Test Loss: 0.694 | Test Accuracy: 0.688


468it [02:21,  3.30it/s]


Epoch: 060/100 | Train Loss: 0.738 | Train Accuracy: 0.747 | Test Loss: 0.687 | Test Accuracy: 0.688


468it [02:20,  3.32it/s]


Epoch: 061/100 | Train Loss: 0.735 | Train Accuracy: 0.748 | Test Loss: 0.682 | Test Accuracy: 0.688


468it [02:20,  3.33it/s]


Epoch: 062/100 | Train Loss: 0.733 | Train Accuracy: 0.748 | Test Loss: 0.675 | Test Accuracy: 0.688


468it [02:20,  3.32it/s]


Epoch: 063/100 | Train Loss: 0.729 | Train Accuracy: 0.749 | Test Loss: 0.669 | Test Accuracy: 0.688


468it [02:20,  3.32it/s]


Epoch: 064/100 | Train Loss: 0.725 | Train Accuracy: 0.750 | Test Loss: 0.662 | Test Accuracy: 0.688


468it [02:20,  3.33it/s]


Epoch: 065/100 | Train Loss: 0.722 | Train Accuracy: 0.751 | Test Loss: 0.656 | Test Accuracy: 0.688


468it [02:21,  3.31it/s]


Epoch: 066/100 | Train Loss: 0.720 | Train Accuracy: 0.752 | Test Loss: 0.651 | Test Accuracy: 0.688


468it [02:21,  3.31it/s]


Epoch: 067/100 | Train Loss: 0.717 | Train Accuracy: 0.753 | Test Loss: 0.646 | Test Accuracy: 0.688


468it [02:22,  3.27it/s]


Epoch: 068/100 | Train Loss: 0.714 | Train Accuracy: 0.754 | Test Loss: 0.641 | Test Accuracy: 0.688


468it [02:22,  3.29it/s]


Epoch: 069/100 | Train Loss: 0.711 | Train Accuracy: 0.755 | Test Loss: 0.636 | Test Accuracy: 0.688


468it [02:21,  3.30it/s]


Epoch: 070/100 | Train Loss: 0.708 | Train Accuracy: 0.756 | Test Loss: 0.631 | Test Accuracy: 0.750


468it [02:22,  3.29it/s]


Epoch: 071/100 | Train Loss: 0.706 | Train Accuracy: 0.756 | Test Loss: 0.627 | Test Accuracy: 0.750


468it [02:21,  3.30it/s]


Epoch: 072/100 | Train Loss: 0.704 | Train Accuracy: 0.757 | Test Loss: 0.622 | Test Accuracy: 0.750


468it [02:20,  3.32it/s]


Epoch: 073/100 | Train Loss: 0.701 | Train Accuracy: 0.758 | Test Loss: 0.617 | Test Accuracy: 0.750


468it [02:20,  3.32it/s]


Epoch: 074/100 | Train Loss: 0.698 | Train Accuracy: 0.759 | Test Loss: 0.613 | Test Accuracy: 0.750


468it [02:20,  3.33it/s]


Epoch: 075/100 | Train Loss: 0.695 | Train Accuracy: 0.760 | Test Loss: 0.609 | Test Accuracy: 0.750


468it [02:20,  3.32it/s]


Epoch: 076/100 | Train Loss: 0.693 | Train Accuracy: 0.761 | Test Loss: 0.605 | Test Accuracy: 0.750


468it [02:23,  3.27it/s]


Epoch: 077/100 | Train Loss: 0.691 | Train Accuracy: 0.761 | Test Loss: 0.601 | Test Accuracy: 0.750


468it [02:23,  3.27it/s]


Epoch: 078/100 | Train Loss: 0.689 | Train Accuracy: 0.762 | Test Loss: 0.597 | Test Accuracy: 0.750


468it [02:22,  3.29it/s]


Epoch: 079/100 | Train Loss: 0.686 | Train Accuracy: 0.763 | Test Loss: 0.592 | Test Accuracy: 0.750


468it [02:22,  3.29it/s]


Epoch: 080/100 | Train Loss: 0.684 | Train Accuracy: 0.764 | Test Loss: 0.589 | Test Accuracy: 0.750


468it [02:21,  3.30it/s]


Epoch: 081/100 | Train Loss: 0.682 | Train Accuracy: 0.764 | Test Loss: 0.586 | Test Accuracy: 0.750


468it [02:21,  3.32it/s]


Epoch: 082/100 | Train Loss: 0.680 | Train Accuracy: 0.765 | Test Loss: 0.583 | Test Accuracy: 0.812


468it [02:21,  3.32it/s]


Epoch: 083/100 | Train Loss: 0.678 | Train Accuracy: 0.766 | Test Loss: 0.579 | Test Accuracy: 0.812


468it [02:20,  3.33it/s]


Epoch: 084/100 | Train Loss: 0.675 | Train Accuracy: 0.766 | Test Loss: 0.577 | Test Accuracy: 0.812


468it [02:20,  3.33it/s]


Epoch: 085/100 | Train Loss: 0.673 | Train Accuracy: 0.767 | Test Loss: 0.573 | Test Accuracy: 0.812


468it [02:20,  3.33it/s]


Epoch: 086/100 | Train Loss: 0.672 | Train Accuracy: 0.767 | Test Loss: 0.570 | Test Accuracy: 0.812


468it [02:22,  3.28it/s]


Epoch: 087/100 | Train Loss: 0.670 | Train Accuracy: 0.768 | Test Loss: 0.567 | Test Accuracy: 0.812


468it [02:22,  3.29it/s]


Epoch: 088/100 | Train Loss: 0.668 | Train Accuracy: 0.769 | Test Loss: 0.565 | Test Accuracy: 0.812


468it [02:22,  3.29it/s]


Epoch: 089/100 | Train Loss: 0.666 | Train Accuracy: 0.769 | Test Loss: 0.562 | Test Accuracy: 0.812


468it [02:22,  3.29it/s]


Epoch: 090/100 | Train Loss: 0.664 | Train Accuracy: 0.770 | Test Loss: 0.559 | Test Accuracy: 0.812


468it [02:22,  3.30it/s]


Epoch: 091/100 | Train Loss: 0.662 | Train Accuracy: 0.771 | Test Loss: 0.557 | Test Accuracy: 0.812


468it [02:21,  3.32it/s]


Epoch: 092/100 | Train Loss: 0.660 | Train Accuracy: 0.771 | Test Loss: 0.555 | Test Accuracy: 0.812


468it [02:20,  3.33it/s]


Epoch: 093/100 | Train Loss: 0.658 | Train Accuracy: 0.772 | Test Loss: 0.553 | Test Accuracy: 0.812


468it [02:20,  3.33it/s]


Epoch: 094/100 | Train Loss: 0.657 | Train Accuracy: 0.773 | Test Loss: 0.551 | Test Accuracy: 0.812


468it [02:20,  3.32it/s]


Epoch: 095/100 | Train Loss: 0.654 | Train Accuracy: 0.773 | Test Loss: 0.550 | Test Accuracy: 0.812


468it [02:23,  3.27it/s]


Epoch: 096/100 | Train Loss: 0.653 | Train Accuracy: 0.774 | Test Loss: 0.547 | Test Accuracy: 0.812


468it [02:22,  3.29it/s]


Epoch: 097/100 | Train Loss: 0.651 | Train Accuracy: 0.774 | Test Loss: 0.546 | Test Accuracy: 0.812


468it [02:22,  3.29it/s]


Epoch: 098/100 | Train Loss: 0.650 | Train Accuracy: 0.775 | Test Loss: 0.544 | Test Accuracy: 0.812


468it [02:22,  3.29it/s]


Epoch: 099/100 | Train Loss: 0.648 | Train Accuracy: 0.775 | Test Loss: 0.542 | Test Accuracy: 0.812


468it [02:22,  3.27it/s]


Epoch: 100/100 | Train Loss: 0.646 | Train Accuracy: 0.776 | Test Loss: 0.541 | Test Accuracy: 0.812
