<a href="https://colab.research.google.com/github/JackBlake-zkq/robust-edge-inference/blob/main/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from fa_ensemble import FiniteAggregationEnsemble
from torchvision.datasets import CIFAR10
from torchvision import transforms
import torch
from torch import nn, optim
import random
import numpy
from tqdm import tqdm
import ssl
from models.resnet import ResNet18, ResNetSmall

In [2]:
ssl._create_default_https_context = ssl._create_unverified_context
trainset = CIFAR10(root='./datasets/CIFAR10', train=True, download=True, transform=transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]))
testset = CIFAR10(root='./datasets/CIFAR10', train=False, download=True, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]))

Files already downloaded and verified
Files already downloaded and verified


In [3]:
def train_base_model(partition_number, train_subset):
    seed = partition_number
    random.seed(seed)
    numpy.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    curr_lr = 0.01
    epochs = 100
    device = (
        "cuda"
        if torch.cuda.is_available()
        else "mps"
        if torch.backends.mps.is_available()
        else "cpu"
    )
    device += ":" + str(partition_number)

    trainloader = torch.utils.data.DataLoader(train_subset, batch_size=128, shuffle=True, num_workers=1)
    print("subset has ", len(train_subset), "data points")
    
    ssl._create_default_https_context = ssl._create_unverified_context
    net = ResNet18()

    net = net.to(device)

    criterion = nn.CrossEntropyLoss()

    optimizer = optim.SGD(net.parameters(), lr=curr_lr, momentum=0.9, weight_decay=0.0005, nesterov= True)

    # Training
    net.train()
    for epoch in tqdm(range(epochs)):
        for (inputs, targets) in trainloader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
        if (epoch in [60,120,160]):
            curr_lr = curr_lr * 0.2
            for param_group in optimizer.param_groups:
                param_group['lr'] = curr_lr


    net.eval()
    nomtestloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=1)
    correct = 0
    total = 0
    for (inputs, targets) in nomtestloader:
        inputs, targets = inputs.to(device), targets.to(device)
        with torch.no_grad():
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            _, predicted = outputs.max(1)
            correct += predicted.eq(targets).sum().item()
            total += targets.size(0)
        break
    acc = 100.*correct/total
    print(f'Estimated accuracy for base model {partition_number}: {str(acc)}%')

    return net


In [4]:
ensemble = FiniteAggregationEnsemble("ensembles/cifar10_k50", trainset, testset, 10, 50)
for i in range(50):
    ensemble.train_base_model(i, train_base_model)

Base model 0 already exists
Base model 1 already exists
Base model 2 already exists
Base model 3 already exists
Base model 4 already exists
Base model 5 already exists
Base model 6 already exists
Base model 7 already exists
Base model 8 already exists
Base model 9 already exists
Base model 10 already exists
Base model 11 already exists
Base model 12 already exists
Base model 13 already exists
Base model 14 already exists
Base model 15 already exists
Base model 16 already exists
Base model 17 already exists
Base model 18 already exists
Base model 19 already exists
Base model 20 already exists
Base model 21 already exists
Base model 22 already exists
Base model 23 already exists
Base model 24 already exists
Base model 25 already exists
Base model 26 already exists
Base model 27 already exists
Base model 28 already exists
Base model 29 already exists
Base model 30 already exists
Base model 31 already exists
Base model 32 already exists
Base model 33 already exists
Base model 34 already ex

In [5]:
ensemble.eval("softmax_median")

testset predictions already computed, using those...
Certificates already computed, using those...
Base classifier accuracy: 50.0012
Clean Accuracy: 65.60000000000001%
Median Certified Radius: 8


In [6]:
ensemble.eval("logit_median")

testset predictions already computed, using those...
Certificates already computed, using those...
Base classifier accuracy: 50.0012
Clean Accuracy: 65.49000000000001%
Median Certified Radius: 5


In [7]:
ensemble.eval(mode="label_voting")

testset predictions already computed, using those...
Certificates already computed, using those...
Base classifier accuracy: 50.0012
Clean Accuracy: 65.36%
Median Certified Radius: 6


In [8]:
ensemble.eval(mode="label_runoff")

testset predictions already computed, using those...
Certificates already computed, using those...
Base classifier accuracy: 50.0012
Clean Accuracy: 65.34%
Median Certified Radius: 7


In [9]:
ensemble.distill(ResNetSmall(), 'label_runoff', lr=1e-3, epochs=10)

