# Federal Learning

`ewan.barel`

In [None]:
from sklearn.metrics import classification_report
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import transforms
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision


In [None]:
random_state = 42
torch.manual_seed(random_state)
torch.cuda.manual_seed(random_state)
torch.backends.cudnn.deterministic = True


In [None]:
# 1. Load the MNIST dataset (or any other dataset like HAM 10000).

transform = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize((0.1307,), (0.3081,)),
])

dataset_train = torchvision.datasets.MNIST(
  root="./data",
  train=True,
  download=True,
  transform=transform,
  target_transform=None,
)

dataset_test = torchvision.datasets.MNIST(
  root="./data",
  train=False,
  transform=transform,
  target_transform=None,
)


In [None]:
# 2. Extract two subsets of 600 data points each (without intersection).

def extract_subsets(dataset, subset_size=600, random_state=42):
  gen = np.random.default_rng(random_state)
  indices = gen.permutation(len(dataset))
  indices = indices[:2 * subset_size]
  subset_1 = Subset(dataset, indices[:subset_size])
  subset_2 = Subset(dataset, indices[subset_size:])
  return subset_1, subset_2


subset_1, subset_2 = extract_subsets(dataset_train, subset_size=600, random_state=random_state)

assert len(set(subset_1) & set(subset_2)) == 0
assert len(subset_1) == 600 and len(subset_2) == 600


In [None]:
# 3. Create a simple Convolutional Neural Network (2 convolutional layers and 2 dense layers, for example).

class SimpleCNN(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)
    self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
    self.pool = nn.MaxPool2d(2, 2)
    self.fc1 = nn.Linear(64 * 7 * 7, 128)
    self.fc2 = nn.Linear(128, 10)

  def forward(self, x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = torch.flatten(x, 1)
    x = F.relu(self.fc1(x))
    x = self.fc2(x)
    return x

  def clone(self):
    return copy.deepcopy(self)


In [None]:
# 4. Create a function average_model_parameters(models: iterable, average_weight): iterable that takes
# a list of models as an argument and returns the weighted average of the parameters of each model.

def average_model_parameters(models, average_weight):
  assert len(models) == len(average_weight)
  assert sum(average_weight) == 1.0
  average_params = [average_weight[0] * param.data.clone() for param in models[0].parameters()]
  for model, weight in zip(models[1:], average_weight[1:]):
    for param_idx, param in enumerate(model.parameters()):
      average_params[param_idx] += weight * param.data.clone()
  return average_params


In [None]:
# 5. Create a function that updates the parameters of a model from a list of values.

def update_model_parameters(model, values):
  for param, value in zip(model.parameters(), values):
    param.data = value


In [None]:
# 6. Create a script/code/function that reproduces Algorithm 1, considering that both models are on your machine.

class Client:
  def __init__(self, dataset, batch_size=60, device=None):
    self.dataset = dataset
    self.batch_size = batch_size
    self.device = device

  def train(self, model, epochs=10, lr=1e-4):
    model = model.to(self.device)
    data_loader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    model.train()
    for epoch in range(epochs):
      loss_sum = 0
      for x, y_true in data_loader:
        x, y_true = x.to(self.device), y_true.to(self.device)
        optimizer.zero_grad()
        y_pred = model(x)
        loss = F.cross_entropy(y_pred, y_true)
        loss.backward()
        optimizer.step()
        loss_sum += loss.item()
      print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss_sum / len(data_loader)}")
    print()


class FederatedAveraging:
  def __init__(self, clients, average_weight, random_state=42):
    self.clients = clients
    self.average_weight = average_weight
    self.gen = np.random.default_rng(random_state)

  def __call__(self, global_model_cls, C=1.0):
    K = len(self.clients)
    m = max(int(C * K), 1)
    client_indices = self.gen.choice(K, m, replace=False)

    client_models = [global_model_cls() for _ in range(m)]
    for model_idx, client_idx in enumerate(client_indices):
      self.clients[client_idx].train(client_models[model_idx], epochs=10, lr=1e-4)
    average_params = average_model_parameters(client_models, self.average_weight)

    global_model = global_model_cls()
    update_model_parameters(global_model, average_params)
    return global_model


