<a href="https://colab.research.google.com/github/JackBlake-zkq/robust-edge-inference/blob/main/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from fa_ensemble import FiniteAggregationEnsemble
from torchvision.datasets import CIFAR10, MNIST
from torchvision import transforms
import torch
from torch import nn, optim
import random
import numpy
from tqdm import tqdm
import ssl
from models.resnet import ResNet18, ResNet18_1C, ResNetSmall_1C

In [2]:
trainset = MNIST(root='./datasets/MNIST', train=True, download=True, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
]))
testset = MNIST(root='./datasets/MNIST', train=False, download=True, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
]))

In [3]:
def train_base_model(partition_number, train_subset):
    seed = partition_number
    random.seed(seed)
    numpy.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    curr_lr = 0.01
    epochs = 10
    device = (
        "cuda"
        if torch.cuda.is_available()
        else "mps"
        if torch.backends.mps.is_available()
        else "cpu"
    )
    device += ":" + str(partition_number)

    trainloader = torch.utils.data.DataLoader(train_subset, batch_size=128, shuffle=True, num_workers=1)
    print("subset has ", len(train_subset), "data points")
    
    ssl._create_default_https_context = ssl._create_unverified_context
    net = ResNet18_1C()

    net = net.to(device)

    criterion = nn.CrossEntropyLoss()

    optimizer = optim.SGD(net.parameters(), lr=curr_lr, momentum=0.9, weight_decay=0.0005, nesterov= True)

    # Training
    net.train()
    for epoch in tqdm(range(epochs)):
        for (inputs, targets) in trainloader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
        if (epoch in [60,120,160]):
            curr_lr = curr_lr * 0.2
            for param_group in optimizer.param_groups:
                param_group['lr'] = curr_lr


    net.eval()
    nomtestloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=1)
    correct = 0
    total = 0
    for (inputs, targets) in nomtestloader:
        inputs, targets = inputs.to(device), targets.to(device)
        with torch.no_grad():
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            _, predicted = outputs.max(1)
            correct += predicted.eq(targets).sum().item()
            total += targets.size(0)
        break
    acc = 100.*correct/total
    print(f'Estimated accuracy for base model {partition_number}: {str(acc)}%')

    return net


In [4]:
ensemble = FiniteAggregationEnsemble("ensembles/mnist_k30_d1", trainset, testset, 10, 30)
for i in range(30):
    ensemble.train_base_model(i, train_base_model)

Base model 0 already exists
Base model 1 already exists
Base model 2 already exists
Base model 3 already exists
Base model 4 already exists
Base model 5 already exists
Base model 6 already exists
Base model 7 already exists
Base model 8 already exists
Base model 9 already exists
Base model 10 already exists
Base model 11 already exists
Base model 12 already exists
Base model 13 already exists
Base model 14 already exists
Base model 15 already exists
Base model 16 already exists
Base model 17 already exists
Base model 18 already exists
Base model 19 already exists
Base model 20 already exists
Base model 21 already exists
Base model 22 already exists
Base model 23 already exists
Base model 24 already exists
Base model 25 already exists
Base model 26 already exists
Base model 27 already exists
Base model 28 already exists
Base model 29 already exists


In [5]:
ensemble.eval("softmax_median")

testset predictions already computed, using those...
Certificates already computed, using those...
Base classifier accuracy: 97.14733333333334
Clean Accuracy: 98.07000000000001%
Median Certified Radius: 15


In [6]:
ensemble.eval("logit_median")

testset predictions already computed, using those...
Certificates already computed, using those...
Base classifier accuracy: 97.14733333333334
Clean Accuracy: 98.11999999999999%
Median Certified Radius: 15


In [7]:
ensemble.eval(mode="label_voting")

testset predictions already computed, using those...
Certificates already computed, using those...
Base classifier accuracy: 97.14733333333334
Clean Accuracy: 98.06%
Median Certified Radius: 15


In [8]:
ensemble.eval(mode="label_runoff")

testset predictions already computed, using those...
Certificates already computed, using those...
Base classifier accuracy: 97.14733333333334
Clean Accuracy: 98.00999999999999%
Median Certified Radius: 15


In [9]:
ensemble.distill(ResNetSmall_1C(), 'softmax_median', lr=0.05, epochs=5)

ensembles/mnist_k30_d1/students/student_softmax_median11529215046068470450.pkl
Student already trained with those parameters at ensembles/mnist_k30_d1/students/student_softmax_median11529215046068470450.pkl, loading it in instead of training again...
Evaluating Student


100%|██████████| 79/79 [00:05<00:00, 14.53it/s]

Accuracy for student: 98.64%





In [10]:
ensemble.distill(ResNetSmall_1C(), mode='logit_median', lr=1e-3, epochs=5)

ensembles/mnist_k30_d1/students/student_logit_median230584300921369450.pkl
Student already trained with those parameters at ensembles/mnist_k30_d1/students/student_logit_median230584300921369450.pkl, loading it in instead of training again...
Evaluating Student


100%|██████████| 79/79 [00:05<00:00, 15.16it/s]

Accuracy for student: 97.92%





In [11]:
ensemble.distill(ResNetSmall_1C(), 'label_voting', lr=1e-3, epochs=5)

ensembles/mnist_k30_d1/students/student_label_voting230584300921369450.pkl
Student already trained with those parameters at ensembles/mnist_k30_d1/students/student_label_voting230584300921369450.pkl, loading it in instead of training again...
Evaluating Student


100%|██████████| 79/79 [00:05<00:00, 15.38it/s]

Accuracy for student: 98.01%





In [12]:
ensemble.certified_accuracy("softmax_median", 10)

0.9651

In [13]:
ensemble.certified_accuracy("softmax_median", 15)

0.9232

In [14]:
ensemble.certified_accuracy("softmax_median", 20)

0