<a href="https://colab.research.google.com/github/JackBlake-zkq/robust-edge-inference/blob/main/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from fa_ensemble import FiniteAggregationEnsemble
from threading import Thread
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor
import torch
from torch import nn, optim
import random
import numpy
from NetworkInNetwork import NetworkInNetwork
from tqdm import tqdm

In [2]:
trainset = CIFAR10(root='./data', train=True, download=True, transform=ToTensor())
testset = CIFAR10(root='./data', train=False, download=True, transform=ToTensor())

Files already downloaded and verified
Files already downloaded and verified


In [3]:
def train_base_model(partition_number, train_subset, mean, std):
    device = (
        "cuda"
        if torch.cuda.is_available()
        else "mps"
        if torch.backends.mps.is_available()
        else "cpu"
    )
    seed = partition_number
    random.seed(seed)
    numpy.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    curr_lr = 0.1
    epochs = 100

    trainloader = torch.utils.data.DataLoader(train_subset, batch_size=128, shuffle=True, num_workers=1)

    net = NetworkInNetwork({'num_classes': 10, 'num_inchannels': 3})

    net = net.to(device)

    criterion = nn.CrossEntropyLoss()

    optimizer = optim.SGD(net.parameters(), lr=curr_lr, momentum=0.9, weight_decay=0.0005, nesterov= True)

    # Training
    net.train()
    for epoch in tqdm(range(epochs)):
        for (inputs, targets) in trainloader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
        if (epoch in [60,120,160]):
            curr_lr = curr_lr * 0.2
            for param_group in optimizer.param_groups:
                param_group['lr'] = curr_lr


    net.eval()
    nomtestloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=True, num_workers=1)
    correct = 0
    total = 0
    for (inputs, targets) in nomtestloader:
        inputs, targets = inputs.to(device), targets.to(device)
        with torch.no_grad():
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            _, predicted = outputs.max(1)
            correct += predicted.eq(targets).sum().item()
            total += targets.size(0)
    acc = 100.*correct/total
    print(f'Accuracy for base model {partition_number}: {str(acc)}%')

    return net


In [4]:
ensemble = FiniteAggregationEnsemble(trainset, train_base_model, d=3)

Computing partitions...


In [5]:
ensemble.train_base_model(0)


tensor([0.4860, 0.4799, 0.4448]) tensor([0.2483, 0.2457, 0.2654])


  5%|▌         | 5/100 [00:19<06:11,  3.91s/it]


KeyboardInterrupt: 