# Use an average_weight=[1/2, 1/2].
average_weight = [1/2, 1/2]

# Reuse the same setup as in the article (50 examples per local batch).
local_batch_size = 50


In [None]:
# 7. Train your models without initializing the common parameters ...

device = "cuda" if torch.cuda.is_available() else "cpu"

client_1 = Client(subset_1, batch_size=local_batch_size, device=device)
client_2 = Client(subset_2, batch_size=local_batch_size, device=device)
clients = [client_1, client_2]

algorithm = FederatedAveraging(clients, average_weight, random_state=random_state)
global_model = algorithm(SimpleCNN, C=1.0)


Epoch 1/10, Loss: 2.260442634423574
Epoch 2/10, Loss: 2.13232151667277
Epoch 3/10, Loss: 1.950896253188451
Epoch 4/10, Loss: 1.7107790509859722
Epoch 5/10, Loss: 1.4360344807306926
Epoch 6/10, Loss: 1.1656321982542674
Epoch 7/10, Loss: 0.9339665820201238
Epoch 8/10, Loss: 0.7675834347804388
Epoch 9/10, Loss: 0.6465268582105637
Epoch 10/10, Loss: 0.5582850947976112

Epoch 1/10, Loss: 2.2641897002855935
Epoch 2/10, Loss: 2.137272854646047
Epoch 3/10, Loss: 1.9561141431331635
Epoch 4/10, Loss: 1.711478481690089
Epoch 5/10, Loss: 1.4293417235215504
Epoch 6/10, Loss: 1.148185799519221
Epoch 7/10, Loss: 0.9066081643104553
Epoch 8/10, Loss: 0.7409565349419912
Epoch 9/10, Loss: 0.6186471978823344
Epoch 10/10, Loss: 0.5460866168141365



In [None]:
# 7. ... and measure the performance on the entire dataset.

def evaluate(model, dataset, target_names=None):
  data_loader = DataLoader(dataset, batch_size=len(dataset), shuffle=False)
  model.eval()
  with torch.no_grad():
    for x, y_true in data_loader:
      y_pred = model(x)
      loss = F.cross_entropy(y_pred, y_true)
      print(classification_report(y_true, y_pred.argmax(dim=1), target_names=target_names, zero_division=0.0))
      print(f"test loss = {loss.item()}")

global_model = global_model.to("cpu")
evaluate(global_model, dataset_test)


              precision    recall  f1-score   support

           0       0.71      0.97      0.82       980
           1       0.55      0.99      0.71      1135
           2       0.88      0.75      0.81      1032
           3       0.88      0.63      0.73      1010
           4       0.55      0.88      0.68       982
           5       0.63      0.80      0.71       892
           6       0.97      0.71      0.82       958
           7       0.57      0.89      0.70      1028
           8       0.00      0.00      0.00       974
           9       0.00      0.00      0.00      1009

    accuracy                           0.67     10000
   macro avg       0.58      0.66      0.60     10000
weighted avg       0.57      0.67      0.60     10000

test loss = 1.8237099647521973


In [None]:
# 8. Train your models with the initialization of common parameters and verify that the performance is better.

class FederatedAveraging:
  def __init__(self, clients, average_weight, random_state=42):
    self.clients = clients
    self.average_weight = average_weight
    self.gen = np.random.default_rng(random_state)

  def __call__(self, global_model, C=1.0):
    K = len(self.clients)
    m = max(int(C * K), 1)
    client_indices = self.gen.choice(K, m, replace=False)

    client_models = [global_model.clone() for _ in range(m)]
    for model_idx, client_idx in enumerate(client_indices):
      self.clients[client_idx].train(client_models[model_idx], epochs=10, lr=1e-4)
    average_params = average_model_parameters(client_models, self.average_weight)

    update_model_parameters(global_model, average_params)