ensembles/cifar10_k50/students/student_label_runoff2305843009213694100.pkl
Student already trained with those parameters at ensembles/cifar10_k50/students/student_label_runoff2305843009213694100.pkl, loading it in instead of training again...
Evaluating Student


100%|██████████| 79/79 [00:11<00:00,  7.06it/s]

Accuracy for student: 63.19%





In [10]:
ensemble.distill(ResNetSmall(), 'label_runoff', lr=1e-2, epochs=10)

ensembles/cifar10_k50/students/student_label_runoff23058430092136940100.pkl
Student already trained with those parameters at ensembles/cifar10_k50/students/student_label_runoff23058430092136940100.pkl, loading it in instead of training again...
Evaluating Student


100%|██████████| 79/79 [00:10<00:00,  7.45it/s]

Accuracy for student: 63.66%





In [11]:
ensemble.distill(ResNetSmall(), 'softmax_median', lr=1e-3, epochs=10)

ensembles/cifar10_k50/students/student_softmax_median2305843009213694100.pkl
Student already trained with those parameters at ensembles/cifar10_k50/students/student_softmax_median2305843009213694100.pkl, loading it in instead of training again...
Evaluating Student


100%|██████████| 79/79 [00:10<00:00,  7.37it/s]

Accuracy for student: 45.01%





In [12]:
ensemble.distill(ResNetSmall(), 'softmax_median', lr=1e-2, epochs=10)

ensembles/cifar10_k50/students/student_softmax_median23058430092136940100.pkl
Student already trained with those parameters at ensembles/cifar10_k50/students/student_softmax_median23058430092136940100.pkl, loading it in instead of training again...
Evaluating Student


100%|██████████| 79/79 [00:11<00:00,  7.12it/s]

Accuracy for student: 59.19%





In [13]:
ensemble.distill(ResNetSmall(), 'softmax_median', lr=1e-1, epochs=10)

ensembles/cifar10_k50/students/student_softmax_median230584300921369408100.pkl
Student already trained with those parameters at ensembles/cifar10_k50/students/student_softmax_median230584300921369408100.pkl, loading it in instead of training again...
Evaluating Student


100%|██████████| 79/79 [00:11<00:00,  6.93it/s]

Accuracy for student: 61.21%





In [14]:
ensemble.distill(ResNetSmall(), mode='logit_median', lr=1e-3, epochs=10)

ensembles/cifar10_k50/students/student_logit_median2305843009213694100.pkl
Student already trained with those parameters at ensembles/cifar10_k50/students/student_logit_median2305843009213694100.pkl, loading it in instead of training again...
Evaluating Student


100%|██████████| 79/79 [00:10<00:00,  7.27it/s]

Accuracy for student: 64.05%





In [15]:
ensemble.distill(ResNetSmall(), mode='logit_median', lr=1e-2, epochs=10)

ensembles/cifar10_k50/students/student_logit_median23058430092136940100.pkl
Student already trained with those parameters at ensembles/cifar10_k50/students/student_logit_median23058430092136940100.pkl, loading it in instead of training again...
Evaluating Student


100%|██████████| 79/79 [00:10<00:00,  7.30it/s]

Accuracy for student: 64.17%





In [16]:
ensemble.distill(ResNetSmall(), 'label_voting', lr=1e-3, epochs=10)

ensembles/cifar10_k50/students/student_label_voting2305843009213694100.pkl
trainset predictions already computed, using those...
Epoch 1/10


100%|██████████| 391/391 [01:36<00:00,  4.06it/s]


Trainset Accuracy: 38.403999999999996%
Epoch 2/10


100%|██████████| 391/391 [01:33<00:00,  4.17it/s]


Trainset Accuracy: 52.205999999999996%
Epoch 3/10


100%|██████████| 391/391 [01:33<00:00,  4.19it/s]


Trainset Accuracy: 56.728%
Epoch 4/10


100%|██████████| 391/391 [01:31<00:00,  4.29it/s]


Trainset Accuracy: 59.214%
Epoch 5/10


100%|██████████| 391/391 [01:31<00:00,  4.28it/s]


Trainset Accuracy: 60.67399999999999%
Epoch 6/10


100%|██████████| 391/391 [01:31<00:00,  4.29it/s]


Trainset Accuracy: 61.512%
Epoch 7/10


100%|██████████| 391/391 [01:31<00:00,  4.28it/s]


