In [1]:
from fa_ensemble import FiniteAggregationEnsemble
from torchvision.datasets import CIFAR10, MNIST
from torchvision import transforms
import torch
from torch import nn, optim
import random
import numpy
from tqdm import tqdm
import ssl
from models.NetworkInNetwork import NetworkInNetwork
from threading import Thread

In [2]:
trainset = MNIST(root='./datasets/MNIST', train=True, download=True, transform=transforms.Compose([
    transforms.RandomCrop(28, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
]))
testset = MNIST(root='./datasets/MNIST', train=False, download=True, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
]))

In [3]:
def train_base_model(partition_number, train_subset):
    seed = partition_number
    random.seed(seed)
    numpy.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    curr_lr = 0.1
    epochs = 200
    device = (
        "cuda"
        if torch.cuda.is_available()
        else "mps"
        if torch.backends.mps.is_available()
        else "cpu"
    )

    trainloader = torch.utils.data.DataLoader(train_subset, batch_size=128, shuffle=True, num_workers=1)
    print("subset has ", len(train_subset), "data points")
    
    ssl._create_default_https_context = ssl._create_unverified_context
    net = NetworkInNetwork({'num_classes':10, 'num_inchannels': 1})

    net = net.to(device)

    criterion = nn.CrossEntropyLoss()

    optimizer = optim.SGD(net.parameters(), lr=curr_lr, momentum=0.9, weight_decay=0.0005, nesterov= True)

    # Training
    net.train()
    for epoch in tqdm(range(epochs)):
        for (inputs, targets) in trainloader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
        if (epoch in [60,120,160]):
            curr_lr = curr_lr * 0.2
            for param_group in optimizer.param_groups:
                param_group['lr'] = curr_lr


    net.eval()
    nomtestloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=1)
    correct = 0
    total = 0
    for (inputs, targets) in nomtestloader:
        inputs, targets = inputs.to(device), targets.to(device)
        with torch.no_grad():
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            _, predicted = outputs.max(1)
            correct += predicted.eq(targets).sum().item()
            total += targets.size(0)
        break
    acc = 100.*correct/total
    print(f'Estimated accuracy for base model {partition_number}: {str(acc)}%')

    return net


In [4]:
ensemble = FiniteAggregationEnsemble("ensembles/mnist_benchmark_k=1200_d=1", trainset, testset, 10, 1200)
threads = []
for i in range(1200):
    ensemble.train_base_model(i, train_base_model)

Base model 0 already exists
Base model 1 already exists
Base model 2 already exists
Base model 3 already exists
Base model 4 already exists
Base model 5 already exists
Base model 6 already exists
Base model 7 already exists
Base model 8 already exists
Base model 9 already exists
Base model 10 already exists
Base model 11 already exists
Base model 12 already exists
Base model 13 already exists
Base model 14 already exists
Base model 15 already exists
Base model 16 already exists
Base model 17 already exists
Base model 18 already exists
Base model 19 already exists
Base model 20 already exists
Base model 21 already exists
Base model 22 already exists
Base model 23 already exists
Base model 24 already exists
Base model 25 already exists
Base model 26 already exists
Base model 27 already exists
Base model 28 already exists
Base model 29 already exists
Base model 30 already exists
Base model 31 already exists
Base model 32 already exists
Base model 33 already exists
Base model 34 already ex

In [5]:
ensemble.eval("softmax_median")

testset predictions already computed, using those...
Certificates already computed, using those...
Base classifier accuracy: 77.28179166666666
Clean Accuracy: 95.63000000000001%
Median Certified Radius: 496


In [6]:
ensemble.eval("logit_median")

testset predictions already computed, using those...
Certificates already computed, using those...
Base classifier accuracy: 77.28179166666666
Clean Accuracy: 95.39999999999999%
Median Certified Radius: 397


In [7]:
ensemble.eval(mode="label_voting")

testset predictions already computed, using those...
Certificates already computed, using those...
Base classifier accuracy: 77.28179166666666
Clean Accuracy: 95.85000000000001%
Median Certified Radius: 451


In [8]:
ensemble.eval(mode="label_runoff")

testset predictions already computed, using those...
Certificates already computed, using those...
Base classifier accuracy: 77.28179166666666
Clean Accuracy: 95.53%
Median Certified Radius: 467


In [9]:
ensemble.distill(NetworkInNetwork({'num_classes':10, 'num_inchannels': 1}), mode='logit_median', lr=1e-3, epochs=10)

ensembles/mnist_benchmark_k=1200_d=1/students/student_logit_median2305843009213694100.pkl
Generating base model predictions on trainset...
torch.Size([60000, 1200, 10])


  0%|          | 0/1200 [00:00<?, ?it/s]

Predictions for base model 0 on trainset already computed, using those...


  0%|          | 1/1200 [00:00<19:28,  1.03it/s]

Predictions for base model 1 on trainset already computed, using those...


  0%|          | 3/1200 [00:36<5:22:59, 16.19s/it]Error: command buffer exited with error status.
	The Metal Performance Shaders operations encoded on it may not have completed.
	Error: 
	(null)
	Insufficient Memory (00000008:kIOGPUCommandBufferCallbackErrorOutOfMemory)
	<AGXG13GFamilyCommandBuffer: 0x2dd40c950>
    label = <none> 
    device = <AGXG13GDevice: 0x12436ac00>
        name = Apple M1 
    commandQueue = <AGXG13GFamilyCommandQueue: 0x1234daa00>
        label = <none> 
        device = <AGXG13GDevice: 0x12436ac00>
            name = Apple M1 
    retainedReferences = 1
  0%|          | 3/1200 [00:37<4:06:08, 12.34s/it]


KeyboardInterrupt: 

In [None]:
ensemble.distill(NetworkInNetwork({'num_classes':10, 'num_inchannels': 1}), mode='softmax_median', lr=1e-3, epochs=10)

In [None]:
ensemble.distill(NetworkInNetwork({'num_classes':10, 'num_inchannels': 1}), mode='label_voting', lr=1e-3, epochs=10)

In [None]:
ensemble.distill(NetworkInNetwork({'num_classes':10, 'num_inchannels': 1}), mode='label_runoff', lr=1e-3, epochs=10)