algorithm = FederatedAveraging(clients, average_weight, random_state=random_state)
global_model = SimpleCNN()

algorithm(global_model, C=1.0)

global_model = global_model.to("cpu")
evaluate(global_model, dataset_test)


Epoch 1/10, Loss: 2.242070992787679
Epoch 2/10, Loss: 2.050161898136139
Epoch 3/10, Loss: 1.8017733693122864
Epoch 4/10, Loss: 1.495373805363973
Epoch 5/10, Loss: 1.1988696853319805
Epoch 6/10, Loss: 0.9511296500762304
Epoch 7/10, Loss: 0.7785952091217041
Epoch 8/10, Loss: 0.6539307683706284
Epoch 9/10, Loss: 0.5606670354803404
Epoch 10/10, Loss: 0.49569424490133923

Epoch 1/10, Loss: 2.2349798480669656
Epoch 2/10, Loss: 2.0342212915420532
Epoch 3/10, Loss: 1.7736465533574421
Epoch 4/10, Loss: 1.460756778717041
Epoch 5/10, Loss: 1.1492241472005844
Epoch 6/10, Loss: 0.8984330246845881
Epoch 7/10, Loss: 0.7192040284474691
Epoch 8/10, Loss: 0.5979145144422849
Epoch 9/10, Loss: 0.5090892165899277
Epoch 10/10, Loss: 0.45369911193847656

              precision    recall  f1-score   support

           0       0.88      0.96      0.92       980
           1       0.92      0.96      0.94      1135
           2       0.91      0.85      0.88      1032
           3       0.76      0.91      0.

In [None]:
# 9. Reduce the number of data points in each sub-batch. What is the minimum number of data points
# necessary for the final model to have acceptable performance?

def evaluate_subset_size(dataset_train, dataset_test, model_cls, subset_size, random_state=42, target_names=None):
  subset_1, subset_2 = extract_subsets(dataset_train, subset_size=subset_size, random_state=random_state)

  assert len(set(subset_1) & set(subset_2)) == 0
  assert len(subset_1) == subset_size and len(subset_2) == subset_size

  average_weight = [1/2, 1/2]
  local_batch_size = 50

  device = "cuda" if torch.cuda.is_available() else "cpu"

  client_1 = Client(subset_1, batch_size=local_batch_size, device=device)
  client_2 = Client(subset_2, batch_size=local_batch_size, device=device)
  clients = [client_1, client_2]

  global_model = model_cls()

  algorithm = FederatedAveraging(clients, average_weight, random_state=random_state)
  algorithm(global_model, C=1.0)

  global_model = global_model.to("cpu")

  evaluate(global_model, dataset_test, target_names=target_names)


In [None]:
# Subset size: 300
evaluate_subset_size(dataset_train, dataset_test, SimpleCNN, subset_size=300, random_state=random_state)

Epoch 1/10, Loss: 2.2740848859151206
Epoch 2/10, Loss: 2.199165423711141
Epoch 3/10, Loss: 2.1264237562815347
Epoch 4/10, Loss: 2.0396344661712646
Epoch 5/10, Loss: 1.9353689352671306
Epoch 6/10, Loss: 1.8142779469490051
Epoch 7/10, Loss: 1.6792110006014507
Epoch 8/10, Loss: 1.5358028809229534
Epoch 9/10, Loss: 1.3876149257024128
Epoch 10/10, Loss: 1.24173508087794

