In [25]:
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torchinfo
import onnx
import torch.onnx
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

In [26]:
class SimpleNet(nn.Module):
    def __init__(self, num_classes=10, init_weights=True):
        super(SimpleNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=4, stride=2)
        self.conv2 = nn.Conv2d(in_channels=3, out_channels=5, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(in_channels=5, out_channels=5, kernel_size=4, stride=2)

        self.linear1 = nn.Linear(in_features=5 * 2 * 2, out_features=10)
        self.linear2 = nn.Linear(in_features=10, out_features=num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = torch.flatten(x, start_dim=1)
        x = self.linear1(x)
        x = F.relu(x)
        x = self.linear2(x)
        return x

class MNISTloader:
    def __init__(
        self,
        batch_size: int = 64,
        data_dir: str = "./data/",
        num_workers: int = 0,
        pin_memory: bool = False,
        shuffle: bool = False,
        train_val_split: float = 0.1,
    ):
        self.batch_size = batch_size
        self.data_dir = data_dir
        self.num_workers = num_workers
        self.pin_memory = pin_memory
        self.shuffle = shuffle
        self.train_val_split = train_val_split

        self.setup()

    def setup(self):
        transform = transforms.Compose(
            [
                transforms.Resize((32, 32)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5], std=[0.5]),
            ]
        )

        self.train_dataset = datasets.MNIST(
            self.data_dir, train=True, download=True, transform=transform
        )
        val_split = int(len(self.train_dataset) * self.train_val_split)
        train_split = len(self.train_dataset) - val_split

        self.train_dataset, self.val_dataset = random_split(
            self.train_dataset, [train_split, val_split]
        )
        self.test_dataset = datasets.MNIST(
            self.data_dir, train=False, download=True, transform=transform
        )

        print(
            "Image Shape:    {}".format(self.train_dataset[0][0].numpy().shape),
            end="\n\n",
        )
        print("Training Set:   {} samples".format(len(self.train_dataset)))
        print("Validation Set: {} samples".format(len(self.val_dataset)))
        print("Test Set:       {} samples".format(len(self.test_dataset)))

    def load(self):
        train_loader = DataLoader(
            dataset=self.train_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            shuffle=self.shuffle,
        )

        val_loader = DataLoader(
            dataset=self.val_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            shuffle=self.shuffle,
        )

        test_loader = DataLoader(
            dataset=self.test_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            shuffle=self.shuffle,
        )

        return train_loader, val_loader, test_loader

def train(device, lr, model, optimizer, criterion, train_loader):

    train_loss_running, train_acc_running = 0, 0

    model.train().cuda() if torch.cuda.is_available() else model.train()

    for inputs, labels in train_loader:

        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)

        _, predictions = torch.max(outputs, dim=1)
        loss = criterion(outputs, labels)
        
        loss.backward()
        optimizer.step()

        train_loss_running += loss.item() * inputs.shape[0]
        train_acc_running += torch.sum(predictions == labels.data)

    train_loss = train_loss_running / len(train_loader.sampler)
    train_acc = train_acc_running / len(train_loader.sampler)
    
    return train_loss, train_acc
    
def evaluate(device, model, criterion, val_loader):

    val_loss_running, val_acc_running = 0, 0
    
    model.eval().cuda() if torch.cuda.is_available() else model.eval()

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            _, predictions = torch.max(outputs, dim=1)
            val_loss_running += loss.item() * inputs.shape[0]
            val_acc_running += torch.sum(predictions == labels.data)

        val_loss = val_loss_running / len(val_loader.sampler)
        val_acc = val_acc_running / len(val_loader.sampler)

    return val_loss, val_acc

In [27]:
train_loader, val_loader, test_loader = MNISTloader(train_val_split=0.1).load()
model = SimpleNet()
torchinfo.summary(model, input_size=(16, 1, 32, 32))

lr = 0.001
num_epochs = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr)
criterion = nn.CrossEntropyLoss()

