# Federated Learning (Extreme Non-IID setting)


Train a centralized model on a decentralized data. Dataset used: CIFAR-10. In PyTorch, CIFAR 10 is available to use with the help of the torchvision module.

# To import all the relevant packages


In [None]:
import os
import random
from tqdm import tqdm
import numpy as np
import torch, torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data.dataset import Dataset
from torchvision import transforms 
from torchvision.transforms import Compose 
torch.backends.cudnn.benchmark=True

# Setting the Hyper-parameters

**classes_pc:** classes per client, further this is used to divide the balanced dataset to non-IID dataset by creating an unbalanced representation of classes among the clients. For e.g. if the classes_pc=1, then all the clients will have images from one class only, thus creating an extensive imbalance among the clients.

**num_clients:** Number of clients among which images are to be distributed.

**num_selected:** Number of randomly selected clients from num_clients during the start of each communication round. To be used in the training phase of the global model. Typically, num_selected is around 30–40% of the num_clients.

**num_rounds:** Total number of communication rounds for the global model to train. In each communication round, training on individual clients takes place simultaneously.

**epochs:** Number of local training rounds on each client’s device.

**batch_size:** Loading of the data into the data loader in batches.

**baseline_num:** Total number of baseline images to be saved on the global server for retraining of the client’s model before aggregation. This technique of retraining all the models on the global server deals with non-IID/real-world datasets.

**retrain_epochs:** Total number of retraining rounds on the global server after receiving the model weights from all the clients that participated in the communication round.


In [None]:
classes_pc = 2
num_clients = 20
num_selected = 6
num_rounds = 50
epochs = 5
batch_size = 32
baseline_num = 100
retrain_epochs = 10

# Creating the distribution



In [None]:
#### get cifar dataset in x and y form

#The get_cifar10 function downloads the CIFAR10 dataset and returns x_train, y_train for training, and x_test, y_test for test purposes.

def get_cifar10():
  '''Return CIFAR10 train/test data and labels as numpy arrays'''
  
  #downloads the dataset from torchvision
  data_train = torchvision.datasets.CIFAR10('./data', train=True, download=True)
  data_test = torchvision.datasets.CIFAR10('./data', train=False, download=True) 
  
  #converts it into NumPy array.
  x_train, y_train = data_train.data.transpose((0,3,1,2)), np.array(data_train.targets)
  x_test, y_test = data_test.data.transpose((0,3,1,2)), np.array(data_test.targets)
  
  return x_train, y_train, x_test, y_test

#Function to print the basic data stats
def print_image_data_stats(data_train, labels_train, data_test, labels_test):
  print("\nData: ")
  print(" - Train Set: ({},{}), Range: [{:.3f}, {:.3f}], Labels: {},..,{}".format(
    data_train.shape, labels_train.shape, np.min(data_train), np.max(data_train),
      np.min(labels_train), np.max(labels_train)))
  print(" - Test Set: ({},{}), Range: [{:.3f}, {:.3f}], Labels: {},..,{}".format(
    data_test.shape, labels_test.shape, np.min(data_train), np.max(data_train),
      np.min(labels_test), np.max(labels_test)))
  
#--------------------------------------------------------------------------------------------------------------------------------------#
#The clients_rand function creates a random distribution for the clients, such that every client has an arbitrary number of images. 
#It is one of the helper functions to be used.

def clients_rand(train_len, nclients):
  '''
  train_len: size of the train data
  nclients: number of clients
  
  Returns: to_ret
  
  This function creates a random distribution 
  for the clients, i.e. number of images each client 
  possess.
  '''
  client_tmp=[]
  sum_=0
  #### creating random values for each client ####
  for i in range(nclients-1):
    tmp=random.randint(10,100)
    sum_+=tmp
    client_tmp.append(tmp)

  client_tmp= np.array(client_tmp)
  #### using those random values as weights ####
  clients_dist= ((client_tmp/sum_)*train_len).astype(int)
  num  = train_len - clients_dist.sum()
  to_ret = list(clients_dist)
  to_ret.append(num)
  return to_ret