Epoch 1/10, Loss: 2.2814810276031494
Epoch 2/10, Loss: 2.2043351332346597
Epoch 3/10, Loss: 2.1314351161321006
Epoch 4/10, Loss: 2.0367972453435264
Epoch 5/10, Loss: 1.924130916595459
Epoch 6/10, Loss: 1.7990633249282837
Epoch 7/10, Loss: 1.6503669222195942
Epoch 8/10, Loss: 1.4985593756039937
Epoch 9/10, Loss: 1.3339777787526448
Epoch 10/10, Loss: 1.181776722272237

              precision    recall  f1-score   support

           0       0.89      0.88      0.88       980
           1       0.77      0.97      0.86      1135
           2       0.91      0.77      0.84      1032
           3       0.57      0.89      0.70 

In [None]:
# Subset size: 225
evaluate_subset_size(dataset_train, dataset_test, SimpleCNN, subset_size=225, random_state=random_state)

Epoch 1/10, Loss: 2.299188566207886
Epoch 2/10, Loss: 2.2145005226135255
Epoch 3/10, Loss: 2.143722724914551
Epoch 4/10, Loss: 2.064178490638733
Epoch 5/10, Loss: 1.9833114862442016
Epoch 6/10, Loss: 1.885623860359192
Epoch 7/10, Loss: 1.7870988607406617
Epoch 8/10, Loss: 1.6532409906387329
Epoch 9/10, Loss: 1.5299564838409423
Epoch 10/10, Loss: 1.3875293731689453

Epoch 1/10, Loss: 2.2792191028594972
Epoch 2/10, Loss: 2.203948640823364
Epoch 3/10, Loss: 2.118944215774536
Epoch 4/10, Loss: 2.0378366470336915
Epoch 5/10, Loss: 1.9357996225357055
Epoch 6/10, Loss: 1.8372415542602538
Epoch 7/10, Loss: 1.7240669965744018
Epoch 8/10, Loss: 1.572980785369873
Epoch 9/10, Loss: 1.4463969707489013
Epoch 10/10, Loss: 1.34232439994812

              precision    recall  f1-score   support

           0       0.80      0.94      0.86       980
           1       0.88      0.92      0.90      1135
           2       0.88      0.82      0.85      1032
           3       0.58      0.87      0.69     

In [None]:
# Subset size: 150
evaluate_subset_size(dataset_train, dataset_test, SimpleCNN, subset_size=150, random_state=random_state)

Epoch 1/10, Loss: 2.287604649861654
Epoch 2/10, Loss: 2.2228218714396157
Epoch 3/10, Loss: 2.173189163208008
Epoch 4/10, Loss: 2.121098279953003
Epoch 5/10, Loss: 2.066273053487142
Epoch 6/10, Loss: 2.0017035404841104
Epoch 7/10, Loss: 1.9335798422495525
Epoch 8/10, Loss: 1.8602807521820068
Epoch 9/10, Loss: 1.783068060874939
Epoch 10/10, Loss: 1.7008123397827148

Epoch 1/10, Loss: 2.288973013559977
Epoch 2/10, Loss: 2.2297961711883545
Epoch 3/10, Loss: 2.176190217336019
Epoch 4/10, Loss: 2.127875645955404
Epoch 5/10, Loss: 2.0736462275187173
Epoch 6/10, Loss: 2.0125320752461753
Epoch 7/10, Loss: 1.946573297182719
Epoch 8/10, Loss: 1.8765028317769368
Epoch 9/10, Loss: 1.7989062070846558
Epoch 10/10, Loss: 1.7127409378687541

              precision    recall  f1-score   support

           0       0.45      0.98      0.62       980
           1       0.93      0.87      0.90      1135
           2       1.00      0.25      0.40      1032
           3       0.47      0.89      0.61     

In [None]:
# Subset size: 75
evaluate_subset_size(dataset_train, dataset_test, SimpleCNN, subset_size=75, random_state=random_state)

