# Federated PyTorch MNIST

In [1]:
#!pip install -r requirements.txt

In [2]:
import os
import glob

from PIL import Image

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from openfl.interface.interactive_api.federation import Federation
from openfl.interface.interactive_api.experiment import TaskInterface, DataInterface, ModelInterface, FLExperiment
from copy import deepcopy
import torchvision
from torchvision import transforms as T
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import tqdm

myseed = 0

torch.manual_seed(myseed)
np.random.seed(myseed)

  from .autonotebook import tqdm as notebook_tqdm
2023-02-26 09:36:01.743459: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-02-26 09:36:02.056392: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-02-26 09:36:04.869652: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-02-26 09:36:04.869861: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plu

## Connect to the Federation

In [3]:
# Create a federation

# please use the same identificator that was used in signed certificate
client_id = 'api'
cert_dir = 'cert'
director_node_fqdn = 'director'
# 1) Run with API layer - Director mTLS 
# If the user wants to enable mTLS their must provide CA root chain, and signed key pair to the federation interface
# cert_chain = f'{cert_dir}/root_ca.crt'
# api_certificate = f'{cert_dir}/{client_id}.crt'
# api_private_key = f'{cert_dir}/{client_id}.key'

# federation = Federation(client_id=client_id, director_node_fqdn=director_node_fqdn, director_port='50051',
#                        cert_chain=cert_chain, api_cert=api_certificate, api_private_key=api_private_key)

# --------------------------------------------------------------------------------------------------------------------

# 2) Run with TLS disabled (trusted environment)
# Federation can also determine local fqdn automatically
federation = Federation(client_id=client_id, director_node_fqdn=director_node_fqdn, director_port='50051', tls=False)


In [4]:
federation.target_shape

['28', '28', '1']

In [5]:
shard_registry = federation.get_shard_registry()
shard_registry

{'env_one': {'shard_info': node_info {
    name: "env_one"
  }
  shard_description: "Mnist dataset, shard number 1 out of 10"
  sample_shape: "28"
  sample_shape: "28"
  sample_shape: "1"
  target_shape: "28"
  target_shape: "28"
  target_shape: "1",
  'is_online': True,
  'is_experiment_running': False,
  'last_updated': '2023-02-26 09:36:04',
  'current_time': '2023-02-26 09:36:09',
  'valid_duration': seconds: 120,
  'experiment_name': 'ExperimentName Mock'},
 'env_two': {'shard_info': node_info {
    name: "env_two"
  }
  shard_description: "Mnist dataset, shard number 2 out of 10"
  sample_shape: "28"
  sample_shape: "28"
  sample_shape: "1"
  target_shape: "28"
  target_shape: "28"
  target_shape: "1",
  'is_online': True,
  'is_experiment_running': False,
  'last_updated': '2023-02-26 09:36:07',
  'current_time': '2023-02-26 09:36:09',
  'valid_duration': seconds: 120,
  'experiment_name': 'ExperimentName Mock'}}

In [6]:
# First, request a dummy_shard_desc that holds information about the federated dataset 
dummy_shard_desc = federation.get_dummy_shard_descriptor(size=10)
dummy_shard_dataset = dummy_shard_desc.get_dataset('train')
sample, target = dummy_shard_dataset[0]
print(sample.shape)
print(target.shape)

(28, 28, 1)
(28, 28, 1)


## Creating a FL experiment using Interactive API

### Register dataset

In [7]:
normalize = T.Normalize(
    mean=[0.3081],
    std=[0.1307]
)

augmentation = T.RandomApply(
    [T.RandomHorizontalFlip(),
     T.RandomCrop(28, padding=4)],
    p=.5
)

training_transform = T.Compose(
    [T.ToTensor(),
     T.Resize(28),
     augmentation,
     normalize]
)

valid_transform = T.Compose(
    [T.ToTensor(),
     T.Resize(28),
     normalize]
)


In [8]:
class TransformedDataset(Dataset):
    """Image Person ReID Dataset."""

    def __init__(self, dataset, transform=None, target_transform=None):
        """Initialize Dataset."""
        self.dataset = dataset
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        """Length of dataset."""
        return len(self.dataset)

    def __getitem__(self, index):
        img, label = self.dataset[index]
        label = self.target_transform(label) if self.target_transform else label
        img = self.transform(img) if self.transform else img
        return img, label


In [9]:
class MNISTDataset(DataInterface):
    def __init__(self, **kwargs):
        self.kwargs = kwargs
    
    @property
    def shard_descriptor(self):
        return self._shard_descriptor
        
    @shard_descriptor.setter
    def shard_descriptor(self, shard_descriptor):
        """
        Describe per-collaborator procedures or sharding.

        This method will be called during a collaborator initialization.
        Local shard_descriptor  will be set by Envoy.
        """
        self._shard_descriptor = shard_descriptor
        
        self.train_set = TransformedDataset(
            self._shard_descriptor.get_dataset('train'),
            transform=training_transform
        )
        self.valid_set = TransformedDataset(
            self._shard_descriptor.get_dataset('val'),
            transform=valid_transform
        )
        
    def get_train_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks with optimizer in contract
        """
        generator=torch.Generator()
        generator.manual_seed(myseed)
        return DataLoader(
            self.train_set, batch_size=self.kwargs['train_bs'], shuffle=True, generator=generator, drop_last=True)

    def get_valid_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks without optimizer in contract
        """
        return DataLoader(self.valid_set, batch_size=self.kwargs['valid_bs'], drop_last=True)

    def get_train_data_size(self):
        """
        Information for aggregation
        """
        return len(self.train_set)

    def get_valid_data_size(self):
        """
        Information for aggregation
        """
        return len(self.valid_set)
    

In [10]:
fed_dataset = MNISTDataset(train_bs=128, valid_bs=128)

### Describe the model and optimizer

In [11]:
"""
MobileNetV2 model


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x)


model_net = Net()
"""

'\nMobileNetV2 model\n\n\nclass Net(nn.Module):\n    def __init__(self):\n        super(Net, self).__init__()\n        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)\n        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)\n        self.conv2_drop = nn.Dropout2d()\n        self.fc1 = nn.Linear(320, 50)\n        self.fc2 = nn.Linear(50, 10)\n\n    def forward(self, x):\n        x = F.relu(F.max_pool2d(self.conv1(x), 2))\n        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))\n        x = x.view(-1, 320)\n        x = F.relu(self.fc1(x))\n        x = F.dropout(x, training=self.training)\n        x = self.fc2(x)\n        return F.log_softmax(x)\n\n\nmodel_net = Net()\n'