#--------------------------------------------------------------------------------------------------------------------------------------#
#The split_image_data function splits the given images into n_clients. 
#It returns a split which is further used to create the real-world dataset. 
# verbose: specifies verbosity mode(0 = silent, 1= progress bar, 2 = one line per epoch).

def split_image_data(data, labels, n_clients=100, classes_per_client=10, shuffle=True, verbose=True):
  '''
  Splits (data, labels) among 'n_clients s.t. every client can holds 'classes_per_client' number of classes
  Input:
    data : [n_data x shape]
    labels : [n_data (x 1)] from 0 to n_labels
    n_clients : number of clients
    classes_per_client : number of classes per client
    shuffle : True/False => True for shuffling the dataset, False otherwise
    verbose : True/False => True for printing some info, False otherwise
  Output:
    clients_split : client data into desired format
  '''
  #### constants #### 
  n_data = data.shape[0]
  n_labels = np.max(labels) + 1


  ### client distribution ####
  data_per_client = clients_rand(len(data), n_clients)
  data_per_client_per_class = [np.maximum(1,nd // classes_per_client) for nd in data_per_client]
  
  # sort for labels
  data_idcs = [[] for i in range(n_labels)]
  for j, label in enumerate(labels):
    data_idcs[label] += [j]
  if shuffle:
    for idcs in data_idcs:
      np.random.shuffle(idcs)
    
  # split data among clients
  clients_split = []
  c = 0
  for i in range(n_clients):
    client_idcs = []
        
    budget = data_per_client[i]
    c = np.random.randint(n_labels)
    while budget > 0:
      take = min(data_per_client_per_class[i], len(data_idcs[c]), budget)
      
      client_idcs += data_idcs[c][:take]
      data_idcs[c] = data_idcs[c][take:]
      
      budget -= take
      c = (c + 1) % n_labels
      
    clients_split += [(data[client_idcs], labels[client_idcs])]

  def print_split(clients_split): 
    print("Data split:")
    for i, client in enumerate(clients_split):
      split = np.sum(client[1].reshape(1,-1)==np.arange(n_labels).reshape(-1,1), axis=1)
      print(" - Client {}: {}".format(i,split))
    print()
      
    if verbose:
      print_split(clients_split)
  
  clients_split = np.array(clients_split)
  
  return clients_split

#To shuffle the images of each client respectively
def shuffle_list(data):
  '''
  This function returns the shuffled data
  '''
  for i in range(len(data)):
    tmp_len= len(data[i][0])
    index = [i for i in range(tmp_len)]
    random.shuffle(index)
    data[i][0],data[i][1] = shuffle_list_data(data[i][0],data[i][1])
  return data

#To further make the data unbalanced by shuffling the mapped array
def shuffle_list_data(x, y):
  '''
  This function is a helper function, shuffles an
  array while maintaining the mapping between x and y
  '''
  inds = list(range(len(x)))
  random.shuffle(inds)
  return x[inds],y[inds]

#The below code snippet converts the split into a data loader(image augmentation is done is this part) 
#for giving this as an input to the model for training.
class CustomImageDataset(Dataset):
  '''
  A custom Dataset class for images
  inputs : numpy array [n_data x shape]
  labels : numpy array [n_data (x 1)]
  '''
  def __init__(self, inputs, labels, transforms=None):
      assert inputs.shape[0] == labels.shape[0]
      self.inputs = torch.Tensor(inputs)
      self.labels = torch.Tensor(labels).long()
      self.transforms = transforms 

  def __getitem__(self, index):
      img, label = self.inputs[index], self.labels[index]

      if self.transforms is not None:
        img = self.transforms(img)

      return (img, label)

  def __len__(self):
      return self.inputs.shape[0]
          

def get_default_data_transforms(train=True, verbose=True):
  transforms_train = {
  'cifar10' : transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))]),#(0.24703223, 0.24348513, 0.26158784)
  }
  transforms_eval = {    
  'cifar10' : transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
  }
  if verbose:
    print("\nData preprocessing: ")
    for transformation in transforms_train['cifar10'].transforms:
      print(' -', transformation)
    print()

  return (transforms_train['cifar10'], transforms_eval['cifar10'])