for epoch in range(num_epochs):
    train_loss, train_acc = train(device, lr, model, optimizer, criterion, train_loader)
    val_loss, val_acc = evaluate(device, model, criterion, val_loader)
    info = "Epoch: {:3}/{} \t train_Loss: {:.3f} \t train_acc: {:.3f} \t val_loss: {:.3f} \t val_acc: {:.3f}"
    print(info.format(epoch + 1, num_epochs, train_loss, train_acc, val_loss, val_acc))

Image Shape:    (1, 32, 32)

Training Set:   54000 samples
Validation Set: 6000 samples
Test Set:       10000 samples
Epoch:   1/1 	 train_Loss: 0.889 	 train_acc: 0.694 	 val_loss: 0.511 	 val_acc: 0.844


In [28]:
evaluate(device, model, criterion, test_loader)

(0.4737554032325745, tensor(0.8590, device='cuda:0'))

In [30]:
x = torch.randn(16, 1, 32, 32, requires_grad=True)
x = x.cuda() if torch.cuda.is_available() else x.cpu()

# Export the model
torch.onnx.export(model,                     # model being run
                  x,                         # model input (or a tuple for multiple inputs)
                  "model.onnx",              # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=13,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output'], # the model's output names
                  dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                                'output' : {0 : 'batch_size'}})

In [31]:
def onnx_check_model(onnx_model):
    try:
        onnx.checker.check_model(onnx_model)
    except onnx.checker.ValidationError as e:
        print('The model is invalid: %s' % e)
    else:
        print('The model is valid!')

model_onnx = onnx.load("model.onnx")
onnx_check_model(model_onnx)

The model is valid!


In [32]:
print(model_onnx.graph.initializer[0].dims)
print(model_onnx.graph.initializer[1].dims)
print(model_onnx.graph.initializer[2].dims)
print(model_onnx.graph.initializer[3].dims)

[3, 1, 4, 4]
[3]
[5, 3, 4, 4]
[5]


In [33]:
from converter import parse_onnx_model

%load_ext autoreload
%autoreload 2

In [34]:
parse_onnx_model(model_onnx)

OrderedDict([('input_nodes', ['input']),
             ('Conv_0',
              {'op_type': CONV2D,
               'initializer': {'weight': {'name': 'conv1.weight',
                 'raw_data': b'\xfd\xd1\x9e\xbe\x806\x98\xbeO~&\xbe\xfd\xf0\xcd\xbe\xc3\x1d\x8b\xbe\x93VI\xbe\xa5\xa5\x93\xbem\x9a\xd1\xbe\x0e^k\xbe\xa5K\xe7\xbe]1\xd5\xbe\xc2Q\xd3\xbd\xc2\xa3\xeb\xbe&l"\xbe\xc8\xfe\x89\xbe=\xd5\x92=\xca3\xe8\xbe\xf2\x04\xcc\xbd2):\xbe\xcd\xff\xd0\xbd@\x93\x05\xbf8\xb1\x05\xbf\x8ctp\xbe\x8ap\xeb\xbe\xad\xe1\xc0\xbei\x83\xf4\xbe#\x19J\xbe\xc3\xb9N\xbe\xaan@\xbe\xb4\xfe\xfe\xbdQ[\x13\xbe\xb2\x7f\xcd\xbb\xd3>\xa9=\xea\xb2r>$b\x0b>d\xe0\xad>\xb1\xe4\xf5>\xd6i\x04?2\xba\x0b>\xcb\x9ew\xbe\xe2\r\xc9>\xc2)\x03?/\x87\xb1>bq\xcb<G\xdb\xfe>\xda\xce\xab=\x0c\xb9s\xbe{\xc8\xca\xbe',
                 'dims': [3, 1, 4, 4],
                 'data_type': FLOAT},
                'bias': {'name': 'conv1.bias',
                 'raw_data': b'\x86e\x81>\xd9\x916>tOt>',
                 'dims': [3],
            