# Setup and Preparation

This is an example of using [NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html) to train an image classifier using federated averaging ([FedAvg](https://arxiv.org/abs/1602.05629))
and [PyTorch](https://pytorch.org/) as the deep learning training framework.


We will use the train script [cifar10_fl.py](src/cifar10_fl.py) and network [net.py](src/net.py) from the src directory.

The dataset will be [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset and will load its data within the client train code.

## Install NVIDIA FLARE and dependencies

Install nvflare and requirements


In [None]:
! pip install nvflare

In [None]:
! pip install -r code/requirements.txt

## Prepare Data

The CIFAR10 dataset has the following classes: ‘airplane’, ‘automobile’, ‘bird’, ‘cat’, ‘deer’, ‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck’.
The images in CIFAR-10 are of size 3x32x32, i.e. 3-channel color images of 32x32 pixels in size.

![image](code/img/cifar10.png)


Before we start the training, we will first need to prepare the data. 

### download the data

To avoid each job having to download and split the data, we add a step to prepare the data for all the cifar10 jobs. 

The CIFAR10 data will be downloaded to the common location. To make this process easiler, we wrote an simple download program like the followings

```python

import argparse
import torchvision.datasets as datasets

# default dataset path
CIFAR10_ROOT = "/tmp/nvflare/data/cifar10"

def define_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_path", type=str, default=CIFAR10_ROOT, nargs="?")
    args = parser.parse_args()
    return args

def main(args):
    datasets.CIFAR10(root=args.dataset_path, train=True, download=True)
    datasets.CIFAR10(root=args.dataset_path, train=False, download=True)


```


 The program just takes a root dataset_path and downloads the training and test datasets to the given root directory from the torchvision dataset. Let's run the code.

In [1]:
!python3 code/data/download.py

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /tmp/nvflare/data/cifar10/cifar-10-python.tar.gz
100%|████████████████████████████████████████| 170M/170M [00:05<00:00, 28.6MB/s]
Extracting /tmp/nvflare/data/cifar10/cifar-10-python.tar.gz to /tmp/nvflare/data/cifar10
Files already downloaded and verified


We can examine the downloaded data

In [2]:
!tree /tmp/nvflare/data/cifar10/cifar-10-batches-py/

[01;34m/tmp/nvflare/data/cifar10/cifar-10-batches-py/[0m
├── [00mbatches.meta[0m
├── [00mdata_batch_1[0m
├── [00mdata_batch_2[0m
├── [00mdata_batch_3[0m
├── [00mdata_batch_4[0m
├── [00mdata_batch_5[0m
├── [00mreadme.html[0m
└── [00mtest_batch[0m

0 directories, 8 files


### Split the data

In real-world scenarios, the data will be distributed among different clients/sites. Since we are simulating real-world data, we need to split the data into different clients/sites. How to split the data
depends on the type of problem or type of data. For simplicity, in this example we assume all clients will have the same data for horizontal federated learning cases.
Thus we do not do a data split, but rather point all clients to the same data location.












Next step, we will start to run training using simulation: [run pytorch federated learning job](../01.1_running_federated_learning_job/running_pytorch_fl_job.ipynb)
