In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

# Introduction to MPI and `mpi4py`

MPI stands for Message Passing Interface. It is a **library** that allows to:
- **spawn** several processes 
- **adress** them individually
- have them **communicate** between them

MPI can be used in many languages (C, C++, Fortran), and is extensively used in High-Performance Computing.  
`mpi4py` is the Python interface to MPI.

### Installation

```
conda install -c conda-forge mpi4py
```
(The standard Anaconda channel for mpi4py is [broken](https://github.com/conda/conda/issues/2277). It is necessary to use the [conda-forge](https://conda-forge.github.io/) channel instead.)

### Example

Let us try to get a feeling on how `mpi4py` works by looking at the example below:

In [None]:
%%file example.py

from mpi4py.MPI import COMM_WORLD as communicator
import random

# Draw one random integer between 0 and 100
i = random.randint(0, 100)
print('Rank %d' %communicator.rank + ' drew a random integer: %d' %i )

# Gather the results
integer_list = communicator.gather( i, root=0 )
if communicator.rank == 0:
    print('\nRank 0 gathered the results:')
    print(integer_list)

In [None]:
! mpirun -np 3 python example.py

### What happened?

- "`mpirun -np 3`" spawns 3 processes.

- **All processes execute the same code.** (In this case, they all execute the same Python script: `example.py`.)

- Each process gets a **unique identification number** (`communicator.rank`).

- **Based on this identifier** and e.g. **on `if` statements**, the different processes can be **addressed individually**, and perform different work.

- MPI provides functions (like `communicator.gather`) that allow processes to **communicate** data (even between **different nodes**).

NB: There are many other communication functions, e.g.:
- one-to-one communication (`send`, `receive`, `isend`, `ireceive`)
- all-to-one communication (`gather`, `reduce`)
- one-to-all communication (`scatter`, `broadcast`)
- all-to-all (`allgather`, `allreduce`)

See the [mpi4py documentation](https://mpi4py.readthedocs.io/en/stable/) for more information.

----
# Digit classification with `mpi4py`

## On two processes

In [None]:
%%file parallel_script.py

from classification import nearest_neighbor_prediction
import numpy as np
from mpi4py.MPI import COMM_WORLD as communicator

# Load data
train_images = np.load('./data/train_images.npy')
train_labels = np.load('./data/train_labels.npy')
test_images = np.load('./data/test_images.npy')

# Use only the data that this rank needs
N_test = len(test_images)
if communicator.rank == 0:
    i_start = 0
    i_end = N_test/2
elif communicator.rank == 1:
    i_start = N_test/2
    i_end = N_test    
small_test_images = test_images[i_start:i_end]

# Predict the results and gather it on rank 0
small_test_labels = nearest_neighbor_prediction(small_test_images, train_images, train_labels)
test_labels_list = communicator.gather( small_test_labels, root=0 )

# Rank 0 merges the results into one array and saves them to a file
if communicator.rank == 0:
    test_labels = np.hstack( test_labels_list )
    np.save('./data/test_labels_parallel.npy', test_labels )

In [None]:
%%time
! mpirun -np 2 python parallel_script.py

## On more processes

In [None]:
import numpy as np

In [None]:
# Load and split the set of test images
test_images = np.load('data/test_images.npy')
split_arrays = np.array_split( test_images, 4 )

# Print the corresponding shape
print( 'Shape of the original array:' )
print( test_images.shape )
print('Shape of the splitted arrays:')
for array in split_arrays:
    print( array.shape )

In [None]:
%%file parallel_script.py

from classification import nearest_neighbor_prediction
import numpy as np
from mpi4py.MPI import COMM_WORLD as communicator

# Load data
train_images = np.load('./data/train_images.npy')
train_labels = np.load('./data/train_labels.npy')
test_images = np.load('./data/test_images.npy')

# Split the data
# Select the array that is relevant for this rank
split_arrays = np.array_split( test_images, communicator.size )
small_test_images = split_arrays[ communicator.rank ]

# Predict the results and gather it on rank 0
small_test_labels = nearest_neighbor_prediction(small_test_images, train_images, train_labels)
test_labels_list = communicator.gather( small_test_labels, root=0 )

# Rank 0 merges the results into one array and saves them to a file
if communicator.rank == 0:
    test_labels = np.hstack( test_labels_list )
    np.save('./data/test_labels_parallel.npy', test_labels )

In [None]:
%%time
! mpirun -np 4 python parallel_script.py

---
# Check the results

In [None]:
# Load the data from the file
test_images = np.load('data/test_images.npy')
test_labels_parallel = np.load('data/test_labels_parallel.npy')

# Define function to have a look at the data
def show_random_digit( images, labels=None ):
    """"Show a random image out of `images`, 
    with the corresponding label if available"""
    i = np.random.randint(len(images))
    image = images[i].reshape((28, 28))
    plt.imshow( image, cmap='Greys' )
    if labels is not None:
        plt.title('Label: %d' %labels[i])

In [None]:
show_random_digit( test_images, test_labels_parallel )