In [38]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.onnx import export
import onnx
import onnxruntime as ort
import time

In [39]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [40]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.flatten = nn.Flatten()

        fc1_input_size = self.get_dim()

        self.fc1 = nn.Linear(fc1_input_size, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = self.flatten(x)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        x = torch.softmax(x, dim=1)
        return x

    def get_dim(self):
      sample_input = torch.zeros(1, 1, 28, 28)
      output = self.conv2(self.conv1(sample_input))
      fc1_input_size = output.flatten().shape[0]
      return fc1_input_size

In [41]:
transform=transforms.Compose([
        transforms.ToTensor()
        ])

In [42]:
# Load the MNIST dataset:
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

In [43]:
# Prepare the MNIST dataset:

batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [44]:
# Instantiate the model :
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [45]:
# Train the model on the data :
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item()))




In [46]:
# Test the model on the test data:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        outputs = model(data)
        _, predicted = torch.max(outputs.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()

In [47]:
print('Accuracy of the model on the test images: {} %'.format(100 * correct / total))

Accuracy of the model on the test images: 98.65 %


In [55]:
dummy_input = torch.randn(1, 1, 28, 28).to(device)

torch.onnx.export(
    model,
    dummy_input,
    "simple_cnn.onnx",
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
)

In [58]:
ort_session = ort.InferenceSession("simple_cnn.onnx")

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()


def benchmark(batch_size):
    dummy_input = torch.randn(batch_size, 1, 28, 28)
    ort_inputs = {'input': to_numpy(dummy_input)}
    
    start_time = time.time()
    ort_outs = ort_session.run(None, ort_inputs)
    return time.time() - start_time
    

# Benchmark for each batch size
batch_sizes = [1, 8, 32, 128]
times_onnx = {bs: benchmark(bs) for bs in batch_sizes}

print("Benchmark times for different batch sizes:", times_onnx)

Benchmark times for different batch sizes: {1: 0.0008008480072021484, 8: 0.0009520053863525391, 32: 0.003454923629760742, 128: 0.01295018196105957}


In [62]:
# We want to quantize the model using quantize_dynamic method:
quantized_model = torch.quantization.quantize_dynamic(model, {nn.Conv2d, nn.Linear}, dtype=torch.qint8)

In [63]:
# Save the quantized model:
torch.save(quantized_model.state_dict(), "quantized_model.pth")

In [65]:
# Compare the size of the quantized model with the original model:

print("Size of the original model:", sum(p.numel() for p in model.parameters()))

print("Size of the quantized model:", sum(p.numel() for p in quantized_model.parameters()))

Size of the original model: 4738826
Size of the quantized model: 18816


In [None]:
# TODO: refaire avec CIFAR10

In [None]:
# We want to use onnx runtime web to run the model in the browser:

