In [57]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import os
import onnx
import onnxruntime

# Depthwise Separable Convolution

In [58]:
def conv_bn(inp, oup, stride):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU(inplace=True)
    )

def conv_dw(inp, oup, stride):
    return nn.Sequential(
        # Depthwise
        nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
        nn.BatchNorm2d(inp),
        nn.ReLU(inplace=True),
        
        # Pointwise
        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU(inplace=True),
    )

In [59]:
class MobileNetV1(nn.Module):
    def __init__(self, num_classes=10):
        super(MobileNetV1, self).__init__()
        self.features = nn.Sequential(
            conv_bn(1, 32, 2),    # 28x28x1 -> 14x14x32
            conv_dw(32, 64, 1),   # 14x14x64
            conv_dw(64, 128, 2),  # 7x7x128
            conv_dw(128, 128, 1), # 7x7x128
            conv_dw(128, 256, 2), # 4x4x256
            conv_dw(256, 256, 1), # 4x4x256
            conv_dw(256, 512, 2), # 2x2x512
            conv_dw(512, 512, 1), # 2x2x512
            conv_dw(512, 512, 1), # 2x2x512
            conv_dw(512, 512, 1), # 2x2x512
            conv_dw(512, 512, 1), # 2x2x512
            conv_dw(512, 512, 1), # 2x2x512
            conv_dw(512, 1024, 2),# 1x1x1024
            conv_dw(1024, 1024, 1),# 1x1x1024
        )
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(1024, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = self.avg_pool(x)
        x = x.view(-1, 1024)
        x = self.fc(x)
        return x

In [60]:
def preprocess_data(device, batch_size=64):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
    
    return train_loader, test_loader

In [61]:
def calculate_accuracy(model, data_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, targets in data_loader:
            data, targets = data.to(device), targets.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
    return 100 * correct / total

In [62]:
def train(model, train_loader, criterion, optimizer, device, num_epochs):
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for batch_idx, (data, targets) in enumerate(train_loader):
            data, targets = data.to(device), targets.to(device)
            
            optimizer.zero_grad()
            outputs = model(data)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            
            if batch_idx % 100 == 99:  # print every 100 mini-batches
                print(f'Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx+1}/{len(train_loader)}], Loss: {running_loss/100:.4f}')
                running_loss = 0.0
        
        # Calculate and print training accuracy at the end of each epoch
        train_accuracy = calculate_accuracy(model, train_loader, device)
        print(f'Epoch [{epoch+1}/{num_epochs}], Train Accuracy: {train_accuracy:.2f}%')

In [63]:
def test(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, targets in test_loader:
            data, targets = data.to(device), targets.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
    
    accuracy = 100 * correct / total
    print(f'Accuracy on test set: {accuracy:.2f}%')
    return accuracy

In [64]:
def save_model(model, path='mobilenet_mnist.pth'):
    torch.save(model.state_dict(), path)
    print(f"Model saved to {path}")

In [65]:
def load_model(model, path='mobilenet_mnist.pth', device='cpu'):
    if os.path.exists(path):
        model.load_state_dict(torch.load(path))
        print(f"Model loaded from {path}")
        return True
    return False

In [66]:
def export_to_onnx(model, sample_input, onnx_path='mobilenet_mnist.onnx'):
    torch.onnx.export(model, sample_input, onnx_path, export_params=True, opset_version=10, do_constant_folding=True, input_names=['input'], output_names=['output'], dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}})
    print(f"Model exported to ONNX format at {onnx_path}")

In [67]:
def verify_onnx(onnx_path='mobilenet_mnist.onnx'):
    onnx_model = onnx.load(onnx_path)
    onnx.checker.check_model(onnx_model)
    print("ONNX model is valid")

In [68]:
def test_onnx(onnx_path, test_loader, device):
    session = onnxruntime.InferenceSession(onnx_path, providers=['CPUExecutionProvider'] if device == 'cpu' else ['CUDAExecutionProvider'])
    
    correct = 0
    total = 0
    for data, targets in test_loader:
        data = data.numpy()  # Convert to numpy array
        outputs = session.run(None, {'input': data})
        predicted = outputs[0].argmax(axis=1)
        total += targets.size(0)
        correct += (predicted == targets.numpy()).sum()
    
    accuracy = 100 * correct / total
    print(f'ONNX model accuracy on test set: {accuracy:.2f}%')

In [69]:
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    train_loader, test_loader = preprocess_data(device)
    
    model = MobileNet().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    num_epochs = 10
    model_path = 'mobilenet_mnist.pth'
    onnx_path = 'mobilenet_mnist.onnx'

    if not load_model(model, model_path, device):
        print("Training new model...")
        train(model, train_loader, criterion, optimizer, device, num_epochs)
        save_model(model, model_path)
    else:
        print("Using pre-trained model.")

    test(model, test_loader, device)

    # Export to ONNX
    sample_input = torch.randn(1, 1, 28, 28).to(device)
    export_to_onnx(model, sample_input, onnx_path)
    verify_onnx(onnx_path)
    test_onnx(onnx_path, test_loader, device)

In [70]:
if __name__ == "__main__":
    main()

Using device: cuda
Training new model...
Epoch [1/10], Step [100/938], Loss: 0.9163
Epoch [1/10], Step [200/938], Loss: 0.2285
Epoch [1/10], Step [300/938], Loss: 0.1630
Epoch [1/10], Step [400/938], Loss: 0.1369
Epoch [1/10], Step [500/938], Loss: 0.1118
Epoch [1/10], Step [600/938], Loss: 0.1010
Epoch [1/10], Step [700/938], Loss: 0.0862
Epoch [1/10], Step [800/938], Loss: 0.0857
Epoch [1/10], Step [900/938], Loss: 0.0719
Epoch [1/10], Train Accuracy: 98.44%
Epoch [2/10], Step [100/938], Loss: 0.0488
Epoch [2/10], Step [200/938], Loss: 0.0596
Epoch [2/10], Step [300/938], Loss: 0.0533
Epoch [2/10], Step [400/938], Loss: 0.0649
Epoch [2/10], Step [500/938], Loss: 0.0623
Epoch [2/10], Step [600/938], Loss: 0.0548
Epoch [2/10], Step [700/938], Loss: 0.0712
Epoch [2/10], Step [800/938], Loss: 0.0555
Epoch [2/10], Step [900/938], Loss: 0.0501
Epoch [2/10], Train Accuracy: 99.07%
Epoch [3/10], Step [100/938], Loss: 0.0362
Epoch [3/10], Step [200/938], Loss: 0.0319
Epoch [3/10], Step [300/9