In [12]:
from torch.nn import Module
class BatchRenormalization2D(Module):

    def __init__(self, num_features,  eps=1e-05, momentum=0.01, r_d_max_inc_step = 0.0001):
        super(BatchRenormalization2D, self).__init__()

        self.eps = eps
        self.momentum = torch.tensor( (momentum), requires_grad = False)

        self.gamma = torch.nn.Parameter(torch.ones((1, num_features, 1, 1)), requires_grad=True)
        self.beta = torch.nn.Parameter(torch.zeros((1, num_features, 1, 1)), requires_grad=True)

        self.running_avg_mean = torch.ones((1, num_features, 1, 1), requires_grad=False)
        self.running_avg_std = torch.zeros((1, num_features, 1, 1), requires_grad=False) 

        self.max_r_max = 3.0
        self.max_d_max = 5.0

        self.r_max_inc_step = r_d_max_inc_step
        self.d_max_inc_step = r_d_max_inc_step

        self.r_max = torch.tensor( (1.0), requires_grad = False)
        self.d_max = torch.tensor( (0.0), requires_grad = False)

    def forward(self, x):

        device = self.gamma.device

        batch_ch_mean = torch.mean(x, dim=(0,2,3), keepdim=True).to(device)
        batch_ch_std = torch.clamp(torch.std(x, dim=(0,2,3), keepdim=True), self.eps, 1e10).to(device)

        self.running_avg_std = self.running_avg_std.to(device)
        self.running_avg_mean = self.running_avg_mean.to(device)
        self.momentum = self.momentum.to(device)

        self.r_max = self.r_max.to(device)
        self.d_max = self.d_max.to(device)


        if self.training:

            r = torch.clamp(batch_ch_std / self.running_avg_std, 1.0 / self.r_max, self.r_max).to(device).data.to(device)
            d = torch.clamp((batch_ch_mean - self.running_avg_mean) / self.running_avg_std, -self.d_max, self.d_max).to(device).data.to(device)

            x = ((x - batch_ch_mean) * r )/ batch_ch_std + d
            x = self.gamma * x + self.beta

            if self.r_max < self.max_r_max:
                self.r_max += self.r_max_inc_step * x.shape[0]

            if self.d_max < self.max_d_max:
                self.d_max += self.d_max_inc_step * x.shape[0]

        else:

            x = (x - self.running_avg_mean) / self.running_avg_std
            x = self.gamma * x + self.beta

        self.running_avg_mean = self.running_avg_mean + self.momentum * (batch_ch_mean.data.to(device) - self.running_avg_mean)
        self.running_avg_std = self.running_avg_std + self.momentum * (batch_ch_std.data.to(device) - self.running_avg_std)

        return x

In [13]:
resnet18 = torchvision.models.resnet18(pretrained=False)



In [14]:
resnet18

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [15]:
resnet18
resnet18.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

#MOMENTUM BATCH NORM

resnet18.bn1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)

resnet18.layer1[0].bn1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
resnet18.layer1[0].bn2 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
resnet18.layer1[1].bn1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
resnet18.layer1[1].bn2 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)

resnet18.layer2[0].bn1 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
resnet18.layer2[0].bn2 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
resnet18.layer2[1].bn1 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
resnet18.layer2[1].bn2 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
resnet18.layer2[0].downsample[1] = nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)

resnet18.layer3[0].bn1 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
resnet18.layer3[0].bn2 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
resnet18.layer3[1].bn1 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
resnet18.layer3[1].bn2 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
resnet18.layer3[0].downsample[1] = nn.BatchNorm2d(256, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)

resnet18.layer4[0].bn1 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
resnet18.layer4[0].bn2 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
resnet18.layer4[1].bn1 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
resnet18.layer4[1].bn2 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
resnet18.layer4[0].downsample[1] = nn.BatchNorm2d(512, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)


#GROUP NORM
'''
resnet18.bn1 = nn.GroupNorm(32, 64)

resnet18.layer1[0].bn1 = nn.GroupNorm(32, 64)
resnet18.layer1[0].bn2 = nn.GroupNorm(32, 64)
resnet18.layer1[1].bn1 = nn.GroupNorm(32, 64)
resnet18.layer1[1].bn2 = nn.GroupNorm(32, 64)

resnet18.layer2[0].bn1 = nn.GroupNorm(32, 128)
resnet18.layer2[0].bn2 = nn.GroupNorm(32, 128)
resnet18.layer2[1].bn1 = nn.GroupNorm(32, 128)
resnet18.layer2[1].bn2 = nn.GroupNorm(32, 128)
resnet18.layer2[0].downsample[1] = nn.GroupNorm(32, 128)

resnet18.layer3[0].bn1 = nn.GroupNorm(32, 256)
resnet18.layer3[0].bn2 = nn.GroupNorm(32, 256)
resnet18.layer3[1].bn1 = nn.GroupNorm(32, 256)
resnet18.layer3[1].bn2 = nn.GroupNorm(32, 256)
resnet18.layer3[0].downsample[1] = nn.GroupNorm(32, 256)

resnet18.layer4[0].bn1 = nn.GroupNorm(32, 512)
resnet18.layer4[0].bn2 = nn.GroupNorm(32, 512)
resnet18.layer4[1].bn1 = nn.GroupNorm(32, 512)
resnet18.layer4[1].bn2 = nn.GroupNorm(32, 512)
resnet18.layer4[0].downsample[1] = nn.GroupNorm(32, 512)
'''

#INSTANCE NORM (E' UNA GROUP NORM CON TANTI GRUPPI TANTI CANALI)
'''
resnet18.bn1 = nn.GroupNorm(64, 64)

resnet18.layer1[0].bn1 = nn.GroupNorm(64, 64)
resnet18.layer1[0].bn2 = nn.GroupNorm(64, 64)
resnet18.layer1[1].bn1 = nn.GroupNorm(64, 64)
resnet18.layer1[1].bn2 = nn.GroupNorm(64, 64)

resnet18.layer2[0].bn1 = nn.GroupNorm(128, 128)
resnet18.layer2[0].bn2 = nn.GroupNorm(128, 128)
resnet18.layer2[1].bn1 = nn.GroupNorm(128, 128)
resnet18.layer2[1].bn2 = nn.GroupNorm(128, 128)
resnet18.layer2[0].downsample[1] = nn.GroupNorm(128, 128)

resnet18.layer3[0].bn1 = nn.GroupNorm(256, 256)
resnet18.layer3[0].bn2 = nn.GroupNorm(256, 256)
resnet18.layer3[1].bn1 = nn.GroupNorm(256, 256)
resnet18.layer3[1].bn2 = nn.GroupNorm(256, 256)
resnet18.layer3[0].downsample[1] = nn.GroupNorm(256, 256)

resnet18.layer4[0].bn1 = nn.GroupNorm(512, 512)
resnet18.layer4[0].bn2 = nn.GroupNorm(512, 512)
resnet18.layer4[1].bn1 = nn.GroupNorm(512, 512)
resnet18.layer4[1].bn2 = nn.GroupNorm(512, 512)
resnet18.layer4[0].downsample[1] = nn.GroupNorm(512, 512)
'''