#The get_data_loader function uses the above helper functions and converts the CIFAR10 dataset into non-IID type

def get_data_loaders(nclients,batch_size,classes_pc=10 ,verbose=True ):
  
  x_train, y_train, x_test, y_test = get_cifar10()

  if verbose:
    print_image_data_stats(x_train, y_train, x_test, y_test)

  transforms_train, transforms_eval = get_default_data_transforms(verbose=False)
  
  split = split_image_data(x_train, y_train, n_clients=nclients, 
        classes_per_client=classes_pc, verbose=verbose)
  
  split_tmp = shuffle_list(split)
  
  client_loaders = [torch.utils.data.DataLoader(CustomImageDataset(x, y, transforms_train), 
                                                                batch_size=batch_size, shuffle=True) for x, y in split_tmp]

  test_loader  = torch.utils.data.DataLoader(CustomImageDataset(x_test, y_test, transforms_eval), batch_size=100, shuffle=False) 

  return client_loaders, test_loader

# Building the Neural Network Model
**VGG:** It was proposed by the Visual Geometry Group of Oxford University in 2014 and obtained accurate classification performance on the ImageNet dataset.

**VGG19:** 16 convolution layers, 3 Fully Connected layers, 5 MaxPool layers (Summarizing the output of Convolution Layer), and 1 SoftMax layer (Softmax is implemented through a neural network layer just before the output layer. The Softmax layer must have the same number of nodes as the output layer).

In [None]:
#################################
##### Neural Network model #####
#################################

#VGG is a deep CNN used to classify images.

cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

class VGG(nn.Module):
    def __init__(self, vgg_name):
        super(VGG, self).__init__()
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(True),
            nn.Linear(512, 512),
            nn.ReLU(True),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        output = F.log_softmax(out, dim=1)
        return output

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.BatchNorm2d(x),
                           nn.ReLU(inplace=True)]
                in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)

# Helper functions for Federated Learning
The **baseline_data** function creates a loader for the baseline data on which the client’s model is retrained before the aggregation of weights on the global server. 

**‘num’** is the number of images on which the retraining of client’s model on the global server is supposed to take place. 

In [None]:
def baseline_data(num):
  '''
  Returns baseline data loader to be used on retraining on global server
  Input:
        num : size of baseline data
  Output:
        loader: baseline data loader
  '''
  xtrain, ytrain, xtmp,ytmp = get_cifar10()
  x , y = shuffle_list_data(xtrain, ytrain)

  x, y = x[:num], y[:num]
  transform, _ = get_default_data_transforms(train=True, verbose=False)
  loader = torch.utils.data.DataLoader(CustomImageDataset(x, y, transform), batch_size=16, shuffle=True)

  return loader

The **client_update** function trains the client model on the given private client data. This is the local training round that takes place for every selected client, i.e. num_selected (6 in our case).


In [None]:
def client_update(client_model, optimizer, train_loader, epoch=5):
    """
    This function updates/trains client model on client data
    """
    model.train()
    for e in range(epoch):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.cuda(), target.cuda()
            optimizer.zero_grad()
            output = client_model(data)
            loss = F.nll_loss(output, target)
            loss.backward()
            optimizer.step()
    return loss.item()

The **client_sync** function synchronizes the client model (before training) with global weights. It helps in the case when a particular client has not participated in the previous communication rounds, so it makes sure that all the selected clients have the previously trained weights from the global model.

In [None]:
def client_syn(client_model, global_model):
  '''
  This function synchronizes the client model with global model
  '''
  client_model.load_state_dict(global_model.state_dict())

The **server_aggregate** function aggregates the model weights received from every client and updates the global model with updated weights. Here, the weighted mean of the weights is calculated. In IID part of this code, instead of the weighted mean, the mean is used as an aggregation method.