Epoch 1/10, Loss: 2.2963664531707764
Epoch 2/10, Loss: 2.2229732275009155
Epoch 3/10, Loss: 2.184417486190796
Epoch 4/10, Loss: 2.121456503868103
Epoch 5/10, Loss: 2.0725066661834717
Epoch 6/10, Loss: 2.01925790309906
Epoch 7/10, Loss: 1.990318477153778
Epoch 8/10, Loss: 1.9431979060173035
Epoch 9/10, Loss: 1.8549925684928894
Epoch 10/10, Loss: 1.809952735900879

Epoch 1/10, Loss: 2.319611430168152
Epoch 2/10, Loss: 2.2495293617248535
Epoch 3/10, Loss: 2.194462776184082
Epoch 4/10, Loss: 2.171608567237854
Epoch 5/10, Loss: 2.113717198371887
Epoch 6/10, Loss: 2.1115788221359253
Epoch 7/10, Loss: 2.0749809741973877
Epoch 8/10, Loss: 1.9912885427474976
Epoch 9/10, Loss: 1.9675783514976501
Epoch 10/10, Loss: 1.9404996037483215

              precision    recall  f1-score   support

           0       0.28      0.98      0.43       980
           1       1.00      0.56      0.71      1135
           2       0.95      0.33      0.49      1032
           3       0.36      0.70      0.48      

In [None]:
# Repeat the study on CIFAR-10.

In [None]:
# Load the CIFAR-10 dataset.

transform = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

dataset_train = torchvision.datasets.CIFAR10(
  root="./data",
  train=True,
  download=True,
  transform=transform,
  target_transform=None,
)

dataset_test = torchvision.datasets.CIFAR10(
  root="./data",
  train=False,
  transform=transform,
  target_transform=None,
)

target_names = [
  "airplane",
  "automobile",
  "bird",
  "cat",
  "deer",
  "dog",
  "frog",
  "horse",
  "ship",
  "truck",
]


Files already downloaded and verified


In [None]:
# Create a simple Convolutional Neural Network.