#LAYER NORM (E' UNA GROUP NORM CON TUTTI I CANALI IN UN SOLO GRUPPO)
'''
resnet18.bn1 = nn.GroupNorm(1, 64)

resnet18.layer1[0].bn1 = nn.GroupNorm(1, 64)
resnet18.layer1[0].bn2 = nn.GroupNorm(1, 64)
resnet18.layer1[1].bn1 = nn.GroupNorm(1, 64)
resnet18.layer1[1].bn2 = nn.GroupNorm(1, 64)

resnet18.layer2[0].bn1 = nn.GroupNorm(1, 128)
resnet18.layer2[0].bn2 = nn.GroupNorm(1, 128)
resnet18.layer2[1].bn1 = nn.GroupNorm(1, 128)
resnet18.layer2[1].bn2 = nn.GroupNorm(1, 128)
resnet18.layer2[0].downsample[1] = nn.GroupNorm(1, 128)

resnet18.layer3[0].bn1 = nn.GroupNorm(1, 256)
resnet18.layer3[0].bn2 = nn.GroupNorm(1, 256)
resnet18.layer3[1].bn1 = nn.GroupNorm(1, 256)
resnet18.layer3[1].bn2 = nn.GroupNorm(1, 256)
resnet18.layer3[0].downsample[1] = nn.GroupNorm(1, 256)

resnet18.layer4[0].bn1 = nn.GroupNorm(1, 512)
resnet18.layer4[0].bn2 = nn.GroupNorm(1, 512)
resnet18.layer4[1].bn1 = nn.GroupNorm(1, 512)
resnet18.layer4[1].bn2 = nn.GroupNorm(1, 512)
resnet18.layer4[0].downsample[1] = nn.GroupNorm(1, 512)
'''

#BATCH RENORMALIZATION
'''
resnet18.bn1 = BatchRenormalization2D(64, eps=1e-05, momentum=0.9)

resnet18.layer1[0].bn1 = BatchRenormalization2D(64, eps=1e-05, momentum=0.9)
resnet18.layer1[0].bn2 = BatchRenormalization2D(64, eps=1e-05, momentum=0.9)
resnet18.layer1[1].bn1 = BatchRenormalization2D(64, eps=1e-05, momentum=0.9)
resnet18.layer1[1].bn2 = BatchRenormalization2D(64, eps=1e-05, momentum=0.9)

resnet18.layer2[0].bn1 = BatchRenormalization2D(128, eps=1e-05, momentum=0.9)
resnet18.layer2[0].bn2 = BatchRenormalization2D(128, eps=1e-05, momentum=0.9)
resnet18.layer2[1].bn1 = BatchRenormalization2D(128, eps=1e-05, momentum=0.9)
resnet18.layer2[1].bn2 = BatchRenormalization2D(128, eps=1e-05, momentum=0.9)
resnet18.layer2[0].downsample[1] = BatchRenormalization2D(128, eps=1e-05, momentum=0.9)

resnet18.layer3[0].bn1 = BatchRenormalization2D(256, eps=1e-05, momentum=0.9)
resnet18.layer3[0].bn2 = BatchRenormalization2D(256, eps=1e-05, momentum=0.9)
resnet18.layer3[1].bn1 = BatchRenormalization2D(256, eps=1e-05, momentum=0.9)
resnet18.layer3[1].bn2 = BatchRenormalization2D(256, eps=1e-05, momentum=0.9)
resnet18.layer3[0].downsample[1] = BatchRenormalization2D(256, eps=1e-05, momentum=0.9)
resnet18.layer4[0].bn1 = BatchRenormalization2D(512, eps=1e-05, momentum=0.9)
resnet18.layer4[0].bn2 = BatchRenormalization2D(512, eps=1e-05, momentum=0.9)
resnet18.layer4[1].bn1 = BatchRenormalization2D(512, eps=1e-05, momentum=0.9)
resnet18.layer4[1].bn2 = BatchRenormalization2D(512, eps=1e-05, momentum=0.9)
resnet18.layer4[0].downsample[1] = BatchRenormalization2D(512, eps=1e-05, momentum=0.9)
'''

resnet18.fc = nn.Linear(in_features=512, out_features=10, bias=True)

resnet18