Trainset Accuracy: 62.188%
Epoch 8/10


100%|██████████| 391/391 [01:31<00:00,  4.29it/s]


Trainset Accuracy: 62.78%
Epoch 9/10


100%|██████████| 391/391 [01:31<00:00,  4.30it/s]


Trainset Accuracy: 63.27%
Epoch 10/10


100%|██████████| 391/391 [01:31<00:00,  4.29it/s]


Trainset Accuracy: 63.636%
Finished training student, saving to ensembles/cifar10_k50/students/student_label_voting2305843009213694100.pkl
Evaluating Student


100%|██████████| 79/79 [00:11<00:00,  7.14it/s]

Accuracy for student: 63.48%





In [17]:
ensemble.distill(ResNetSmall(), 'label_voting', lr=1e-2, epochs=10)

ensembles/cifar10_k50/students/student_label_voting23058430092136940100.pkl
trainset predictions already computed, using those...
Epoch 1/10


100%|██████████| 391/391 [01:30<00:00,  4.30it/s]


Trainset Accuracy: 49.834%
Epoch 2/10


100%|██████████| 391/391 [01:29<00:00,  4.37it/s]


Trainset Accuracy: 59.812%
Epoch 3/10


100%|██████████| 391/391 [01:29<00:00,  4.38it/s]


Trainset Accuracy: 61.614000000000004%
Epoch 4/10


100%|██████████| 391/391 [01:29<00:00,  4.36it/s]


Trainset Accuracy: 62.775999999999996%
Epoch 5/10


100%|██████████| 391/391 [01:29<00:00,  4.36it/s]


Trainset Accuracy: 63.273999999999994%
Epoch 6/10


100%|██████████| 391/391 [01:29<00:00,  4.36it/s]


Trainset Accuracy: 63.46000000000001%
Epoch 7/10


100%|██████████| 391/391 [01:30<00:00,  4.33it/s]


Trainset Accuracy: 63.864%
Epoch 8/10


100%|██████████| 391/391 [01:31<00:00,  4.28it/s]


Trainset Accuracy: 64.118%
Epoch 9/10


100%|██████████| 391/391 [01:32<00:00,  4.25it/s]


Trainset Accuracy: 64.288%
Epoch 10/10


100%|██████████| 391/391 [01:30<00:00,  4.33it/s]


Trainset Accuracy: 64.574%
Finished training student, saving to ensembles/cifar10_k50/students/student_label_voting23058430092136940100.pkl
Evaluating Student


100%|██████████| 79/79 [00:15<00:00,  5.26it/s]

Accuracy for student: 63.77%





In [18]:
ensemble.distill(ResNetSmall(), 'label_voting', lr=1e-1, epochs=10)

ensembles/cifar10_k50/students/student_label_voting230584300921369408100.pkl
trainset predictions already computed, using those...
Epoch 1/10


100%|██████████| 391/391 [01:31<00:00,  4.28it/s]


Trainset Accuracy: 46.078%
Epoch 2/10


100%|██████████| 391/391 [01:31<00:00,  4.26it/s]


Trainset Accuracy: 57.984%
Epoch 3/10


100%|██████████| 391/391 [01:31<00:00,  4.26it/s]


Trainset Accuracy: 60.362%
Epoch 4/10


100%|██████████| 391/391 [01:31<00:00,  4.27it/s]


Trainset Accuracy: 61.678%
Epoch 5/10


100%|██████████| 391/391 [01:31<00:00,  4.26it/s]


Trainset Accuracy: 62.392%
Epoch 6/10


100%|██████████| 391/391 [01:31<00:00,  4.26it/s]


Trainset Accuracy: 62.682%
Epoch 7/10


100%|██████████| 391/391 [01:30<00:00,  4.32it/s]


Trainset Accuracy: 62.842%
Epoch 8/10


100%|██████████| 391/391 [01:31<00:00,  4.26it/s]


Trainset Accuracy: 63.054%
Epoch 9/10


100%|██████████| 391/391 [01:31<00:00,  4.26it/s]


Trainset Accuracy: 63.266%
Epoch 10/10


100%|██████████| 391/391 [01:31<00:00,  4.27it/s]


Trainset Accuracy: 63.482000000000006%
Finished training student, saving to ensembles/cifar10_k50/students/student_label_voting230584300921369408100.pkl
Evaluating Student


100%|██████████| 79/79 [00:10<00:00,  7.21it/s]

Accuracy for student: 61.47%



