In [1]:
import torch
from src.util import get_net
import src.config as c
from src.common.network import LayerType
from torch import nn
import torchvision

In [2]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [32]:
def to_device(data, device):
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

class DeviceDataLoader():
    def __init__(self, dl, device, transform = None):
        self.dl = dl
        self.device = device
        self.transform = transform
    def __iter__(self):
        if self.transform is None:
            for b in self.dl:
                yield to_device(b, self.device)
        else:
            for b in self.dl:
                a = to_device(b, self.device)
                yield [self.transform(a[0]), a[0], a[1]]
    def __len__(self):
        return len(self.dl)

def toDeviceDataLoader(*args, device = torch.device('cuda:0'), batch_size = 16, transform = None):
    dls = [torch.utils.data.DataLoader(d, batch_size = batch_size, shuffle = True, drop_last = True) for d in args]
    return [DeviceDataLoader(d, device = device, transform = transform) for d in dls]

def load_mnist(batch_size_train, batch_size_test, device, dataset_path="/share/datasets/mnist/", normalize = True):
    if normalize:
        transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307,), (0.3081,))])
    else:
        transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0,), (1,))])
    train_dataset = torchvision.datasets.MNIST(dataset_path, download=True, train=True, transform=transform)
    test_dataset = torchvision.datasets.MNIST(dataset_path, download=True, train=False, transform=transform)
    train_loader = toDeviceDataLoader(train_dataset, device = device, batch_size = batch_size_train)[0]
    test_loader = toDeviceDataLoader(test_dataset, device = device, batch_size = batch_size_test)[0]
    return train_loader, test_loader

def accuracy(outputs, labels):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

def dsa_01(ds, k, verbose = False):
    tot = 0
    acc = 0
    for batch in ds:
        out = k(batch[0])
        acc += accuracy(out, batch[1]) * len(batch[1])
        tot += len(batch[1])
    if verbose:
        print(acc, tot)
    return (acc/tot).item()

mnist_train_loader_n, mnist_test_loader_n = load_mnist(64, 128, device, normalize = True)
mnist_train_loader, mnist_test_loader = load_mnist(64, 128, device, normalize = False)

In [85]:
torch.nn.Conv2d(10, 15, 20, padding=2).weight.shape

torch.Size([15, 10, 20, 20])

In [12]:
import onnx

In [14]:
model = onnx.load(f'src/nets/{c.MNIST_CONV_SMALL}')

In [21]:
for node in model.graph.node:
  print(node.stride)

AttributeError: stride

In [33]:
def evaluate_net(net, dataset, normalize = True):
    layers = get_net(f'src/nets/{net}', dataset)
    modules = []
    seen_linear = False
    for layer in layers:
        if layer.type == LayerType.Conv2D:
            modules.append(torch.nn.Conv2d(layer.weight.shape[1], layer.weight.shape[0], layer.weight.shape[2], stride = 2, padding=0))
            modules[-1].weight = torch.nn.Parameter(layer.weight)
            modules[-1].bias = torch.nn.Parameter(layer.bias)
        elif layer.type == LayerType.Linear:
            if not seen_linear:
                seen_linear = True
                modules.append(torch.nn.Flatten())
            modules.append(torch.nn.Linear(layer.weight.shape[1], layer.weight.shape[0]))
            modules[-1].weight = torch.nn.Parameter(layer.weight)
            modules[-1].bias = torch.nn.Parameter(layer.bias)
        elif layer.type == LayerType.ReLU: 
            modules.append(torch.nn.ReLU())
        elif layer.type == LayerType.Flatten:
            modules.append(torch.nn.Flatten())
        else:
            raise ValueError(f'Layer type {layer.type} not supported')
    model = nn.Sequential(*modules)
    model.eval().to(device)
    #print(model)

    if dataset == 'cifar':
        print(f"{net}: {dsa_01(dataset, model)}")
    else:
        if normalize:
            print(f"{net}: {dsa_01(mnist_test_loader_n, model)}")
        else:
            print(f"{net}: {dsa_01(mnist_test_loader, model)}")

    

In [37]:
print("Normalization ((0.1307,), (0.3081,))")
evaluate_net(c.MNIST_LINEAR_50, 'mnist', normalize = True)
evaluate_net(c.MNIST_LINEAR_100, 'mnist', normalize = True)
evaluate_net(c.MNIST_CONV_SMALL, 'mnist', normalize = True)
#evaluate_net(c.MNIST_CONV_MED, 'mnist', normalize = True)
evaluate_net(c.MNIST_FFN_PGD, 'mnist', normalize = True)
evaluate_net(c.MNIST_FFN_DIFFAI, 'mnist', normalize = True)
print("Normalization ((0,), (1,))")
evaluate_net(c.MNIST_LINEAR_50, 'mnist', normalize = False)
evaluate_net(c.MNIST_LINEAR_100, 'mnist', normalize = False)
evaluate_net(c.MNIST_CONV_SMALL, 'mnist', normalize = False)
#evaluate_net(c.MNIST_CONV_MED, 'mnist', normalize = False)
evaluate_net(c.MNIST_FFN_PGD, 'mnist', normalize = False)
evaluate_net(c.MNIST_FFN_DIFFAI, 'mnist', normalize = False)

Normalization ((0.1307,), (0.3081,))
mnist_relu_3_50.onnx: 0.938401460647583
mnist_relu_3_100.onnx: 0.9385015964508057
mnist_convSmallRELU__Point.onnx: 0.9824719429016113
mnistconvSmallRELU__PGDK.onnx: 0.989182710647583
mnistconvSmallRELUDiffAI.onnx: 0.9773637652397156
Normalization ((0,), (1,))
mnist_relu_3_50.onnx: 0.9588341116905212
mnist_relu_3_100.onnx: 0.9654446840286255
mnist_convSmallRELU__Point.onnx: 0.9803686141967773
mnistconvSmallRELU__PGDK.onnx: 0.9887820482254028
mnistconvSmallRELUDiffAI.onnx: 0.9536257982254028


In [36]:
evaluate_net(c.MNIST_LINEAR_50, 'mnist', normalize = False)
evaluate_net(c.MNIST_LINEAR_100, 'mnist', normalize = False)
evaluate_net(c.MNIST_CONV_SMALL, 'mnist', normalize = False)
#evaluate_net(c.MNIST_CONV_MED, 'mnist', normalize = False)
evaluate_net(c.MNIST_FFN_PGD, 'mnist', normalize = False)
evaluate_net(c.MNIST_FFN_DIFFAI, 'mnist', normalize = False)

mnist_relu_3_50.onnx: 0.9587339758872986
mnist_relu_3_100.onnx: 0.9654446840286255
mnist_convSmallRELU__Point.onnx: 0.9803686141967773
mnistconvSmallRELU__PGDK.onnx: 0.9888821840286255
mnistconvSmallRELUDiffAI.onnx: 0.9535256624221802
