In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from pathlib import Path
import os

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42)

<torch._C.Generator at 0x18eb3934e30>

In [9]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081))])

train_dataset = datasets.MNIST('data/', train=True, transform=transform)
test_dataset = datasets.MNIST('data/', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=True)

In [10]:
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 64)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(64, 64)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = self.relu1(self.fc1(x))
        x = self.relu2(self.fc2(x))
        x = self.fc3(x)
        return x

In [11]:
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    criterion = nn.CrossEntropyLoss()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
           print(f'Train Epoch: {epoch} [{batch_idx*len(data)}/{len(train_loader.dataset)}] Loss: {loss.item():.6f}')

In [12]:
def test(model, device, test_loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
    accuracy = 100. * correct / len(test_loader.dataset)
    print(f'\nTest Accuracy: {accuracy:.2f}%\n')
    return accuracy

In [13]:
float_model = MLP().to(device)
optimizer = optim.Adam(float_model.parameters(), lr=0.001)

In [14]:
if not Path("mlp_mnist.pth").exists():
    for epoch in range(1, 6):
        train(float_model, device, train_loader, optimizer, epoch)
    torch.save(float_model.state_dict(), "mlp_mnist.pth")
else:
    float_model.load_state_dict(torch.load("mlp_mnist.pth"))

Train Epoch: 1 [0/60000] Loss: 2.318032
Train Epoch: 1 [6400/60000] Loss: 0.376820
Train Epoch: 1 [12800/60000] Loss: 0.284974
Train Epoch: 1 [19200/60000] Loss: 0.210929
Train Epoch: 1 [25600/60000] Loss: 0.341595
Train Epoch: 1 [32000/60000] Loss: 0.094097
Train Epoch: 1 [38400/60000] Loss: 0.387732
Train Epoch: 1 [44800/60000] Loss: 0.111830
Train Epoch: 1 [51200/60000] Loss: 0.203923
Train Epoch: 1 [57600/60000] Loss: 0.107853
Train Epoch: 2 [0/60000] Loss: 0.128295
Train Epoch: 2 [6400/60000] Loss: 0.025936
Train Epoch: 2 [12800/60000] Loss: 0.133234
Train Epoch: 2 [19200/60000] Loss: 0.180191
Train Epoch: 2 [25600/60000] Loss: 0.185546
Train Epoch: 2 [32000/60000] Loss: 0.072283
Train Epoch: 2 [38400/60000] Loss: 0.098344
Train Epoch: 2 [44800/60000] Loss: 0.158319
Train Epoch: 2 [51200/60000] Loss: 0.157196
Train Epoch: 2 [57600/60000] Loss: 0.191299
Train Epoch: 3 [0/60000] Loss: 0.113984
Train Epoch: 3 [6400/60000] Loss: 0.216496
Train Epoch: 3 [12800/60000] Loss: 0.015608
Tra

In [15]:
print("Float model:")
test(float_model, device, test_loader)

Float model:

Test Accuracy: 97.12%



97.12

In [16]:
float_model.cpu()
quantized_model = torch.ao.quantization.quantize_dynamic(
    float_model,
    {nn.Linear},
    dtype=torch.qint8
)

In [17]:
print("Quantized model:")
test(quantized_model, torch.device("cpu"), test_loader)

Quantized model:

Test Accuracy: 97.13%



97.13

In [19]:
def print_size_of_model(model):
    torch.save(model.state_dict(), 'temp_delme.p')
    print('Size (KB)', os.path.getsize('temp_delme.p') / 1e3)
    os.remove('temp_delme.p')

In [21]:
print("Size of float model:")
print_size_of_model(float_model)

print("Size of quantized model:")
print_size_of_model(quantized_model)

Size of float model:
Size (KB) 222.822
Size of quantized model:
Size (KB) 60.066
