In [41]:
import numpy as np
import torch
from torchvision.datasets import mnist
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torch.nn import Module
from torch import nn

In [47]:
#Normal Model
class Model(Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(256, 120)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(120, 84)
        self.relu4 = nn.ReLU()
        self.fc3 = nn.Linear(84, 10)
        self.relu5 = nn.ReLU()

    def forward(self, x):
        y = self.conv1(x)
        y = self.relu1(y)
        y = self.pool1(y)
        y = self.conv2(y)
        y = self.relu2(y)
        y = self.pool2(y)
        y = y.view(y.shape[0], -1)
        y = self.fc1(y)
        y = self.relu3(y)
        y = self.fc2(y)
        y = self.relu4(y)
        y = self.fc3(y)
        y = self.relu5(y)
        return y

In [43]:
if __name__ == '__main__':
    batch_size = 256
    train_dataset = mnist.MNIST(download=True, root='./data', train=True, transform=ToTensor())
    test_dataset = mnist.MNIST(download=True, root='./data', train=False, transform=ToTensor())
    train_loader = DataLoader(train_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    model = Model()
    sgd = SGD(model.parameters(), lr=1e-1)
    loss_fn = CrossEntropyLoss()
    all_epoch = 5

    for current_epoch in range(all_epoch):
        model.train()
        for idx, (train_x, train_label) in enumerate(train_loader):
            sgd.zero_grad()
            predict_y = model(train_x.float())
            loss = loss_fn(predict_y, train_label.long())
            if idx % 50 == 0:
                print('idx: {}, loss: {}'.format(idx, loss.sum().item()))
            loss.backward()
            sgd.step()

        all_correct_num = 0
        all_sample_num = 0
        model.eval()
        for idx, (test_x, test_label) in enumerate(test_loader):
            predict_y = model(test_x.float()).detach()
            predict_y = np.argmax(predict_y, axis=-1)
            current_correct_num = predict_y == test_label
            all_correct_num += np.sum(current_correct_num.numpy(), axis=-1)
            all_sample_num += current_correct_num.shape[0]
        acc = all_correct_num / all_sample_num
        print('accuracy: {:.2f}'.format(acc))

idx: 0, loss: 2.3026349544525146
idx: 50, loss: 2.2941477298736572
idx: 100, loss: 2.2012548446655273
idx: 150, loss: 1.2413570880889893
idx: 200, loss: 0.5782398581504822
accuracy: 0.78
idx: 0, loss: 0.6540509462356567
idx: 50, loss: 0.447021484375
idx: 100, loss: 0.5069711804389954
idx: 150, loss: 0.5374249815940857
idx: 200, loss: 0.31400439143180847
accuracy: 0.84
idx: 0, loss: 0.4617568254470825
idx: 50, loss: 0.3370366394519806
idx: 100, loss: 0.44619306921958923
idx: 150, loss: 0.46176448464393616
idx: 200, loss: 0.29563072323799133
accuracy: 0.86
idx: 0, loss: 0.4282064735889435
idx: 50, loss: 0.3117685616016388
idx: 100, loss: 0.4261307716369629
idx: 150, loss: 0.4244333505630493
idx: 200, loss: 0.2847810685634613
accuracy: 0.86
idx: 0, loss: 0.4109957218170166
idx: 50, loss: 0.2945746183395386
idx: 100, loss: 0.41346248984336853
idx: 150, loss: 0.40534743666648865
idx: 200, loss: 0.2772273123264313
accuracy: 0.86


In [44]:
from brevitas.nn import QuantLinear
from brevitas.quant import Int8ActPerTensorFloat
from brevitas.quant import Int8WeightPerTensorFloat
from brevitas.quant import Int8Bias
from brevitas.nn import QuantIdentity
from brevitas.nn import QuantReLU
from brevitas.nn import QuantConv2d

In [45]:
#Making model with brevitas
class Brevitas_Model(Module):
    def __init__(self):
        super(Brevitas_Model, self).__init__()
        self.identity = QuantIdentity(return_quant_tensor=True)
        self.conv1 = self.conv1 = QuantConv2d(1, 6, 5, bias=True, input_quant=Int8ActPerTensorFloat, weight_quant=Int8WeightPerTensorFloat,
                                      output_quant=Int8ActPerTensorFloat, bias_quant=Int8Bias, return_quant_tensor=True)
        self.relu1 = QuantReLU(return_quant_tensor=True)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = QuantConv2d(6, 16, 5, bias=True, input_quant=Int8ActPerTensorFloat, weight_quant=Int8WeightPerTensorFloat,
                                output_quant=Int8ActPerTensorFloat, bias_quant=Int8Bias, return_quant_tensor=True)
        self.relu2 = QuantReLU(return_quant_tensor=True)
        self.pool2 = nn.MaxPool2d(2)
        self.fc1 = QuantLinear(256, 120, bias=True, input_quant=Int8ActPerTensorFloat, weight_quant=Int8WeightPerTensorFloat,
                                output_quant=Int8ActPerTensorFloat, bias_quant=Int8Bias, return_quant_tensor=True)
        self.relu3 = QuantReLU(return_quant_tensor=True)
        self.fc2 = QuantLinear(120, 84, bias=True, input_quant=Int8ActPerTensorFloat, weight_quant=Int8WeightPerTensorFloat,
                                output_quant=Int8ActPerTensorFloat, bias_quant=Int8Bias, return_quant_tensor=True)
        self.relu4 = QuantReLU(return_quant_tensor=True)
        self.fc3 = QuantLinear(84, 10, bias=True, input_quant=Int8ActPerTensorFloat, weight_quant=Int8WeightPerTensorFloat,
                                output_quant=Int8ActPerTensorFloat, bias_quant=Int8Bias, return_quant_tensor=True)
        self.relu5 = QuantReLU()

    def forward(self, x):
        y = self.identity(x)
        y = self.conv1(x)
        y = self.relu1(y)
        y = self.pool1(y)
        y = self.conv2(y)
        y = self.relu2(y)
        y = self.pool2(y)
        y = y.view(y.shape[0], -1)
        y = self.fc1(y)
        y = self.relu3(y)
        y = self.fc2(y)
        y = self.relu4(y)
        y = self.fc3(y)
        y = self.relu5(y)
        return y

In [46]:
if __name__ == '__main__':
    batch_size = 256
    train_dataset = mnist.MNIST(download=True, root='./data', train=True, transform=ToTensor())
    test_dataset = mnist.MNIST(download=True, root='./data', train=False, transform=ToTensor())
    trin_loader = DataLoader(train_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    model = Brevitas_Model()
    sgd = SGD(model.parameters(), lr=1e-1)
    loss_fn = CrossEntropyLoss()
    all_epoch = 5

    for current_epoch in range(all_epoch):
        model.train()
        for idx, (train_x, train_label) in enumerate(train_loader):
            sgd.zero_grad()
            predict_y = model(train_x.float())
            loss = loss_fn(predict_y, train_label.long())
            if idx % 50 == 0:
                print('idx: {}, loss: {}'.format(idx, loss.sum().item()))
            loss.backward()
            sgd.step()

        all_correct_num = 0
        all_sample_num = 0
        model.eval()
        for idx, (test_x, test_label) in enumerate(test_loader):
            predict_y = model(test_x.float()).detach()
            predict_y = np.argmax(predict_y, axis=-1)
            current_correct_num = predict_y == test_label
            all_correct_num += np.sum(current_correct_num.numpy(), axis=-1)
            all_sample_num += current_correct_num.shape[0]
        acc = all_correct_num / all_sample_num
        print('accuracy: {:.2f}'.format(acc))

idx: 0, loss: 2.302600860595703
idx: 50, loss: 2.290087938308716
idx: 100, loss: 2.1032142639160156
idx: 150, loss: 1.774557113647461
idx: 200, loss: 1.0252306461334229
accuracy: 0.68
idx: 0, loss: 0.9428538084030151
idx: 50, loss: 0.8009646534919739
idx: 100, loss: 0.7890003323554993
idx: 150, loss: 0.8983732461929321
idx: 200, loss: 0.5346361994743347
accuracy: 0.74
idx: 0, loss: 0.7262919545173645
idx: 50, loss: 0.6671064496040344
idx: 100, loss: 0.7094277143478394
idx: 150, loss: 0.6719374060630798
idx: 200, loss: 0.49384498596191406
accuracy: 0.75
idx: 0, loss: 0.6772816777229309
idx: 50, loss: 0.6196312308311462
idx: 100, loss: 0.6735409498214722
idx: 150, loss: 0.6238174438476562
idx: 200, loss: 0.47341811656951904
accuracy: 0.73
idx: 0, loss: 0.7591606974601746
idx: 50, loss: 0.5993204712867737
idx: 100, loss: 0.650047242641449
idx: 150, loss: 0.3051925003528595
idx: 200, loss: 0.2878704369068146
accuracy: 0.84
