In [1]:
from typing import Any, Callable, Dict, List, Optional, Tuple
from collections import OrderedDict

import flwr as fl
import torch
from PIL import Image
import torchvision.transforms as transforms
import torchvision.datasets
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10

import warnings
warnings.filterwarnings('ignore')

DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f'Using {DEVICE} for inference')


Using cpu for inference


In [2]:
def load_data():
    """Load CIFAR-10 (training and test set)."""
    transform = transforms.Compose(
        [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])]
    )

    trainset = CIFAR10("./dataset", train=True, download=True, transform=transform)
    testset = CIFAR10("./dataset", train=False, download=True, transform=transform)

    num_examples = {"trainset": len(trainset), "testset": len(testset)}
    return trainset, testset, num_examples

In [3]:
def load_partition(idx: int):
    """Load 1/10th of the training and test data to simulate a partition."""
    assert idx in range(10)
    trainset, testset, num_examples = load_data()
    n_train = int(num_examples['trainset']/10)
    n_test = int(num_examples['testset']/10)

    train_parition = torch.utils.data.Subset(
        trainset, range(idx * n_train,  (idx + 1) * n_train)
        )
    test_parition = torch.utils.data.Subset(
        testset, range(idx * n_test,  (idx + 1) * n_test)
    )
    return  (train_parition, test_parition)



In [4]:
def train(net, trainloader, valloader, epochs):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss().to(DEVICE)
    optimizer = torch.optim.SGD(
        net.parameters(), lr=0.1,
        momentum=0.9,
        weight_decay=1e-4
    )
    net.train()
    for _ in range(epochs):
        for images, labels in trainloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            loss = criterion(net(images), labels)
            loss.backward()
            optimizer.step()
    
    train_loss, train_acc =  test(net, trainloader)
    val_loss, val_acc = test(net, valloader)

    results = {
        "train_loss": train_loss,
        "train_accuracy": train_acc,
        "val_loss": val_loss,
        "val_accuracy": val_acc,
    }
    return results


In [5]:
def test(net, testloader):
    """Validate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader.dataset)
    accuracy = correct / total
    return loss, accuracy

In [6]:
def replace_classifying_layer(efficientnet_model, num_classes: int = 10):
    num_features = efficientnet_model.classifier.fc.in_features
    efficientnet_model.classifier.fc = torch.nn.Linear(num_features, num_classes)

In [7]:
class CifarClient(fl.client.NumPyClient):
    def __init__(
        self, 
        model: torch.nn.Module,
        trainset: torchvision.datasets, 
        testset: torchvision.datasets,
        validation_split:int = 0.1,
    ):
        self.model = model
        self.trainset = trainset
        self.testset = testset
        self.validation_split = validation_split  

    def get_parameters(self):
        """Get parameters of the local model."""
        raise Exception("Not implemented (server-side parameter initialization)")

    def set_parameters(self, parameters):
        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        self.model.load_state_dict(state_dict, strict=True)

    def fit(self, parameters, config):
        """Train parameters on the locally held training set."""

        # Update local model parameters
        self.set_parameters(parameters)

        # Get hyperparameters for this round
        batch_size: int = config["batch_size"]
        epochs: int = config["local_epochs"]

        n_valset =  int(len(self.trainset)*self.validation_split)

        valset = torch.utils.data.Subset(
            self.trainset, range(0, n_valset)
        )
        trainset = torch.utils.data.Subset(
            self.trainset, range(n_valset, len(self.trainset))
        )

        trainLoader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
        valLoader = DataLoader(valset, batch_size=batch_size)

        results = train(self.model, trainLoader, valLoader, epochs)

        parameters_prime = [val.cpu().numpy()
                            for _, val in self.model.state_dict().items()]
        num_examples_train = len(trainset)

        return parameters_prime, num_examples_train, results

    def evaluate(self, parameters, config):
        """Evaluate parameters on the locally held test set."""
        # Update local model parameters
        self.set_parameters(parameters)

        # Get config values
        steps: int = config["val_steps"]
        
        # Evaluate global model parameters on the local test data and return results
        testloader = DataLoader(self.testset, batch_size=steps)

        loss, accuracy = test(self.model, testloader)
        return float(loss), len(self.testset), {"accuracy": float(accuracy)}


In [16]:
def load_efficientnet(entrypoint: str = "nvidia_efficientnet_b0", classes:int =  None):
    efficientnet = torch.hub.load(
        'NVIDIA/DeepLearningExamples:torchhub', entrypoint, pretrained=True)
    efficientnet.to(DEVICE)
    if classes is not None:
        replace_classifying_layer(efficientnet, classes)
    return efficientnet


In [14]:
def client_dry_run():
    model = load_efficientnet(classes=10)
    trainset, testset = load_partition(0)
    trainset = torch.utils.data.Subset(
        trainset, range(10)
    )
    trainset = torch.utils.data.Subset(
        testset, range(10)
    )
    client = CifarClient(model, trains, tests)
    client.fit(
        [val.cpu().numpy()
         for _, val in model.state_dict().items()],
        {'batch_size': 32, 'local_epochs': 1})
    client.evaluate([val.cpu().numpy()
                     for _, val in model.state_dict().items()],
                    {'val_steps': 32})


In [17]:
client_dry_run()

Using cache found in /Users/cozek/.cache/torch/hub/NVIDIA_DeepLearningExamples_torchhub


Files already downloaded and verified
Files already downloaded and verified


In [9]:
trains, tests = load_partition(0)
trains = torch.utils.data.Subset(
    trains, range(10)
)
EFFICIENTNET_MODEL = 'nvidia_efficientnet_b0'

efficientnet = torch.hub.load(
    'NVIDIA/DeepLearningExamples:torchhub', EFFICIENTNET_MODEL, pretrained=True)
efficientnet.to(DEVICE)
replace_classifying_layer(efficientnet)


Files already downloaded and verified
Files already downloaded and verified


Using cache found in /Users/cozek/.cache/torch/hub/NVIDIA_DeepLearningExamples_torchhub


In [10]:
efficientnet.classifier


Sequential(
  (pooling): AdaptiveAvgPool2d(output_size=1)
  (squeeze): Flatten()
  (dropout): Dropout(p=0.2, inplace=False)
  (fc): Linear(in_features=1280, out_features=10, bias=True)
)

In [11]:
ldr = DataLoader(trains, batch_size=32)
for images, labels in ldr:
    efficientnet(images)


In [12]:
efficientnet(images).shape


torch.Size([10, 10])

In [33]:
def replace_classifying_layer(num_classes:int = 10):
    for name,layer in efficientnet.classifier.named_modules():
        if name == 'fc':
            num_features = layer.in_features
            layer = torch.nn.Linear(num_features, num_classes)
            print(name,layer.in_features)


fc 1280


In [44]:
efficientnet.classifier.fc = torch.nn.Linear(1280, 10)


In [45]:
efficientnet


EfficientNet(
  (stem): Sequential(
    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
    (activation): SiLU(inplace=True)
  )
  (layers): Sequential(
    (0): Sequential(
      (block0): MBConvBlock(
        (depsep): Sequential(
          (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
          (act): SiLU(inplace=True)
        )
        (se): SequentialSqueezeAndExcitation(
          (squeeze): Linear(in_features=32, out_features=8, bias=True)
          (expand): Linear(in_features=8, out_features=32, bias=True)
          (activation): SiLU(inplace=True)
          (sigmoid): Sigmoid()
          (mul_a_quantizer): Identity()
          (mul_b_quantizer): Identity()
        )
      