In [None]:
def server_aggregate(global_model, client_models,client_lens):
    """
    This function has aggregation method 'wmean'
    wmean takes the weighted mean of the weights of models
    """
    total = sum(client_lens)
    n = len(client_models)
    global_dict = global_model.state_dict()
    for k in global_dict.keys():
        global_dict[k] = torch.stack([client_models[i].state_dict()[k].float()*(n*client_lens[i]/total) for i in range(len(client_models))], 0).mean(0)
    global_model.load_state_dict(global_dict)
    for model in client_models:
        model.load_state_dict(global_model.state_dict())

The **test** function is the standard function for evaluating the global model with the test dataset. It returns the test loss and test accuracy, which is used for a comparative study of different approaches.

In [None]:
def test(global_model, test_loader):
    """
    This function test the global model on test 
    data and returns test loss and test accuracy 
    """
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.cuda(), target.cuda()
            output = global_model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    acc = correct / len(test_loader.dataset)

    return test_loss, acc

# Training the Model
Global model, client’s models are initialized with the VGG19, and training is done on a GPU. 

In [None]:
############################################
#### Initializing models and optimizer  ####
############################################

#### global model ##########
global_model =  VGG('VGG19').cuda()

############# client models ###############################
#Initialize VGG19 on GPU
client_models = [ VGG('VGG19').cuda() for _ in range(num_selected)]
for model in client_models:
    model.load_state_dict(global_model.state_dict()) ### initial synchronizing with global modle 

###### optimizers ################
#the optimizer (SGD) is defined along with the learning rate.
opt = [optim.SGD(model.parameters(), lr=0.1) for model in client_models]

####### baseline data ############
#the baseline data is added to a loader with ‘baseline_num’ images, i.e. 100 images as defined previously
loader_fixed = baseline_data(baseline_num)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=0.0, max=170498071.0), HTML(value='')))


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


The non-IID data is loaded into a train_loader using the above functions, which ensures the data is non-IID. Classes_pc=2, num_clients=20, batch_size=32.

In [None]:
###### Loading the data using the above function ######
train_loader, test_loader = get_data_loaders(classes_pc=classes_pc, nclients= num_clients,
                                                      batch_size=batch_size,verbose=True)

Files already downloaded and verified
Files already downloaded and verified

Data: 
 - Train Set: ((50000, 3, 32, 32),(50000,)), Range: [0.000, 255.000], Labels: 0,..,9
 - Test Set: ((10000, 3, 32, 32),(10000,)), Range: [0.000, 255.000], Labels: 0,..,9




In [None]:
#create a list for keeping track of the loss and accuracy of the model on train and test dataset.
losses_train = []
losses_test = []
acc_test = []
losses_retrain=[]

# Runnining FL
'''starts the training of individual clients in communication rounds (num_rounds).
In every communication round, first, the selected clients are updated with the global weights. 
Then the local model is trained on the client’s device itself, following which the retraining round takes place on the global server. 
After retraining the client’s model, the aggregation of weights takes place.'''

for r in range(num_rounds):    #Communication round
    # select random clients

    #selects the num_selected clients from num_clients, i.e. six clients are randomly selected from a total of 20 clients. 
    #Training at the client’s device is done 
    client_idx = np.random.permutation(num_clients)[:num_selected]
    client_lens = [len(train_loader[idx]) for idx in client_idx]

    # client update
    loss = 0
    for i in tqdm(range(num_selected)):
      client_syn(client_models[i], global_model) #using the client_sync
      #where the local models are updated with the global weights before the training, and then client_update function
      #is used to start the training
      loss += client_update(client_models[i], opt[i], train_loader[client_idx[i]], epochs)
    losses_train.append(loss)

    ''' Once the local models are trained on the device itself, ensuring the privacy of the private data, they are sent to the global server.
        First, the retraining of these models with the baseline data is done. It is followed by the aggregation of these local models
        (weights) into one global model. After updating the global model, this global model is used to test the training with the 
        help of the test function defined before.'''

    # server aggregate
    #### retraining on the global server
    loss_retrain =0
    for i in tqdm(range(num_selected)):
      loss_retrain+= client_update(client_models[i], opt[i], loader_fixed, epoch=retrain_epochs)
    losses_retrain.append(loss_retrain)
    
    ### Aggregating the models
    server_aggregate(global_model, client_models,client_lens)
    test_loss, acc = test(global_model, test_loader)
    losses_test.append(test_loss)
    acc_test.append(acc)
    print('%d-th round' % r)
    print('average train loss %0.3g | test loss %0.3g | test acc: %0.3f' % (loss_retrain / num_selected, test_loss, acc))

