<a href="https://colab.research.google.com/github/JackBlake-zkq/robust-edge-inference/blob/main/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from fa_ensemble import FiniteAggregationEnsemble
from torchvision.datasets import CIFAR10
from torchvision import transforms
import torch
from torch import nn, optim
import random
import numpy
from tqdm import tqdm
import ssl
from models.resnet import ResNet18, ResNetSmall

In [2]:
ssl._create_default_https_context = ssl._create_unverified_context
trainset = CIFAR10(root='./datasets/CIFAR10', train=True, download=True, transform=transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]))
testset = CIFAR10(root='./datasets/CIFAR10', train=False, download=True, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]))

Files already downloaded and verified
Files already downloaded and verified


In [3]:
def train_base_model(partition_number, train_subset):
    seed = partition_number
    random.seed(seed)
    numpy.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    curr_lr = 0.01
    epochs = 100
    device = (
        "cuda"
        if torch.cuda.is_available()
        else "mps"
        if torch.backends.mps.is_available()
        else "cpu"
    )
    device += ":" + str(partition_number)

    trainloader = torch.utils.data.DataLoader(train_subset, batch_size=128, shuffle=True, num_workers=1)
    print("subset has ", len(train_subset), "data points")
    
    ssl._create_default_https_context = ssl._create_unverified_context
    net = ResNet18()

    net = net.to(device)

    criterion = nn.CrossEntropyLoss()

    optimizer = optim.SGD(net.parameters(), lr=curr_lr, momentum=0.9, weight_decay=0.0005, nesterov= True)

    # Training
    net.train()
    for epoch in tqdm(range(epochs)):
        for (inputs, targets) in trainloader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
        if (epoch in [60,120,160]):
            curr_lr = curr_lr * 0.2
            for param_group in optimizer.param_groups:
                param_group['lr'] = curr_lr


    net.eval()
    nomtestloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=1)
    correct = 0
    total = 0
    for (inputs, targets) in nomtestloader:
        inputs, targets = inputs.to(device), targets.to(device)
        with torch.no_grad():
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            _, predicted = outputs.max(1)
            correct += predicted.eq(targets).sum().item()
            total += targets.size(0)
        break
    acc = 100.*correct/total
    print(f'Estimated accuracy for base model {partition_number}: {str(acc)}%')

    return net


In [4]:
ensemble = FiniteAggregationEnsemble(trainset, testset, train_base_model, 10, k=50, d=1, state_dir="cifar10_k50")
for i in range(50):
    ensemble.train_base_model(i)

Partitions not computed yet, computing now...
Computing partitions...
Finished computing partitions, saving to ensembles/cifar10_k50/partition_info.pth
Partitions saved
Training Base model 0..
subset has  938 data points


100%|██████████| 100/100 [07:40<00:00,  4.61s/it]


Estimated accuracy for base model 0: 47.65625%
Saving Base model 0..
Base model 0 saved
Training Base model 1..
subset has  1022 data points


100%|██████████| 100/100 [07:42<00:00,  4.62s/it]


Estimated accuracy for base model 1: 49.21875%
Saving Base model 1..
Base model 1 saved
Training Base model 2..
subset has  1023 data points


100%|██████████| 100/100 [07:44<00:00,  4.65s/it]


Estimated accuracy for base model 2: 53.125%
Saving Base model 2..
Base model 2 saved
Training Base model 3..
subset has  965 data points


100%|██████████| 100/100 [07:39<00:00,  4.59s/it]


Estimated accuracy for base model 3: 51.5625%
Saving Base model 3..
Base model 3 saved
Training Base model 4..
subset has  979 data points


100%|██████████| 100/100 [07:43<00:00,  4.64s/it]


Estimated accuracy for base model 4: 51.5625%
Saving Base model 4..
Base model 4 saved
Training Base model 5..
subset has  1013 data points


100%|██████████| 100/100 [07:42<00:00,  4.62s/it]


Estimated accuracy for base model 5: 53.125%
Saving Base model 5..
Base model 5 saved
Training Base model 6..
subset has  1055 data points


100%|██████████| 100/100 [08:15<00:00,  4.96s/it]


Estimated accuracy for base model 6: 47.65625%
Saving Base model 6..
Base model 6 saved
Training Base model 7..
subset has  930 data points


100%|██████████| 100/100 [07:38<00:00,  4.59s/it]


Estimated accuracy for base model 7: 58.59375%
Saving Base model 7..
Base model 7 saved
Training Base model 8..
subset has  1030 data points


100%|██████████| 100/100 [08:14<00:00,  4.95s/it]


Estimated accuracy for base model 8: 46.09375%
Saving Base model 8..
Base model 8 saved
Training Base model 9..
subset has  1015 data points


  3%|▎         | 3/100 [00:15<08:37,  5.33s/it]


KeyboardInterrupt: 

In [None]:
ensemble.eval("softmax_median")

testset predictions already computed, using those...
Certificates already computed, using those...
Base classifier accuracy: 97.14733333333334
Ensembe Accuracy: 98.07000000000001%
Certified Radius (for at least half of inputs): 15


In [None]:
ensemble.eval("logit_median")

testset predictions already computed, using those...
Certificates already computed, using those...
Base classifier accuracy: 97.14733333333334
Ensembe Accuracy: 98.11999999999999%
Certified Radius (for at least half of inputs): 15


In [None]:
ensemble.eval(mode="label_voting")

testset predictions already computed, using those...
Certificates already computed, using those...
Base classifier accuracy: 97.14733333333334
Ensembe Accuracy: 98.06%
Certified Radius (for at least half of inputs): 15


In [None]:
ensemble.distill(ResNetSmall(), 'softmax_median', lr=1e-3, epochs=5)

trainset predictions already computed, using those...


  softmaxes, _ = softmaxes_by_class.to(device).sort(dim=2)


Student for softmax_median distillation mode already trained
Evaluating Student


100%|██████████| 79/79 [00:05<00:00, 14.59it/s]

Accuracy for student: 97.88%





In [None]:
ensemble.distill(ResNetSmall(), mode='logit_median', lr=1e-3, epochs=5)

trainset predictions already computed, using those...
Epoch 1/5


100%|██████████| 469/469 [01:36<00:00,  4.87it/s]


Epoch 2/5


100%|██████████| 469/469 [01:35<00:00,  4.89it/s]


Epoch 3/5


100%|██████████| 469/469 [01:36<00:00,  4.84it/s]


Epoch 4/5


100%|██████████| 469/469 [01:37<00:00,  4.83it/s]


Epoch 5/5


100%|██████████| 469/469 [01:36<00:00,  4.87it/s]


Finished training student, saving to ensembles/mnist_k30_d1/student_logit_median.pkl
Evaluating Student


100%|██████████| 79/79 [00:05<00:00, 14.67it/s]

Accuracy for student: 97.85%





In [None]:
ensemble.distill(ResNetSmall(), 'label_voting', lr=1e-3, epochs=5)

trainset predictions already computed, using those...
Student for label_voting distillation mode already trained
Evaluating Student


100%|██████████| 79/79 [00:05<00:00, 14.55it/s]

Accuracy for student: 98.01%



