<a href="https://colab.research.google.com/github/JackBlake-zkq/robust-edge-inference/blob/main/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from fa_ensemble import FiniteAggregationEnsemble
from torchvision.datasets import CIFAR10
from torchvision import transforms
import torch
from torch import nn, optim
import random
import numpy
from tqdm import tqdm
import ssl
from models.resnet import ResNet18

In [2]:
trainset = CIFAR10(root='./data', train=True, download=True, transform=transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]))

testset = CIFAR10(root='./data', train=False, download=True, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]))

Files already downloaded and verified
Files already downloaded and verified


In [4]:
def train_base_model(partition_number, train_subset):
    seed = partition_number
    random.seed(seed)
    numpy.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    curr_lr = 0.1
    epochs = 30
    device = (
        "cuda"
        if torch.cuda.is_available()
        else "mps"
        if torch.backends.mps.is_available()
        else "cpu"
    )

    trainloader = torch.utils.data.DataLoader(train_subset, batch_size=128, shuffle=True, num_workers=1)
    
    ssl._create_default_https_context = ssl._create_unverified_context
    net = ResNet18()

    net = net.to(device)

    criterion = nn.CrossEntropyLoss()

    optimizer = optim.SGD(net.parameters(), lr=curr_lr, momentum=0.9, weight_decay=0.0005, nesterov= True)

    # Training
    net.train()
    for epoch in tqdm(range(epochs)):
        for (inputs, targets) in trainloader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
        if (epoch in [60,120,160]):
            curr_lr = curr_lr * 0.2
            for param_group in optimizer.param_groups:
                param_group['lr'] = curr_lr


    net.eval()
    nomtestloader = torch.utils.data.DataLoader(testset, batch_size=2000, shuffle=False, num_workers=1)
    correct = 0
    total = 0
    for (inputs, targets) in nomtestloader:
        inputs, targets = inputs.to(device), targets.to(device)
        with torch.no_grad():
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            _, predicted = outputs.max(1)
            correct += predicted.eq(targets).sum().item()
            total += targets.size(0)
    acc = 100.*correct/total
    print(f'Accuracy for base model {partition_number}: {str(acc)}%')

    return net


In [5]:
ensemble = FiniteAggregationEnsemble(trainset, testset, train_base_model, 10, state_dir="resnet_single", k=50, d=3, channels=3)
for i in range(150):
    ensemble.train_base_model(i)
ensemble.certify()