#This process continues for num_rounds, i.e. 150 communication rounds in our case. 
#6 selected clients, each running 5 local epochs and retaining on the global server with 20 epochs on top of the 150 communication rounds

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
100%|██████████| 6/6 [01:09<00:00, 11.52s/it]
100%|██████████| 6/6 [00:08<00:00,  1.37s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

0-th round
average train loss 2.34 | test loss 2.32 | test acc: 0.100


100%|██████████| 6/6 [01:14<00:00, 12.41s/it]
100%|██████████| 6/6 [00:08<00:00,  1.34s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

1-th round
average train loss 2.07 | test loss 2.31 | test acc: 0.100


100%|██████████| 6/6 [00:54<00:00,  9.16s/it]
100%|██████████| 6/6 [00:08<00:00,  1.34s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

2-th round
average train loss 2.52 | test loss 2.32 | test acc: 0.106


100%|██████████| 6/6 [01:20<00:00, 13.35s/it]
100%|██████████| 6/6 [00:08<00:00,  1.34s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

3-th round
average train loss 2.05 | test loss 2.22 | test acc: 0.120


100%|██████████| 6/6 [01:11<00:00, 11.95s/it]
100%|██████████| 6/6 [00:08<00:00,  1.34s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

4-th round
average train loss 2.18 | test loss 2.31 | test acc: 0.091


100%|██████████| 6/6 [00:57<00:00,  9.53s/it]
100%|██████████| 6/6 [00:07<00:00,  1.33s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

5-th round
average train loss 2.19 | test loss 2.19 | test acc: 0.203


100%|██████████| 6/6 [01:03<00:00, 10.66s/it]
100%|██████████| 6/6 [00:07<00:00,  1.33s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

6-th round
average train loss 1.8 | test loss 2.35 | test acc: 0.079


100%|██████████| 6/6 [01:16<00:00, 12.81s/it]
100%|██████████| 6/6 [00:08<00:00,  1.34s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

7-th round
average train loss 1.95 | test loss 2.19 | test acc: 0.139


100%|██████████| 6/6 [01:03<00:00, 10.64s/it]
100%|██████████| 6/6 [00:08<00:00,  1.33s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

8-th round
average train loss 1.94 | test loss 1.98 | test acc: 0.225


100%|██████████| 6/6 [00:59<00:00,  9.84s/it]
100%|██████████| 6/6 [00:07<00:00,  1.33s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

9-th round
average train loss 2.18 | test loss 1.99 | test acc: 0.243


100%|██████████| 6/6 [01:15<00:00, 12.62s/it]
100%|██████████| 6/6 [00:08<00:00,  1.34s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

10-th round
average train loss 2.06 | test loss 1.88 | test acc: 0.235


100%|██████████| 6/6 [01:04<00:00, 10.67s/it]
100%|██████████| 6/6 [00:08<00:00,  1.34s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

11-th round
average train loss 2.15 | test loss 1.93 | test acc: 0.264


100%|██████████| 6/6 [01:09<00:00, 11.62s/it]
100%|██████████| 6/6 [00:08<00:00,  1.34s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

12-th round
average train loss 1.87 | test loss 1.86 | test acc: 0.270


100%|██████████| 6/6 [01:04<00:00, 10.73s/it]
100%|██████████| 6/6 [00:08<00:00,  1.34s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

13-th round
average train loss 1.8 | test loss 2.18 | test acc: 0.150


100%|██████████| 6/6 [01:09<00:00, 11.63s/it]
100%|██████████| 6/6 [00:08<00:00,  1.34s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

14-th round
average train loss 1.75 | test loss 1.89 | test acc: 0.298


100%|██████████| 6/6 [01:10<00:00, 11.69s/it]
100%|██████████| 6/6 [00:08<00:00,  1.34s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

15-th round
average train loss 1.88 | test loss 1.88 | test acc: 0.263


100%|██████████| 6/6 [01:07<00:00, 11.26s/it]
100%|██████████| 6/6 [00:08<00:00,  1.34s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

16-th round
average train loss 1.72 | test loss 1.95 | test acc: 0.285


100%|██████████| 6/6 [00:41<00:00,  6.88s/it]
100%|██████████| 6/6 [00:08<00:00,  1.34s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

17-th round
average train loss 2.38 | test loss 1.74 | test acc: 0.295


100%|██████████| 6/6 [00:55<00:00,  9.21s/it]
100%|██████████| 6/6 [00:08<00:00,  1.34s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

18-th round
average train loss 1.68 | test loss 1.86 | test acc: 0.287


100%|██████████| 6/6 [01:00<00:00, 10.14s/it]
100%|██████████| 6/6 [00:08<00:00,  1.34s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

19-th round
average train loss 1.64 | test loss 2.01 | test acc: 0.246


100%|██████████| 6/6 [01:02<00:00, 10.34s/it]
100%|██████████| 6/6 [00:08<00:00,  1.33s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

20-th round
average train loss 1.66 | test loss 1.64 | test acc: 0.385


100%|██████████| 6/6 [00:43<00:00,  7.21s/it]
100%|██████████| 6/6 [00:08<00:00,  1.33s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

21-th round
average train loss 1.37 | test loss 1.85 | test acc: 0.327


100%|██████████| 6/6 [00:56<00:00,  9.48s/it]
100%|██████████| 6/6 [00:07<00:00,  1.33s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

22-th round
average train loss 2.19 | test loss 1.69 | test acc: 0.368


100%|██████████| 6/6 [00:54<00:00,  9.02s/it]
100%|██████████| 6/6 [00:07<00:00,  1.33s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

23-th round
average train loss 1.34 | test loss 1.63 | test acc: 0.421


100%|██████████| 6/6 [00:37<00:00,  6.26s/it]
100%|██████████| 6/6 [00:07<00:00,  1.33s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

24-th round
average train loss 1.32 | test loss 1.66 | test acc: 0.396


100%|██████████| 6/6 [01:16<00:00, 12.73s/it]
100%|██████████| 6/6 [00:08<00:00,  1.34s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

25-th round
average train loss 0.781 | test loss 1.75 | test acc: 0.429


100%|██████████| 6/6 [01:08<00:00, 11.50s/it]
100%|██████████| 6/6 [00:08<00:00,  1.33s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

26-th round
average train loss 1.19 | test loss 1.82 | test acc: 0.408


100%|██████████| 6/6 [00:58<00:00,  9.81s/it]
100%|██████████| 6/6 [00:07<00:00,  1.33s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

27-th round
average train loss 1.19 | test loss 1.65 | test acc: 0.448


100%|██████████| 6/6 [01:01<00:00, 10.27s/it]
100%|██████████| 6/6 [00:07<00:00,  1.33s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

28-th round
average train loss 1.3 | test loss 2.07 | test acc: 0.404


100%|██████████| 6/6 [00:48<00:00,  8.00s/it]
100%|██████████| 6/6 [00:08<00:00,  1.34s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

29-th round
average train loss 1.42 | test loss 1.83 | test acc: 0.432


100%|██████████| 6/6 [01:06<00:00, 11.14s/it]
100%|██████████| 6/6 [00:08<00:00,  1.34s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

30-th round
average train loss 1.4 | test loss 1.67 | test acc: 0.481


100%|██████████| 6/6 [01:03<00:00, 10.51s/it]
100%|██████████| 6/6 [00:07<00:00,  1.33s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

31-th round
average train loss 1.24 | test loss 1.72 | test acc: 0.479


100%|██████████| 6/6 [01:00<00:00, 10.02s/it]
100%|██████████| 6/6 [00:07<00:00,  1.33s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

32-th round
average train loss 0.77 | test loss 1.9 | test acc: 0.467


100%|██████████| 6/6 [00:38<00:00,  6.39s/it]
100%|██████████| 6/6 [00:08<00:00,  1.35s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

33-th round
average train loss 0.977 | test loss 1.81 | test acc: 0.475


100%|██████████| 6/6 [00:56<00:00,  9.38s/it]
100%|██████████| 6/6 [00:08<00:00,  1.34s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

34-th round
average train loss 0.557 | test loss 2.03 | test acc: 0.462


100%|██████████| 6/6 [01:08<00:00, 11.43s/it]
100%|██████████| 6/6 [00:08<00:00,  1.34s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

35-th round
average train loss 0.723 | test loss 1.86 | test acc: 0.473


100%|██████████| 6/6 [00:53<00:00,  8.98s/it]
100%|██████████| 6/6 [00:08<00:00,  1.33s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

36-th round
average train loss 1.35 | test loss 2.11 | test acc: 0.494


100%|██████████| 6/6 [00:48<00:00,  8.15s/it]
100%|██████████| 6/6 [00:07<00:00,  1.33s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

37-th round
average train loss 1.03 | test loss 1.72 | test acc: 0.499


100%|██████████| 6/6 [01:01<00:00, 10.29s/it]
100%|██████████| 6/6 [00:08<00:00,  1.34s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

38-th round
average train loss 1 | test loss 1.79 | test acc: 0.470


100%|██████████| 6/6 [01:07<00:00, 11.32s/it]
100%|██████████| 6/6 [00:08<00:00,  1.33s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

39-th round
average train loss 1.11 | test loss 1.79 | test acc: 0.536


100%|██████████| 6/6 [01:20<00:00, 13.46s/it]
100%|██████████| 6/6 [00:08<00:00,  1.34s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

40-th round
average train loss 0.302 | test loss 2.02 | test acc: 0.487


100%|██████████| 6/6 [01:06<00:00, 11.08s/it]
100%|██████████| 6/6 [00:08<00:00,  1.34s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

41-th round
average train loss 0.748 | test loss 1.86 | test acc: 0.527


100%|██████████| 6/6 [01:09<00:00, 11.56s/it]
100%|██████████| 6/6 [00:07<00:00,  1.33s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

42-th round
average train loss 0.773 | test loss 1.96 | test acc: 0.526


100%|██████████| 6/6 [00:52<00:00,  8.73s/it]
100%|██████████| 6/6 [00:08<00:00,  1.34s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

43-th round
average train loss 0.608 | test loss 1.84 | test acc: 0.555


100%|██████████| 6/6 [01:15<00:00, 12.64s/it]
100%|██████████| 6/6 [00:08<00:00,  1.34s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

44-th round
average train loss 0.429 | test loss 1.85 | test acc: 0.564


100%|██████████| 6/6 [00:53<00:00,  8.90s/it]
100%|██████████| 6/6 [00:08<00:00,  1.34s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

45-th round
average train loss 0.762 | test loss 1.72 | test acc: 0.555


100%|██████████| 6/6 [01:13<00:00, 12.24s/it]
100%|██████████| 6/6 [00:08<00:00,  1.34s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

46-th round
average train loss 0.435 | test loss 1.66 | test acc: 0.572


100%|██████████| 6/6 [00:57<00:00,  9.60s/it]
100%|██████████| 6/6 [00:08<00:00,  1.34s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

47-th round
average train loss 0.96 | test loss 1.81 | test acc: 0.420


100%|██████████| 6/6 [01:00<00:00, 10.02s/it]
100%|██████████| 6/6 [00:08<00:00,  1.34s/it]
  0%|          | 0/6 [00:00<?, ?it/s]

48-th round
average train loss 0.392 | test loss 1.81 | test acc: 0.556


100%|██████████| 6/6 [01:07<00:00, 11.28s/it]
100%|██████████| 6/6 [00:08<00:00,  1.34s/it]


49-th round
average train loss 0.458 | test loss 1.77 | test acc: 0.554
