# MPI Custom Datatypes + Torch 

## Setup
### Installation
```pip install ipyparallel```

or 

```pip install -e .[notebook]```

### Start cluster

```ipcluster start -n 4 --engines=MPI --profile mpi```

In [1]:
import ipyparallel as ipp
rc = ipp.Client(profile='mpi')
rc.wait_for_engines(4)
len(rc)

4

In [2]:
%%px
import torch
from mpi4py import MPI

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
print(f'Hello from rank {rank}!')

[stdout:1] Hello from rank 1!


[stdout:3] Hello from rank 3!


[stdout:0] Hello from rank 0!


[stdout:2] Hello from rank 2!


## Large item count

In [None]:
%%px
comm = MPI.COMM_WORLD
## Let's find the maximum element count on my mpi implementation
options = [torch.iinfo(torch.int32).max, torch.iinfo(torch.uint32).max, torch.iinfo(torch.int64).max, torch.iinfo(torch.uint64).max]

for possible_max in options:
    print(f"Trying {possible_max}")
    if comm.Get_rank() == 0:
        A = torch.randn(possible_max)
        print(A.dtype)
        print(A[:10])

        print(f"Sent {possible_max} elements")

    else:
        A = torch.zeros(possible_max, dtype=torch.float32)

    MPI.broadcast(buf=[as_buffer(A), possible_max, MPI.FLOAT], root=0)

    if comm.Get_rank() == 0:
        print("Sent!")
    else:
        print("Received!")
        print(A[:10])


[stdout:3] Trying 2147483647


[stdout:2] Trying 2147483647


[stdout:1] Trying 2147483647


[stdout:0] Trying 2147483647


[0:execute]
[0;31m---------------------------------------------------------------------------[0m
[0;31mRuntimeError[0m                              Traceback (most recent call last)
Cell [0;32mIn[3], line 8[0m
[1;32m      6[0m [38;5;28mprint[39m([38;5;124mf[39m[38;5;124m"[39m[38;5;124mTrying [39m[38;5;132;01m{[39;00mpossible_max[38;5;132;01m}[39;00m[38;5;124m"[39m)
[1;32m      7[0m [38;5;28;01mif[39;00m comm[38;5;241m.[39mGet_rank() [38;5;241m==[39m [38;5;241m0[39m:
[0;32m----> 8[0m     A [38;5;241m=[39m [43mtorch[49m[38;5;241;43m.[39;49m[43mrandn[49m[43m([49m[43mpossible_max[49m[43m,[49m[43m [49m[43mdtype[49m[38;5;241;43m=[39;49m[43mtorch[49m[38;5;241;43m.[39;49m[43mint32[49m[43m)[49m
[1;32m      9[0m     [38;5;28mprint[39m(A[38;5;241m.[39mdtype)
[1;32m     10[0m     [38;5;28mprint[39m(A[:[38;5;241m10[39m])

[0;31mRuntimeError[0m: "normal_kernel_cpu" not implemented for 'Int'


%px:   0%|          | 0/4 [00:00<?, ?tasks/s]

AlreadyDisplayedError: 4 errors

## Torch 2 Datatype

Attempt to create a generic function that, given any torch tensor (contiguous, non-contiguous, view, stride, ...), returns a datatype that can write the send/recv the data there. 

In [None]:
%load_ext autoreload
%autoreload 2

# Contiguous
x = torch.arange(64).reshape(8, 8)
print(x)
_ = tc.mpi4torch.tensor2mpiBuffer(x)



[juan-20w000p2ge:439968] shmem: mmap: an error occurred while determining whether or not /tmp/ompi.juan-20w000p2ge.1000/jf.0/3508994048/shared_mem_cuda_pool.juan-20w000p2ge could be created.
[juan-20w000p2ge:439968] create_and_attach: unable to create shared memory BTL coordinating structure :: size 134217728 


AttributeError: 'torch.dtype' object has no attribute 'max'

In [None]:
# Non-Contiguous, permute
x = torch.arange(64).reshape(8, 8).permute(1, 0)
print(x)
_ = tc.mpi4torch.tensor2mpiBuffer(x)

In [None]:
# Non-Contiguous, slicing
x = torch.arange(64).reshape(8, 8)[:, 1::2]
print(x)
_ = tc.mpi4torch.tensor2mpiBuffer(x)

In [None]:
# Non-Contiguous, slicing and permute
x = torch.arange(64).reshape(8, 8)[4:, 1::2].permute(1, 0)
print(x)
_ = tc.mpi4torch.tensor2mpiBuffer(x)

In [None]:
import torch
torch.randn(10,10)