ResNet(
  (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [16]:
#VGG16 WITH BATCHNORM MOMENTUM 0.9
class VGG16(nn.Module):

    def __init__(self, num_classes):
        super(VGG16, self).__init__()

        # calculate same padding:
        # (w - k + 2*p)/s + 1 = o
        # => p = (s(o-1) - w + k)/2

        self.block_1 = nn.Sequential(
            nn.Conv2d(in_channels=1,
                      out_channels=64,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      # (1(32-1)- 32 + 3)/2 = 1
                      padding=1),
            nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True),
            nn.ReLU(),
            nn.Conv2d(in_channels=64,
                      out_channels=64,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2),
                         stride=(2, 2))
        )

        self.block_2 = nn.Sequential(
            nn.Conv2d(in_channels=64,
                      out_channels=128,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True),
            nn.ReLU(),
            nn.Conv2d(in_channels=128,
                      out_channels=128,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(128, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2),
                         stride=(2, 2))
        )
        
        self.block_3 = nn.Sequential(
            nn.Conv2d(in_channels=128,
                      out_channels=256,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(256, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True),
            nn.ReLU(),
            nn.Conv2d(in_channels=256,
                      out_channels=256,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(256, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True),
            nn.ReLU(),
            nn.Conv2d(in_channels=256,
                      out_channels=256,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(256, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2),
                         stride=(2, 2))
        )

        self.block_4 = nn.Sequential(
            nn.Conv2d(in_channels=256,
                      out_channels=512,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(512, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True),
            nn.ReLU(),
            nn.Conv2d(in_channels=512,
                      out_channels=512,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(512, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True),
            nn.ReLU(),
            nn.Conv2d(in_channels=512,
                      out_channels=512,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(512, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2),
                         stride=(2, 2))
        ) 
        
        self.block_5 = nn.Sequential(
            nn.Conv2d(in_channels=512,
                      out_channels=512,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(512, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True),
            nn.ReLU(),
            nn.Conv2d(in_channels=512,
                      out_channels=512,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(512, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True),
            nn.ReLU(),
            nn.Conv2d(in_channels=512,
                      out_channels=512,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(512, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2),
                         stride=(2, 2))
        )        

        self.classifier = nn.Sequential(
            nn.Linear(512*7*7, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, num_classes) 
        )

        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear):
                nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='leaky_relu')
                if m.bias is not None:
                    m.bias.detach().zero_()


    def forward(self, x):

        x = self.block_1(x)
        x = self.block_2(x)
        x = self.block_3(x)
        x = self.block_4(x)
        x = self.block_5(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

vgg16 = VGG16(10)

#GROUP NORM
'''
vgg16.block_1[1] = nn.GroupNorm(32, 64)
vgg16.block_1[4] = nn.GroupNorm(32, 64)
vgg16.block_2[1] = nn.GroupNorm(32, 128)
vgg16.block_2[4] = nn.GroupNorm(32, 128)
vgg16.block_3[1] = nn.GroupNorm(32, 256)
vgg16.block_3[4] = nn.GroupNorm(32, 256)
vgg16.block_3[7] = nn.GroupNorm(32, 256)
vgg16.block_4[1] = nn.GroupNorm(32, 512)
vgg16.block_4[4] = nn.GroupNorm(32, 512)
vgg16.block_4[7] = nn.GroupNorm(32, 512)
vgg16.block_5[1] = nn.GroupNorm(32, 512)
vgg16.block_5[4] = nn.GroupNorm(32, 512)
vgg16.block_5[7] = nn.GroupNorm(32, 512)
'''

#INSTANCE NORM
'''
vgg16.block_1[1] = nn.GroupNorm(64, 64)
vgg16.block_1[4] = nn.GroupNorm(64, 64)
vgg16.block_2[1] = nn.GroupNorm(128, 128)
vgg16.block_2[4] = nn.GroupNorm(128, 128)
vgg16.block_3[1] = nn.GroupNorm(256, 256)
vgg16.block_3[4] = nn.GroupNorm(256, 256)
vgg16.block_3[7] = nn.GroupNorm(256, 256)
vgg16.block_4[1] = nn.GroupNorm(512, 512)
vgg16.block_4[4] = nn.GroupNorm(512, 512)
vgg16.block_4[7] = nn.GroupNorm(512, 512)
vgg16.block_5[1] = nn.GroupNorm(512, 512)
vgg16.block_5[4] = nn.GroupNorm(512, 512)
vgg16.block_5[7] = nn.GroupNorm(512, 512)
'''

#LAYER NORM
'''
vgg16.block_1[1] = nn.GroupNorm(1, 64)
vgg16.block_1[4] = nn.GroupNorm(1, 64)
vgg16.block_2[1] = nn.GroupNorm(1, 128)
vgg16.block_2[4] = nn.GroupNorm(1, 128)
vgg16.block_3[1] = nn.GroupNorm(1, 256)
vgg16.block_3[4] = nn.GroupNorm(1, 256)
vgg16.block_3[7] = nn.GroupNorm(1, 256)
vgg16.block_4[1] = nn.GroupNorm(1, 512)
vgg16.block_4[4] = nn.GroupNorm(1, 512)
vgg16.block_4[7] = nn.GroupNorm(1, 512)
vgg16.block_5[1] = nn.GroupNorm(1, 512)
vgg16.block_5[4] = nn.GroupNorm(1, 512)
vgg16.block_5[7] = nn.GroupNorm(1, 512)
'''

#BATCH RENORM
'''
vgg16.block_1[1] = BatchRenormalization2D(64, eps=1e-05, momentum=0.9)
vgg16.block_1[4] = BatchRenormalization2D(64, eps=1e-05, momentum=0.9)
vgg16.block_2[1] = BatchRenormalization2D(128, eps=1e-05, momentum=0.9)
vgg16.block_2[4] = BatchRenormalization2D(128, eps=1e-05, momentum=0.9)
vgg16.block_3[1] = BatchRenormalization2D(256, eps=1e-05, momentum=0.9)
vgg16.block_3[4] = BatchRenormalization2D(256, eps=1e-05, momentum=0.9)
vgg16.block_3[7] = BatchRenormalization2D(256, eps=1e-05, momentum=0.9)
vgg16.block_4[1] = BatchRenormalization2D(512, eps=1e-05, momentum=0.9)
vgg16.block_4[4] = BatchRenormalization2D(512, eps=1e-05, momentum=0.9)
vgg16.block_4[7] = BatchRenormalization2D(512, eps=1e-05, momentum=0.9)
vgg16.block_5[1] = BatchRenormalization2D(512, eps=1e-05, momentum=0.9)
vgg16.block_5[4] = BatchRenormalization2D(512, eps=1e-05, momentum=0.9)
vgg16.block_5[7] = BatchRenormalization2D(512, eps=1e-05, momentum=0.9)
'''

'\nvgg16.block_1[1] = BatchRenormalization2D(64, eps=1e-05, momentum=0.9)\nvgg16.block_1[4] = BatchRenormalization2D(64, eps=1e-05, momentum=0.9)\nvgg16.block_2[1] = BatchRenormalization2D(128, eps=1e-05, momentum=0.9)\nvgg16.block_2[4] = BatchRenormalization2D(128, eps=1e-05, momentum=0.9)\nvgg16.block_3[1] = BatchRenormalization2D(256, eps=1e-05, momentum=0.9)\nvgg16.block_3[4] = BatchRenormalization2D(256, eps=1e-05, momentum=0.9)\nvgg16.block_3[7] = BatchRenormalization2D(256, eps=1e-05, momentum=0.9)\nvgg16.block_4[1] = BatchRenormalization2D(512, eps=1e-05, momentum=0.9)\nvgg16.block_4[4] = BatchRenormalization2D(512, eps=1e-05, momentum=0.9)\nvgg16.block_4[7] = BatchRenormalization2D(512, eps=1e-05, momentum=0.9)\nvgg16.block_5[1] = BatchRenormalization2D(512, eps=1e-05, momentum=0.9)\nvgg16.block_5[4] = BatchRenormalization2D(512, eps=1e-05, momentum=0.9)\nvgg16.block_5[7] = BatchRenormalization2D(512, eps=1e-05, momentum=0.9)\n'

In [17]:
efficientnet_b0 = torchvision.models.efficientnet_b0(pretrained=False)
efficientnet_b0

EfficientNet(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): MBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): SiLU(inplace=True)
          )
          (1): SqueezeExcitation(
            (avgpool): AdaptiveAvgPool2d(output_size=1)
            (fc1): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
            (fc2): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
            (activation): SiLU(inplace=True)
            (scale_activation): Sigmoid()
          )
          (2): Conv2dNormActivat

In [18]:
efficientnet_b0
efficientnet_b0.features[0][0] = nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1,1), bias=False)
efficientnet_b0.classifier[1] = nn.Linear(in_features=1280, out_features=10, bias=True)

#GROUP NORM
'''
efficientnet_b0.features[0][1] = nn.GroupNorm(32, 32)
efficientnet_b0.features[1][0].block[0][1] = nn.GroupNorm(8, 32)
efficientnet_b0.features[1][0].block[2][1] = nn.GroupNorm(8, 16)  
efficientnet_b0.features[2][0].block[0][1] = nn.GroupNorm(8, 96)
efficientnet_b0.features[2][0].block[1][1] = nn.GroupNorm(8, 96)
efficientnet_b0.features[2][0].block[3][1] = nn.GroupNorm(8, 24) 
efficientnet_b0.features[2][1].block[0][1] = nn.GroupNorm(8, 144)
efficientnet_b0.features[2][1].block[1][1] = nn.GroupNorm(8, 144)
efficientnet_b0.features[2][1].block[3][1] = nn.GroupNorm(8, 24)
efficientnet_b0.features[3][0].block[0][1] = nn.GroupNorm(8, 144)
efficientnet_b0.features[3][0].block[1][1] = nn.GroupNorm(8, 144)
efficientnet_b0.features[3][0].block[3][1] = nn.GroupNorm(8, 40)
efficientnet_b0.features[3][1].block[0][1] = nn.GroupNorm(8, 240)
efficientnet_b0.features[3][1].block[1][1] = nn.GroupNorm(8, 240)
efficientnet_b0.features[3][1].block[3][1] = nn.GroupNorm(8, 40)
efficientnet_b0.features[4][0].block[0][1] = nn.GroupNorm(8, 240)
efficientnet_b0.features[4][0].block[1][1] = nn.GroupNorm(8, 240)
efficientnet_b0.features[4][0].block[3][1] = nn.GroupNorm(8, 80)
efficientnet_b0.features[4][1].block[0][1] = nn.GroupNorm(8, 480)
efficientnet_b0.features[4][1].block[1][1] = nn.GroupNorm(8, 480)
efficientnet_b0.features[4][1].block[3][1] = nn.GroupNorm(8, 80)
efficientnet_b0.features[4][2].block[0][1] = nn.GroupNorm(8, 480)
efficientnet_b0.features[4][2].block[1][1] = nn.GroupNorm(8, 480)
efficientnet_b0.features[4][2].block[3][1] = nn.GroupNorm(8, 80)
efficientnet_b0.features[5][0].block[0][1] = nn.GroupNorm(8, 480)
efficientnet_b0.features[5][0].block[1][1] = nn.GroupNorm(8, 480)
efficientnet_b0.features[5][0].block[3][1] = nn.GroupNorm(8, 112)
efficientnet_b0.features[5][1].block[0][1] = nn.GroupNorm(8, 672)
efficientnet_b0.features[5][1].block[1][1] = nn.GroupNorm(8, 672)
efficientnet_b0.features[5][1].block[3][1] = nn.GroupNorm(8, 112)
efficientnet_b0.features[5][2].block[0][1] = nn.GroupNorm(8, 672)
efficientnet_b0.features[5][2].block[1][1] = nn.GroupNorm(8, 672)
efficientnet_b0.features[5][2].block[3][1] = nn.GroupNorm(8, 112)
efficientnet_b0.features[6][0].block[0][1] = nn.GroupNorm(8, 672)
efficientnet_b0.features[6][0].block[1][1] = nn.GroupNorm(8, 672)
efficientnet_b0.features[6][0].block[3][1] = nn.GroupNorm(8, 192)
efficientnet_b0.features[6][1].block[0][1] = nn.GroupNorm(8, 1152)
efficientnet_b0.features[6][1].block[1][1] = nn.GroupNorm(8, 1152)
efficientnet_b0.features[6][1].block[3][1] = nn.GroupNorm(8, 192)
efficientnet_b0.features[6][2].block[0][1] = nn.GroupNorm(8, 1152)
efficientnet_b0.features[6][2].block[1][1] = nn.GroupNorm(8, 1152)
efficientnet_b0.features[6][2].block[3][1] = nn.GroupNorm(8, 192)
efficientnet_b0.features[6][3].block[0][1] = nn.GroupNorm(8, 1152)
efficientnet_b0.features[6][3].block[1][1] = nn.GroupNorm(8, 1152)
efficientnet_b0.features[6][3].block[3][1] = nn.GroupNorm(8, 192)
efficientnet_b0.features[7][0].block[0][1] = nn.GroupNorm(8, 1152)
efficientnet_b0.features[7][0].block[1][1] = nn.GroupNorm(8, 1152)
efficientnet_b0.features[7][0].block[3][1] = nn.GroupNorm(8, 320)
efficientnet_b0.features[8][1] = nn.GroupNorm(8, 1280)
'''

#INSTANCE NORM
'''
efficientnet_b0.features[0][1] = nn.GroupNorm(32, 32)
efficientnet_b0.features[1][0].block[0][1] = nn.GroupNorm(32, 32)
efficientnet_b0.features[1][0].block[2][1] = nn.GroupNorm(16, 16)  
efficientnet_b0.features[2][0].block[0][1] = nn.GroupNorm(96, 96)
efficientnet_b0.features[2][0].block[1][1] = nn.GroupNorm(96, 96)
efficientnet_b0.features[2][0].block[3][1] = nn.GroupNorm(24, 24) 
efficientnet_b0.features[2][1].block[0][1] = nn.GroupNorm(144, 144)
efficientnet_b0.features[2][1].block[1][1] = nn.GroupNorm(144, 144)
efficientnet_b0.features[2][1].block[3][1] = nn.GroupNorm(24, 24)
efficientnet_b0.features[3][0].block[0][1] = nn.GroupNorm(144, 144)
efficientnet_b0.features[3][0].block[1][1] = nn.GroupNorm(144, 144)
efficientnet_b0.features[3][0].block[3][1] = nn.GroupNorm(40, 40)
efficientnet_b0.features[3][1].block[0][1] = nn.GroupNorm(240, 240)
efficientnet_b0.features[3][1].block[1][1] = nn.GroupNorm(240, 240)
efficientnet_b0.features[3][1].block[3][1] = nn.GroupNorm(40, 40)
efficientnet_b0.features[4][0].block[0][1] = nn.GroupNorm(240, 240)
efficientnet_b0.features[4][0].block[1][1] = nn.GroupNorm(240, 240)
efficientnet_b0.features[4][0].block[3][1] = nn.GroupNorm(80, 80)
efficientnet_b0.features[4][1].block[0][1] = nn.GroupNorm(480, 480)
efficientnet_b0.features[4][1].block[1][1] = nn.GroupNorm(480, 480)
efficientnet_b0.features[4][1].block[3][1] = nn.GroupNorm(80, 80)
efficientnet_b0.features[4][2].block[0][1] = nn.GroupNorm(480, 480)
efficientnet_b0.features[4][2].block[1][1] = nn.GroupNorm(480, 480)
efficientnet_b0.features[4][2].block[3][1] = nn.GroupNorm(80, 80)
efficientnet_b0.features[5][0].block[0][1] = nn.GroupNorm(480, 480)
efficientnet_b0.features[5][0].block[1][1] = nn.GroupNorm(480, 480)
efficientnet_b0.features[5][0].block[3][1] = nn.GroupNorm(112, 112)
efficientnet_b0.features[5][1].block[0][1] = nn.GroupNorm(672, 672)
efficientnet_b0.features[5][1].block[1][1] = nn.GroupNorm(672, 672)
efficientnet_b0.features[5][1].block[3][1] = nn.GroupNorm(112, 112)
efficientnet_b0.features[5][2].block[0][1] = nn.GroupNorm(672, 672)
efficientnet_b0.features[5][2].block[1][1] = nn.GroupNorm(672, 672)
efficientnet_b0.features[5][2].block[3][1] = nn.GroupNorm(112, 112)
efficientnet_b0.features[6][0].block[0][1] = nn.GroupNorm(672, 672)
efficientnet_b0.features[6][0].block[1][1] = nn.GroupNorm(672, 672)
efficientnet_b0.features[6][0].block[3][1] = nn.GroupNorm(192, 192)
efficientnet_b0.features[6][1].block[0][1] = nn.GroupNorm(1152, 1152)
efficientnet_b0.features[6][1].block[1][1] = nn.GroupNorm(1152, 1152)
efficientnet_b0.features[6][1].block[3][1] = nn.GroupNorm(192, 192)
efficientnet_b0.features[6][2].block[0][1] = nn.GroupNorm(1152, 1152)
efficientnet_b0.features[6][2].block[1][1] = nn.GroupNorm(1152, 1152)
efficientnet_b0.features[6][2].block[3][1] = nn.GroupNorm(192, 192)
efficientnet_b0.features[6][3].block[0][1] = nn.GroupNorm(1152, 1152)
efficientnet_b0.features[6][3].block[1][1] = nn.GroupNorm(1152, 1152)
efficientnet_b0.features[6][3].block[3][1] = nn.GroupNorm(192, 192)
efficientnet_b0.features[7][0].block[0][1] = nn.GroupNorm(1152, 1152)
efficientnet_b0.features[7][0].block[1][1] = nn.GroupNorm(1152, 1152)
efficientnet_b0.features[7][0].block[3][1] = nn.GroupNorm(320, 320)
efficientnet_b0.features[8][1] = nn.GroupNorm(1280, 1280)
'''

#LAYER NORM
'''
efficientnet_b0.features[0][1] = nn.GroupNorm(1, 32)
efficientnet_b0.features[1][0].block[0][1] = nn.GroupNorm(1, 32)
efficientnet_b0.features[1][0].block[2][1] = nn.GroupNorm(1, 16)  
efficientnet_b0.features[2][0].block[0][1] = nn.GroupNorm(1, 96)
efficientnet_b0.features[2][0].block[1][1] = nn.GroupNorm(1, 96)
efficientnet_b0.features[2][0].block[3][1] = nn.GroupNorm(1, 24) 
efficientnet_b0.features[2][1].block[0][1] = nn.GroupNorm(1, 144)
efficientnet_b0.features[2][1].block[1][1] = nn.GroupNorm(1, 144)
efficientnet_b0.features[2][1].block[3][1] = nn.GroupNorm(1, 24)
efficientnet_b0.features[3][0].block[0][1] = nn.GroupNorm(1, 144)
efficientnet_b0.features[3][0].block[1][1] = nn.GroupNorm(1, 144)
efficientnet_b0.features[3][0].block[3][1] = nn.GroupNorm(1, 40)
efficientnet_b0.features[3][1].block[0][1] = nn.GroupNorm(1, 240)
efficientnet_b0.features[3][1].block[1][1] = nn.GroupNorm(1, 240)
efficientnet_b0.features[3][1].block[3][1] = nn.GroupNorm(1, 40)
efficientnet_b0.features[4][0].block[0][1] = nn.GroupNorm(1, 240)
efficientnet_b0.features[4][0].block[1][1] = nn.GroupNorm(1, 240)
efficientnet_b0.features[4][0].block[3][1] = nn.GroupNorm(1, 80)
efficientnet_b0.features[4][1].block[0][1] = nn.GroupNorm(1, 480)
efficientnet_b0.features[4][1].block[1][1] = nn.GroupNorm(1, 480)
efficientnet_b0.features[4][1].block[3][1] = nn.GroupNorm(1, 80)
efficientnet_b0.features[4][2].block[0][1] = nn.GroupNorm(1, 480)
efficientnet_b0.features[4][2].block[1][1] = nn.GroupNorm(1, 480)
efficientnet_b0.features[4][2].block[3][1] = nn.GroupNorm(1, 80)
efficientnet_b0.features[5][0].block[0][1] = nn.GroupNorm(1, 480)
efficientnet_b0.features[5][0].block[1][1] = nn.GroupNorm(1, 480)
efficientnet_b0.features[5][0].block[3][1] = nn.GroupNorm(1, 112)
efficientnet_b0.features[5][1].block[0][1] = nn.GroupNorm(1, 672)
efficientnet_b0.features[5][1].block[1][1] = nn.GroupNorm(1, 672)
efficientnet_b0.features[5][1].block[3][1] = nn.GroupNorm(1, 112)
efficientnet_b0.features[5][2].block[0][1] = nn.GroupNorm(1, 672)
efficientnet_b0.features[5][2].block[1][1] = nn.GroupNorm(1, 672)
efficientnet_b0.features[5][2].block[3][1] = nn.GroupNorm(1, 112)
efficientnet_b0.features[6][0].block[0][1] = nn.GroupNorm(1, 672)
efficientnet_b0.features[6][0].block[1][1] = nn.GroupNorm(1, 672)
efficientnet_b0.features[6][0].block[3][1] = nn.GroupNorm(1, 192)
efficientnet_b0.features[6][1].block[0][1] = nn.GroupNorm(1, 1152)
efficientnet_b0.features[6][1].block[1][1] = nn.GroupNorm(1, 1152)
efficientnet_b0.features[6][1].block[3][1] = nn.GroupNorm(1, 192)
efficientnet_b0.features[6][2].block[0][1] = nn.GroupNorm(1, 1152)
efficientnet_b0.features[6][2].block[1][1] = nn.GroupNorm(1, 1152)
efficientnet_b0.features[6][2].block[3][1] = nn.GroupNorm(1, 192)
efficientnet_b0.features[6][3].block[0][1] = nn.GroupNorm(1, 1152)
efficientnet_b0.features[6][3].block[1][1] = nn.GroupNorm(1, 1152)
efficientnet_b0.features[6][3].block[3][1] = nn.GroupNorm(1, 192)
efficientnet_b0.features[7][0].block[0][1] = nn.GroupNorm(1, 1152)
efficientnet_b0.features[7][0].block[1][1] = nn.GroupNorm(1, 1152)
efficientnet_b0.features[7][0].block[3][1] = nn.GroupNorm(1, 320)
efficientnet_b0.features[8][1] = nn.GroupNorm(1, 1280)
'''

#BATCH RENORM

efficientnet_b0.features[0][1] = BatchRenormalization2D(32, eps=1e-05, momentum=0.9)
efficientnet_b0.features[1][0].block[0][1] = BatchRenormalization2D(32, eps=1e-05, momentum=0.9)
efficientnet_b0.features[1][0].block[2][1] = BatchRenormalization2D(16, eps=1e-05, momentum=0.9)  
efficientnet_b0.features[2][0].block[0][1] = BatchRenormalization2D(96, eps=1e-05, momentum=0.9)
efficientnet_b0.features[2][0].block[1][1] = BatchRenormalization2D(96, eps=1e-05, momentum=0.9)
efficientnet_b0.features[2][0].block[3][1] = BatchRenormalization2D(24, eps=1e-05, momentum=0.9) 
efficientnet_b0.features[2][1].block[0][1] = BatchRenormalization2D(144, eps=1e-05, momentum=0.9)
efficientnet_b0.features[2][1].block[1][1] = BatchRenormalization2D(144, eps=1e-05, momentum=0.9)
efficientnet_b0.features[2][1].block[3][1] = BatchRenormalization2D(24, eps=1e-05, momentum=0.9)
efficientnet_b0.features[3][0].block[0][1] = BatchRenormalization2D(144, eps=1e-05, momentum=0.9)
efficientnet_b0.features[3][0].block[1][1] = BatchRenormalization2D(144, eps=1e-05, momentum=0.9)
efficientnet_b0.features[3][0].block[3][1] = BatchRenormalization2D(40, eps=1e-05, momentum=0.9)
efficientnet_b0.features[3][1].block[0][1] = BatchRenormalization2D(240, eps=1e-05, momentum=0.9)
efficientnet_b0.features[3][1].block[1][1] = BatchRenormalization2D(240, eps=1e-05, momentum=0.9)
efficientnet_b0.features[3][1].block[3][1] = BatchRenormalization2D(40, eps=1e-05, momentum=0.9)
efficientnet_b0.features[4][0].block[0][1] = BatchRenormalization2D(240, eps=1e-05, momentum=0.9)
efficientnet_b0.features[4][0].block[1][1] = BatchRenormalization2D(240, eps=1e-05, momentum=0.9)
efficientnet_b0.features[4][0].block[3][1] = BatchRenormalization2D(80, eps=1e-05, momentum=0.9)
efficientnet_b0.features[4][1].block[0][1] = BatchRenormalization2D(480, eps=1e-05, momentum=0.9)
efficientnet_b0.features[4][1].block[1][1] = BatchRenormalization2D(480, eps=1e-05, momentum=0.9)
efficientnet_b0.features[4][1].block[3][1] = BatchRenormalization2D(80, eps=1e-05, momentum=0.9)
efficientnet_b0.features[4][2].block[0][1] = BatchRenormalization2D(480, eps=1e-05, momentum=0.9)
efficientnet_b0.features[4][2].block[1][1] = BatchRenormalization2D(480, eps=1e-05, momentum=0.9)
efficientnet_b0.features[4][2].block[3][1] = BatchRenormalization2D(80, eps=1e-05, momentum=0.9)
efficientnet_b0.features[5][0].block[0][1] = BatchRenormalization2D(480, eps=1e-05, momentum=0.9)
efficientnet_b0.features[5][0].block[1][1] = BatchRenormalization2D(480, eps=1e-05, momentum=0.9)
efficientnet_b0.features[5][0].block[3][1] = BatchRenormalization2D(112, eps=1e-05, momentum=0.9)
efficientnet_b0.features[5][1].block[0][1] = BatchRenormalization2D(672, eps=1e-05, momentum=0.9)
efficientnet_b0.features[5][1].block[1][1] = BatchRenormalization2D(672, eps=1e-05, momentum=0.9)
efficientnet_b0.features[5][1].block[3][1] = BatchRenormalization2D(112, eps=1e-05, momentum=0.9)
efficientnet_b0.features[5][2].block[0][1] = BatchRenormalization2D(672, eps=1e-05, momentum=0.9)
efficientnet_b0.features[5][2].block[1][1] = BatchRenormalization2D(672, eps=1e-05, momentum=0.9)
efficientnet_b0.features[5][2].block[3][1] = BatchRenormalization2D(112, eps=1e-05, momentum=0.9)
efficientnet_b0.features[6][0].block[0][1] = BatchRenormalization2D(672, eps=1e-05, momentum=0.9)
efficientnet_b0.features[6][0].block[1][1] = BatchRenormalization2D(672, eps=1e-05, momentum=0.9)
efficientnet_b0.features[6][0].block[3][1] = BatchRenormalization2D(192, eps=1e-05, momentum=0.9)
efficientnet_b0.features[6][1].block[0][1] = BatchRenormalization2D(1152, eps=1e-05, momentum=0.9)
efficientnet_b0.features[6][1].block[1][1] = BatchRenormalization2D(1152, eps=1e-05, momentum=0.9)
efficientnet_b0.features[6][1].block[3][1] = BatchRenormalization2D(192, eps=1e-05, momentum=0.9)
efficientnet_b0.features[6][2].block[0][1] = BatchRenormalization2D(1152, eps=1e-05, momentum=0.9)
efficientnet_b0.features[6][2].block[1][1] = BatchRenormalization2D(1152, eps=1e-05, momentum=0.9)
efficientnet_b0.features[6][2].block[3][1] = BatchRenormalization2D(192, eps=1e-05, momentum=0.9)
efficientnet_b0.features[6][3].block[0][1] = BatchRenormalization2D(1152, eps=1e-05, momentum=0.9)
efficientnet_b0.features[6][3].block[1][1] = BatchRenormalization2D(1152, eps=1e-05, momentum=0.9)
efficientnet_b0.features[6][3].block[3][1] = BatchRenormalization2D(192, eps=1e-05, momentum=0.9)
efficientnet_b0.features[7][0].block[0][1] = BatchRenormalization2D(1152, eps=1e-05, momentum=0.9)
efficientnet_b0.features[7][0].block[1][1] = BatchRenormalization2D(1152, eps=1e-05, momentum=0.9)
efficientnet_b0.features[7][0].block[3][1] = BatchRenormalization2D(320, eps=1e-05, momentum=0.9)
efficientnet_b0.features[8][1] = BatchRenormalization2D(1280, eps=1e-05, momentum=0.9)


efficientnet_b0


EfficientNet(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchRenormalization2D()
      (2): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): MBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
            (1): BatchRenormalization2D()
            (2): SiLU(inplace=True)
          )
          (1): SqueezeExcitation(
            (avgpool): AdaptiveAvgPool2d(output_size=1)
            (fc1): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
            (fc2): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
            (activation): SiLU(inplace=True)
            (scale_activation): Sigmoid()
          )
          (2): Conv2dNormActivation(
            (0): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (1): BatchReno

In [19]:
vgg11 = torchvision.models.vgg11_bn(pretrained=False)

In [20]:
vgg11.features[0] = nn.Conv2d(1, 64, kernel_size=(3,3), stride=(1,1), padding=(1,1))
vgg11.classifier[0]  = nn.Linear(in_features=25088, out_features=2048, bias=True)
vgg11.classifier[3]  = nn.Linear(in_features=2048, out_features=2048, bias=True)
vgg11.classifier[6]  = nn.Linear(in_features=2048, out_features=10, bias=True)
vgg11

VGG(
  (features): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(inplace=True)
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU(inplace=True)
    (11): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (13): ReLU(inplace=True)
    (14): MaxPool2d(ke

In [21]:
model_net = resnet18
#model_net = vgg11
#model_net = efficientnet_b0

In [22]:
model_net

ResNet(
  (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [23]:
list(model_net.parameters())

[Parameter containing:
 tensor([[[[ 0.1104, -0.0825,  0.0806,  ..., -0.0860,  0.0129, -0.0818],
           [ 0.1157, -0.1330, -0.0123,  ...,  0.0586, -0.0160, -0.0171],
           [ 0.0278,  0.0681,  0.0630,  ...,  0.0624, -0.0851,  0.0279],
           ...,
           [ 0.0336, -0.1372,  0.0101,  ..., -0.0084, -0.1281,  0.0255],
           [ 0.1191, -0.1048,  0.0253,  ...,  0.0428, -0.1027, -0.0991],
           [ 0.0810, -0.0795, -0.0245,  ...,  0.1108,  0.1388, -0.0357]]],
 
 
         [[[ 0.0029,  0.0071,  0.1354,  ...,  0.1211, -0.0552,  0.0959],
           [ 0.1349,  0.0403, -0.0887,  ..., -0.0446,  0.0292, -0.1055],
           [-0.0577,  0.0573, -0.0187,  ..., -0.0787, -0.1053,  0.0459],
           ...,
           [ 0.0572,  0.1187,  0.1403,  ...,  0.0857,  0.1323,  0.1027],
           [ 0.0077, -0.1402, -0.0480,  ...,  0.0487, -0.0592, -0.0268],
           [ 0.0454, -0.1232, -0.0501,  ..., -0.0694,  0.1049,  0.0420]]],
 
 
         [[[ 0.0439,  0.0362,  0.0914,  ..., -0.0103, -0.

In [24]:
params_to_update = []
for param in model_net.parameters():
    if param.requires_grad == True:
        params_to_update.append(param)
'''
FEDPROX
'''        
#from openfl.utilities.optimizers.torch import FedProxAdam        
#optimizer = FedProxAdam(params_to_update, lr=1e-4, mu=0.01)

'''
ORIGINALE
'''
optimizer = optim.Adam(params_to_update, lr=1e-4)
#optimizer = optim.AdamW(params_to_update, lr=0.001, weight_decay=0.02)
#optimizer = optim.SGD(params_to_update, lr=0.01, momentum=0.9, weight_decay=0.0005)

#scheduler
#scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

def cross_entropy(output, target):
    """Cross-entropy metric
    """
    #return F.cross_entropy(input=output,target=target)
    #return F.binary_cross_entropy_with_logits(input=output,target=target)
    criterion = nn.CrossEntropyLoss()
    loss = criterion(output, target)
    return loss

### Register model

In [25]:
framework_adapter = 'openfl.plugins.frameworks_adapters.pytorch_adapter.FrameworkAdapterPlugin'
model_interface = ModelInterface(model=model_net, optimizer=optimizer, framework_plugin=framework_adapter)

# Save the initial model state
initial_model = deepcopy(model_net)

## Define and register FL tasks

In [26]:
task_interface = TaskInterface()

'''
FEDCURV
'''
#from openfl.utilities.fedcurv.torch import FedCurv
#from openfl.component.aggregation_functions import FedCurvWeightedAverage
#import tqdm

#fedcurv = FedCurv(model_interface.provide_model(), importance=1e3)

'''
FEDOPT
'''

#from openfl.component.aggregation_functions import AdagradAdaptiveAggregation    
#agg_fn = AdagradAdaptiveAggregation(model_interface=model_interface, learning_rate=0.4)     
#@task_interface.set_aggregation_function(agg_fn)


# The Interactive API supports registering functions definied in main module or imported.
def function_defined_in_notebook(some_parameter):
    print(f'Also I accept a parameter and it is {some_parameter}')

# Task interface currently supports only standalone functions.
@task_interface.add_kwargs(**{'some_parameter': 42})
@task_interface.register_fl_task(model='net_model', data_loader='train_loader', \
                     device='device', optimizer='optimizer') 
#@task_interface.set_aggregation_function(FedCurvWeightedAverage())


def train(net_model, train_loader, optimizer, device, loss_fn=cross_entropy, some_parameter=None):
    torch.manual_seed(myseed)
    #fedcurv.on_train_begin(net_model)
    device='cuda'
    function_defined_in_notebook(some_parameter)
    
    train_loader = tqdm.tqdm(train_loader, desc="train")
    net_model.train()
    net_model.to(device)

    losses = []
    epochs = 1
    
    for epoch in range(epochs):
        for data, target in train_loader:
            data, target = torch.tensor(data).to(device), torch.tensor(
                target).to(device, dtype=torch.int64)
            optimizer.zero_grad()
            #data = data.type(torch.LongTensor)
            #target = target.type(torch.LongTensor)
            output = net_model(data)
            #output = output.logits #per GOOGLENET
            loss = loss_fn(output=output, target=target) #+ fedcurv.get_penalty(net_model)
            loss.backward()
            optimizer.step()
            losses.append(loss.detach().cpu().numpy())
    #fedcurv.on_train_end(net_model, train_loader, device)    
    return {'train_loss': np.mean(losses),}


@task_interface.register_fl_task(model='net_model', data_loader='val_loader', device='device')     
def validate(net_model, val_loader, device):
    torch.manual_seed(myseed)
    device = torch.device('cuda')
    net_model.eval()
    net_model.to(device)
    
    val_loader = tqdm.tqdm(val_loader, desc="validate")
    val_score = 0
    total_samples = 0

    with torch.no_grad():
        for data, target in val_loader:
            samples = target.shape[0]
            total_samples += samples
            data, target = torch.tensor(data).to(device), \
                torch.tensor(target).to(device, dtype=torch.int64)
            output = net_model(data)
            #da wine
            #_, preds = torch.max(outputs, dim=1)
            #return torch.tensor(torch.sum(preds == labels).item() / len(preds))
            
            #originale
            #pred = output.argmax(dim=1,keepdim=True)
            
            #tentativo
            _, pred = torch.max(output, dim=1)
            val_score += pred.eq(target).sum().cpu().numpy()
            
    return {'acc': val_score / total_samples,}

## Time to start a federated learning experiment

In [27]:
# create an experimnet in federation
experiment_name = 'mnist_WEAK_UNIFORM_CLIENTS2_resnetBN_Adam'
fl_experiment = FLExperiment(federation=federation, experiment_name=experiment_name)

In [28]:
# The following command zips the workspace and python requirements to be transfered to collaborator nodes
fl_experiment.start(
    model_provider=model_interface, 
    task_keeper=task_interface,
    data_loader=fed_dataset,
    rounds_to_train=200,
    opt_treatment='CONTINUE_GLOBAL'
)



In [29]:
# If user want to stop IPython session, then reconnect and check how experiment is going
# fl_experiment.restore_experiment_state(model_interface)

fl_experiment.stream_metrics(tensorboard_logs=True)

In [30]:
list(fl_experiment.get_last_model().parameters())

  new_state[k] = pt.from_numpy(tensor_dict.pop(k)).to(device)


[Parameter containing:
 tensor([[[[ 0.1130, -0.0814,  0.0799,  ..., -0.0880,  0.0100, -0.0846],
           [ 0.1197, -0.1284, -0.0120,  ...,  0.0568, -0.0182, -0.0175],
           [ 0.0308,  0.0693,  0.0627,  ...,  0.0630, -0.0841,  0.0297],
           ...,
           [ 0.0340, -0.1368,  0.0097,  ..., -0.0081, -0.1270,  0.0257],
           [ 0.1201, -0.1017,  0.0255,  ...,  0.0427, -0.1017, -0.0982],
           [ 0.0819, -0.0766, -0.0243,  ...,  0.1104,  0.1385, -0.0351]]],
 
 
         [[[ 0.0021,  0.0056,  0.1325,  ...,  0.1216, -0.0560,  0.0957],
           [ 0.1330,  0.0390, -0.0901,  ..., -0.0437,  0.0296, -0.1053],
           [-0.0604,  0.0562, -0.0195,  ..., -0.0778, -0.1053,  0.0451],
           ...,
           [ 0.0566,  0.1195,  0.1444,  ...,  0.0886,  0.1337,  0.1029],
           [ 0.0065, -0.1401, -0.0444,  ...,  0.0507, -0.0592, -0.0264],
           [ 0.0447, -0.1224, -0.0480,  ..., -0.0680,  0.1044,  0.0413]]],
 
 
         [[[ 0.0441,  0.0351,  0.0895,  ..., -0.0130, -0.