<a href="https://colab.research.google.com/github/JackBlake-zkq/robust-edge-inference/blob/main/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from fa_ensemble import FiniteAggregationEnsemble
from torchvision.datasets import CIFAR10, MNIST
from torchvision import transforms
import torch
from torch import nn, optim
import random
import numpy
from tqdm import tqdm
import ssl
from models.resnet import ResNet18, ResNet18_1C, ResNetSmall_1C

In [None]:
trainset = MNIST(root='./datasets/MNIST', train=True, download=True, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
]))
testset = MNIST(root='./datasets/MNIST', train=False, download=True, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
]))

In [None]:
def train_base_model(partition_number, train_subset):
    seed = partition_number
    random.seed(seed)
    numpy.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    curr_lr = 0.01
    epochs = 10
    device = (
        "cuda"
        if torch.cuda.is_available()
        else "mps"
        if torch.backends.mps.is_available()
        else "cpu"
    )
    device += ":" + str(partition_number)

    trainloader = torch.utils.data.DataLoader(train_subset, batch_size=128, shuffle=True, num_workers=1)
    print("subset has ", len(train_subset), "data points")
    
    ssl._create_default_https_context = ssl._create_unverified_context
    net = ResNet18_1C()

    net = net.to(device)

    criterion = nn.CrossEntropyLoss()

    optimizer = optim.SGD(net.parameters(), lr=curr_lr, momentum=0.9, weight_decay=0.0005, nesterov= True)

    # Training
    net.train()
    for epoch in tqdm(range(epochs)):
        for (inputs, targets) in trainloader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
        if (epoch in [60,120,160]):
            curr_lr = curr_lr * 0.2
            for param_group in optimizer.param_groups:
                param_group['lr'] = curr_lr


    net.eval()
    nomtestloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=1)
    correct = 0
    total = 0
    for (inputs, targets) in nomtestloader:
        inputs, targets = inputs.to(device), targets.to(device)
        with torch.no_grad():
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            _, predicted = outputs.max(1)
            correct += predicted.eq(targets).sum().item()
            total += targets.size(0)
        break
    acc = 100.*correct/total
    print(f'Estimated accuracy for base model {partition_number}: {str(acc)}%')

    return net


In [None]:
ensemble = FiniteAggregationEnsemble(trainset, testset, train_base_model, 10, k=30, d=1, state_dir="mnist_k30_d1")
for i in range(30):
    ensemble.train_base_model(i)

In [None]:
ensemble.eval("softmax_median")

In [None]:
ensemble.eval("logit_median")

In [None]:
ensemble.eval(mode="label_voting")

In [None]:
ensemble.distill(ResNetSmall_1C(), 'softmax_median', lr=0.05, epochs=5)

In [None]:
ensemble.distill(ResNetSmall_1C(), mode='logit_median', lr=1e-3, epochs=5)

In [None]:
ensemble.distill(ResNetSmall_1C(), 'label_voting', lr=1e-3, epochs=5)