# Calculate CIFAR10 Image Histogram

Before one training the image classifer, the pytorch example follows the following steps: 

* **Prepare Data**
    * Load and normalize the CIFAR10 training and test datasets using torchvision
    
* **Training**
    * Define a Convolutional Neural Network
    * Define a loss function
    * Train the network on the training data
    * Test the network on the test data
    
We will add another step to calculate the data historgram and compare the local (site) histogram and global historgrams. So the above steps become


* **Prepare Data**
    * Load and normalize the CIFAR10 training and test datasets using torchvision

* **Data Statistics**
    * Calculate data stastics: image intensity histograms
    
* **Training**
    * Define a Convolutional Neural Network
    * Define a loss function
    * Train the network on the training data
    * Test the network on the test data



## Setup NVFLARE

Follow [Getting Started](https://nvflare.readthedocs.io/en/main/getting_started.html) to set up a virtual environment and install NVFLARE.

You can also follow this [notebook](https://github.com/NVIDIA/NVFlare/blob/main/examples/nvflare_setup.ipynb) to get set up.

> Make sure you have installed nvflare from **terminal** 


## Install requirements
assuming the current directory is 'cifar10/stats'

In [None]:
!pwd

In [None]:
%pip install -r requirements.txt

## Prepare Data

Generally, when you have to deal with image, text, audio or video data, you can use standard python packages that load data into a numpy array. Then you can convert this array into a torch.*Tensor. Torch provied a package called torchvision, that has data loaders for common datasets such as ImageNet, CIFAR10, MNIST, etc. and data transformers for images, viz., torchvision.datasets and torch.utils.data.DataLoader.

For CIFAR10 dataset, it has the classes: ‘airplane’, ‘automobile’, ‘bird’, ‘cat’, ‘deer’, ‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck’. 
The images in CIFAR-10 are of size 3x32x32, i.e. 3-channel color images of 32x32 pixels in size.

![cifar10](../data/cifar10.png)


The output of torchvision datasets are PILImage images of range [0, 1]. We transform them to Tensors of normalized range [-1, 1].





In [None]:
import torch
import torchvision
import torchvision.transforms as transforms

CIFAR10_ROOT = "/tmp/nvflare/data/cifar10"

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 6

trainset = torchvision.datasets.CIFAR10(root=CIFAR10_ROOT, train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root=CIFAR10_ROOT, train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Once you have extract the data from zip file, you can check the directory

In [None]:
ls -al {CIFAR10_ROOT}

Lets explore the data, first

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# functions to show an image

def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


# get some random training images
dataiter = iter(trainloader)

images, labels = next(dataiter)

# show images
imshow(torchvision.utils.make_grid(images))

# print labels
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))

# dimension and shapes
 
# Display image and label.
print(f"\nFeature batch shape: {images.size()}")
print(f"Labels batch shape: {labels.size()} \n")


print("train datasize =", len(trainset))
print("test datasize =", len(testset))
    

We can see the images has shape of [batch, channel, rows, cols] = [6,3,32,32]

## Download data in script
We have prepared python script to download the data as well. 
```
python ../data/download.py  --dataset_path <data_path>
```
if dataset_path is not specified, it default to CIFAR10_ROOT

In [None]:
! python ../data/download.py


## Create Local Image Intensity Histogram Calculator

We ignored all other statistics calculations (mean, stddev etc. as they don't apply). all methods have default implementations.


In [None]:
from typing import Dict, List

import numpy as np
import torchvision

from nvflare.apis.fl_context import FLContext
from nvflare.app_common.abstract.statistics_spec import Bin, DataType, Feature, Histogram, HistogramType, Statistics
 
# the dataset path    
CIFAR10_ROOT = "/tmp/nvflare/data/cifar10"


class ImageStatistics(Statistics):

    def __init__(self, 
                 data_root: str = CIFAR10_ROOT, 
                 batch_size: int = 4):
        """local image intensity calculator.

        Args:
            dataset_path: directory with local image data.
         Returns:
            Histogram of local statistics`
        """
        super().__init__()
        self.dataset_path = data_root
        self.batch_size = batch_size
        
        # there are three color channels : RGB, each corresponding to each channel index
        # we are going treat each channel as one feature, the feature Ids are corresponding to tensor channel index. 
        # The feature name is named "red", "gree", "blue" (RGB). 
        
        self.features_ids = { "red": 0, "green": 1,"blue": 2}
        self.image_features  = [Feature("red", DataType.FLOAT),
                                Feature("green", DataType.FLOAT),
                                Feature("blue", DataType.FLOAT)]
        self.dataset_lengths = {}
        self.loaders = {}

        self.client_name = None
        self.fl_ctx = None



    def initialize(self, fl_ctx: FLContext):

        # FLContext is context information for the client side NVFLARE engine. 
        # it includes many runtime information. 
        # Here we only interested in client site name. 
        # fl_ctx.get_identity_name() will return the client's name
        
        self.fl_ctx = fl_ctx
        self.client_name = "local_client" if fl_ctx is None else fl_ctx.get_identity_name()
        
        transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

        trainset = torchvision.datasets.CIFAR10(root=self.dataset_path, train=True, download=True, transform=transform)
        testset = torchvision.datasets.CIFAR10(root=self.dataset_path, train=False,  download=True, transform=transform)
        self.dataset_lengths = {"train": len(trainset), "test":len(testset)}
        
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=self.batch_size, shuffle=True, num_workers=2)
        testloader = torch.utils.data.DataLoader(testset, batch_size=self.batch_size, shuffle=False, num_workers=2)
        self.loaders = {"train": trainloader, "test": testloader}

    def features(self) -> Dict[str, List[Feature]]:
        return {"train": self.image_features, 
                "test":  self.image_features}

    def count(self, dataset_name: str, feature_name: str) -> int:
        return self.dataset_lengths[dataset_name]
        
 
    def histogram(self,
                  dataset_name: str,
                  feature_name: str, 
                  num_of_bins: int, 
                  global_min_value: float, 
                  global_max_value: float) -> Histogram:
     
        print(f"calculating image intensity histogram for client {self.client_name}")
        channel = self.features_ids[feature_name]
        
        # get the inputs; data is a list of [inputs, labels]
        histogram_bins: List[Bin] = []
        bin_edges = []
        histogram = np.zeros(num_of_bins, dtype=float)

        for inputs, _ in self.loaders[dataset_name]:
            for img in inputs:
                counts, bin_edges = np.histogram(img[channel, : , :],
                                                 bins=num_of_bins,
                                                 range=(global_min_value, global_max_value))
                histogram += counts

        for i in range(num_of_bins):
            low_value = bin_edges[i]
            high_value = bin_edges[i + 1]
            bin_sample_count = histogram[i]
            histogram_bins.append(Bin(low_value=low_value, high_value=high_value, sample_count=bin_sample_count))

        return Histogram(HistogramType.STANDARD, histogram_bins)




Let's test if the code works. 

In [None]:
hist_cal = ImageStatistics()

hist_cal.initialize(fl_ctx = None)
features = hist_cal.features()

In [None]:
hist_cal.count("train", "red")

In [None]:
hist_cal.histogram("train", "red", 20, 0, 256)

The code is working. Let's setup NVFLARE job in federated computing. 

## Create Federated Histogram Job
We are going to use NVFLARE job cli to create job. For detailed instructions on Job CLI, please follow the [job cli tutorial](https://github.com/NVIDIA/NVFlare/blob/main/examples/tutorials/job_cli.ipynb).

Let's check the available job templates, we are going to use one of the existing job template and modify to fit our needs. 
The job template is nothing but server and client-side job configurations.


In [None]:
! nvflare job list_templates

there is "stats_image" job template, which what we need. We are going to use that. Now, use ```nvflare job create``` command
Let's try using the default values


In [None]:
! nvflare job create -w stats_image -j /tmp/nvflare/jobs/stats_image 

The default seems to be ok, execept for few values: 
* data_root=/tmp/nvflare/image_stats/data, it not the same as CIFAR10_ROOT
* meta.conf's min_clients should be 2, as we are going to use two clients

Now let's look closer at the configurations, in particular, the client side, and make sure it matches to the new class we just created

>Note: 
In the upcoming sections, we'll utilize the 'tree' command. To install this command on a Linux system, you can use the sudo apt install tree command. As an alternative to 'tree', you can use the ls -al command.

In [None]:
! tree /tmp/nvflare/jobs/stats_image 

In [None]:
! nvflare job create -w stats_image -j /tmp/nvflare/jobs/stats_image -force \
-f meta.conf min_clients=2 \
-f config_fed_client.conf app_script=image_statistics.py data_root={CIFAR10_ROOT} \
-f config_fed_server.conf bins=20  \
-sd ./ \
-debug

In [None]:
! tree /tmp/nvflare/jobs/stats_image 

We can see we also copied some files doesn't belong there. Let's remove them. 

> **Note**: if your histogram calculator file name is not "image_statistics.py" but other file name such as "image_histogram.py" and the class name is not 'ImageStatistics' but "HistogramStatistics", you will need to manually edit the config_fed_client.conf to replace 'image_statistics.ImageStatistics' with 'image_histogram.HistogramStatistics'. Any other adjustments that your need to change, you can directly edit the server and client configurations


Now we have created the job folder. we are ready to run job

## Run Job in FL Simulator

**Run Job using Simulator CLI**


In [None]:
! nvflare simulator /tmp/nvflare/jobs/stats_image -w /tmp/nvflare/image_stats -n 2 -t 2



**Examine the result**

Notice the result is written at 

**/tmp/nvflare/image_stats/simulate_job/statistics/image_statistics.json**

In [None]:
!ls -al /tmp/nvflare/image_stats/simulate_job/statistics/image_statistics.json

## Visualization
We can visualize the results easly via the visualizaiton notebook. Before we do that, we need to copy the data to the notebook directory 


In [None]:
! cp /tmp/nvflare/image_stats/simulate_job/statistics/image_statistics.json ./.

In [None]:

import json
import pandas as pd
from nvflare.app_opt.statistics.visualization.statistics_visualization import Visualization
with open('image_statistics.json', 'r') as f:
    data = json.load(f)

vis = Visualization()
vis.show_stats(data = data)

In [None]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:100%  depth:100% !important; }</style>"))

In [None]:
vis.show_histograms(data = data, plot_type="main")

The global and local histograms differences are none as we are using the same dataset for all clients. 

## We are done !
Congratulations! you have just completed the federated stats image histogram calulation. 

If you would like to see another example of federated statistics calculations and configurations, please checkout [federated_statistics](https://github.com/NVIDIA/NVFlare/tree/main/examples/advanced/federated-statistics) and [fed_stats with spleen_ct_segmentation](https://github.com/NVIDIA/NVFlare/tree/main/integration/monai/examples/spleen_ct_segmentation_sim)

Let's move on to the next example and see how can we train the image classifier using pytorch with CIFAR10 data.