class SimpleCNN(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1)
    self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
    self.pool = nn.MaxPool2d(2, 2)
    self.fc1 = nn.Linear(64 * 8 * 8, 256)
    self.fc2 = nn.Linear(256, 10)

  def forward(self, x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = torch.flatten(x, 1)
    x = F.relu(self.fc1(x))
    x = self.fc2(x)
    return x

  def clone(self):
    return copy.deepcopy(self)


In [None]:
# Subset size: 600
evaluate_subset_size(
  dataset_train,
  dataset_test,
  SimpleCNN,
  subset_size=600,
  random_state=random_state,
  target_names=target_names,
)


Epoch 1/10, Loss: 2.2762206196784973
Epoch 2/10, Loss: 2.164184014002482
Epoch 3/10, Loss: 2.0451194445292153
Epoch 4/10, Loss: 1.923590620358785
Epoch 5/10, Loss: 1.825722297032674
Epoch 6/10, Loss: 1.7423799733320873
Epoch 7/10, Loss: 1.669473260641098
Epoch 8/10, Loss: 1.5901649991671245
Epoch 9/10, Loss: 1.5202670494715373
Epoch 10/10, Loss: 1.4884677628676097

Epoch 1/10, Loss: 2.275067925453186
Epoch 2/10, Loss: 2.161841571331024
Epoch 3/10, Loss: 2.036328057448069
Epoch 4/10, Loss: 1.9159430960814159
Epoch 5/10, Loss: 1.7957518100738525
Epoch 6/10, Loss: 1.7069916625817616
Epoch 7/10, Loss: 1.6355561117331188
Epoch 8/10, Loss: 1.5722739199797313
Epoch 9/10, Loss: 1.5193351010481517
Epoch 10/10, Loss: 1.4375033378601074

              precision    recall  f1-score   support

    airplane       0.48      0.44      0.46      1000
  automobile       0.35      0.54      0.42      1000
        bird       0.33      0.09      0.14      1000
         cat       0.27      0.16      0.20   

In [None]:
# Subset size: 1200
evaluate_subset_size(
  dataset_train,
  dataset_test,
  SimpleCNN,
  subset_size=1200,
  random_state=random_state,
  target_names=target_names,
)


Epoch 1/10, Loss: 2.2445366283257804
Epoch 2/10, Loss: 2.0492428988218307
Epoch 3/10, Loss: 1.8842016359170277
Epoch 4/10, Loss: 1.78955739736557
Epoch 5/10, Loss: 1.6870730568965275
Epoch 6/10, Loss: 1.622031773130099
Epoch 7/10, Loss: 1.550882250070572
Epoch 8/10, Loss: 1.5009870181481044
Epoch 9/10, Loss: 1.4624215215444565
Epoch 10/10, Loss: 1.4100553294022877

Epoch 1/10, Loss: 2.24455855290095
Epoch 2/10, Loss: 2.077046404282252
Epoch 3/10, Loss: 1.917370542883873
Epoch 4/10, Loss: 1.8110644221305847
Epoch 5/10, Loss: 1.7192897995313008
Epoch 6/10, Loss: 1.6673217912515004
Epoch 7/10, Loss: 1.5926808963219325
Epoch 8/10, Loss: 1.5369653056065242
Epoch 9/10, Loss: 1.4779817312955856
Epoch 10/10, Loss: 1.4092650959889095

              precision    recall  f1-score   support

    airplane       0.63      0.33      0.44      1000
  automobile       0.34      0.78      0.47      1000
        bird       0.36      0.14      0.21      1000
         cat       0.32      0.16      0.21    

In [None]:
# Subset size: 2400
evaluate_subset_size(
  dataset_train,
  dataset_test,
  SimpleCNN,
  subset_size=2400,
  random_state=random_state,
  target_names=target_names,
)


Epoch 1/10, Loss: 2.160722034672896
Epoch 2/10, Loss: 1.8905263418952625
Epoch 3/10, Loss: 1.7529161646962166
Epoch 4/10, Loss: 1.6461766511201859
Epoch 5/10, Loss: 1.5636423428853352
Epoch 6/10, Loss: 1.478153757750988
Epoch 7/10, Loss: 1.4234377294778824
Epoch 8/10, Loss: 1.3802245905001957
Epoch 9/10, Loss: 1.3284379616379738
Epoch 10/10, Loss: 1.2715379695097606

Epoch 1/10, Loss: 2.1589148392279944
Epoch 2/10, Loss: 1.8853613312045734
Epoch 3/10, Loss: 1.747187706331412
Epoch 4/10, Loss: 1.6396285966038704
Epoch 5/10, Loss: 1.5527827988068263
Epoch 6/10, Loss: 1.4695507536331813
Epoch 7/10, Loss: 1.404315394659837
Epoch 8/10, Loss: 1.3612124174833298
Epoch 9/10, Loss: 1.2895467355847359
Epoch 10/10, Loss: 1.2641845345497131

              precision    recall  f1-score   support

    airplane       0.62      0.46      0.53      1000
  automobile       0.46      0.76      0.57      1000
        bird       0.32      0.42      0.37      1000
         cat       0.39      0.10      0.16

In [None]:
# Subset size: 4800
evaluate_subset_size(
  dataset_train,
  dataset_test,
  SimpleCNN,
  subset_size=4800,
  random_state=random_state,
  target_names=target_names,
)


Epoch 1/10, Loss: 2.0514990389347076
Epoch 2/10, Loss: 1.7418601860602696
Epoch 3/10, Loss: 1.6064197445909183
Epoch 4/10, Loss: 1.5091173065205414
Epoch 5/10, Loss: 1.4326312926908333
Epoch 6/10, Loss: 1.3713121190667152
Epoch 7/10, Loss: 1.3099435611317556
Epoch 8/10, Loss: 1.2604158998777468
Epoch 9/10, Loss: 1.2185252190877993
Epoch 10/10, Loss: 1.176115450138847

Epoch 1/10, Loss: 2.043006233870983
Epoch 2/10, Loss: 1.7506167404353619
Epoch 3/10, Loss: 1.6235065770645936
Epoch 4/10, Loss: 1.5230095970133941
Epoch 5/10, Loss: 1.4431847284237544
Epoch 6/10, Loss: 1.3776390440762043
Epoch 7/10, Loss: 1.317013497153918
Epoch 8/10, Loss: 1.2692939701179664
Epoch 9/10, Loss: 1.226432188724478
Epoch 10/10, Loss: 1.1719137895852327

              precision    recall  f1-score   support

    airplane       0.52      0.65      0.58      1000
  automobile       0.64      0.64      0.64      1000
        bird       0.39      0.41      0.40      1000
         cat       0.41      0.24      0.30

In [None]:
# Subset size: 9600
evaluate_subset_size(
  dataset_train,
  dataset_test,
  SimpleCNN,
  subset_size=9600,
  random_state=random_state,
  target_names=target_names,
)


Epoch 1/10, Loss: 1.9340408239513636
Epoch 2/10, Loss: 1.6021019003043573
Epoch 3/10, Loss: 1.4707127753645182
Epoch 4/10, Loss: 1.3836818511287372
Epoch 5/10, Loss: 1.3255397975444794
Epoch 6/10, Loss: 1.2604104336351156
Epoch 7/10, Loss: 1.2169337114319205
Epoch 8/10, Loss: 1.1670714548478525
Epoch 9/10, Loss: 1.1206290638074279
Epoch 10/10, Loss: 1.0837577944621444

Epoch 1/10, Loss: 1.9122319320837657
Epoch 2/10, Loss: 1.5877839637299378
Epoch 3/10, Loss: 1.4599641760190327
Epoch 4/10, Loss: 1.3628605616589387
Epoch 5/10, Loss: 1.299198083889981
Epoch 6/10, Loss: 1.2380698326354225
Epoch 7/10, Loss: 1.1848011485611398
Epoch 8/10, Loss: 1.1482244307796161
Epoch 9/10, Loss: 1.103395421989262
Epoch 10/10, Loss: 1.0635811525086563

              precision    recall  f1-score   support

    airplane       0.70      0.53      0.61      1000
  automobile       0.66      0.72      0.69      1000
        bird       0.50      0.32      0.39      1000
         cat       0.43      0.29      0.

In [None]:
# Subset size: 19200
evaluate_subset_size(
  dataset_train,
  dataset_test,
  SimpleCNN,
  subset_size=19200,
  random_state=random_state,
  target_names=target_names,
)

Epoch 1/10, Loss: 1.7459366113568346
Epoch 2/10, Loss: 1.4363142453754942
Epoch 3/10, Loss: 1.3133919006213546
Epoch 4/10, Loss: 1.2221329243232806
Epoch 5/10, Loss: 1.1604999668585758
Epoch 6/10, Loss: 1.0993118326490123
Epoch 7/10, Loss: 1.0485400989030798
Epoch 8/10, Loss: 1.0020667305216193
Epoch 9/10, Loss: 0.9654473218445977
Epoch 10/10, Loss: 0.9275869346844653

Epoch 1/10, Loss: 1.7510980650161703
Epoch 2/10, Loss: 1.4339016117155552
Epoch 3/10, Loss: 1.3069589499694605
Epoch 4/10, Loss: 1.2174223546559613
Epoch 5/10, Loss: 1.145253004040569
Epoch 6/10, Loss: 1.0875080567784607
Epoch 7/10, Loss: 1.0366560622739296
Epoch 8/10, Loss: 0.9900819258764386
Epoch 9/10, Loss: 0.9484567805193365
Epoch 10/10, Loss: 0.9095035861246288

              precision    recall  f1-score   support

    airplane       0.66      0.71      0.69      1000
  automobile       0.71      0.81      0.76      1000
        bird       0.52      0.48      0.50      1000
         cat       0.55      0.31      0