# 2D and 3D Matrix Multiplication
## Setup
### Installation
```pip install ipyparallel```

or 

```pip install -e .[notebook]```

### Start cluster

```ipcluster start -n 4 --engines=MPI --profile mpi```

In [None]:
import ipyparallel as ipp
rc = ipp.Client(profile='mpi')
rc.wait_for_engines(4)
len(rc)

In [None]:
%%px
import torch
from mpi4py import MPI

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
print(f'Hello from rank {rank}!')

In [None]:
%%px
cart_comm = comm.Create_cart(dims=[2, 2], periods=[True, True], reorder=True)
print(f'Hello from rank {rank}! My coordinates are {cart_comm.Get_coords(rank)}')


In [None]:
%%px
print(f"Topo: {cart_comm.Get_topo()}")

In [None]:
%%px
row_comm = cart_comm.Sub([0,1])
row_global_ranks = row_comm.allgather(rank)
print(f'Hello from rank {rank}! In my row, the ranks are {row_global_ranks}')

col_comm = cart_comm.Sub([1,0])
col_global_ranks = col_comm.allgather(rank)
print(f'Hello from rank {rank}! In my col, the ranks are {col_global_ranks}')


In [None]:
%%px
def as_buffer(x: torch.Tensor):
    return MPI.buffer.fromaddress(x.untyped_storage().data_ptr(), 0)

## Matrix Vector

### y:= Ax

In [None]:
%%px
A_g = torch.arange(64).reshape(8, 8)
A_l = A_g[col_comm.Get_rank()::2, row_comm.Get_rank()::2].contiguous()
print("A_local:", A_l)

x_g = torch.arange(8).reshape(8, 1)
x_l = x_g[comm.Get_rank()::comm.Get_size(), :].contiguous()
print("x_local:", x_l)


# x_col = torch.zeros(4,1, dtype=torch.long)
# col_comm.Allgather((as_buffer(x_l), 2, MPI.LONG), (as_buffer(x_col), 2, MPI.LONG))
# x_col = x_col.reshape(2,2).T.reshape(4,1).contiguous()
# print("x_gather_col:", x_col)

x_col = torch.zeros(4,1, dtype=torch.long)
data_type = MPI.LONG.Create_vector(2, 1, 2).Create_resized(MPI.LONG.Get_extent()[0], MPI.LONG.Get_extent()[1]).Commit()
col_comm.Allgather((as_buffer(x_l), 2, MPI.LONG), (as_buffer(x_col), 1, data_type))
print("x_gather_col:", x_col)

y_l = A_l @ x_col
print("y_local:", y_l)

y_l = y_l.reshape(2,2).T.reshape(4,1).contiguous()
y_scatter = torch.zeros(2, 1, dtype=torch.long)
row_comm.Reduce_scatter((as_buffer(y_l), 4, MPI.LONG), (as_buffer(y_scatter), 2, MPI.LONG), [2,2], MPI.SUM)
print("y_scatter:", y_scatter)

y_end = torch.zeros(2,2,2, dtype=torch.long)
comm.Allgather((as_buffer(y_scatter), 2, MPI.LONG), (as_buffer(y_end), 2, MPI.LONG))
y_end = y_end.permute(2,1,0).reshape(8,1).contiguous()
print("y_end:", y_end)



In [None]:
%%px
print("Expected:", A_g @ x_g)
print("Actual:", y_end)
torch.allclose(y_end, A_g @ x_g)

### x = A.T * y

In [None]:
%%px
A_g = torch.arange(64).reshape(8, 8)
A_l = A_g[col_comm.Get_rank()::2, row_comm.Get_rank()::2].contiguous()
print("A_local:", A_l)

y_g = torch.arange(8).reshape(8, 1)
i = col_comm.Get_rank() + col_comm.Get_size() * row_comm.Get_rank()
y_l = y_g[i::comm.Get_size(), :].contiguous()
print("y_local:", y_l)

y_col = torch.zeros(4,1, dtype=torch.long)
row_comm.Allgather((as_buffer(y_l), 2, MPI.LONG), (as_buffer(y_col), 2, MPI.LONG))
y_col = y_col.reshape(2,2).T.reshape(4,1).contiguous()
print("y_gather_col:", y_col)

x_l = A_l.T @ y_col
print("x_local:", x_l)

x_l = x_l.reshape(2,2).T.reshape(4,1).contiguous()
x_scatter = torch.zeros(2, 1, dtype=torch.long)
col_comm.Reduce_scatter((as_buffer(x_l), 4, MPI.LONG), (as_buffer(x_scatter), 2, MPI.LONG), [2,2], MPI.SUM)
print("x_scatter:", x_scatter)


In [None]:
%%px
print("expected:", A_g.T @ y_g)

### A := y * x.T + A

In [None]:
%%px
A_g = torch.arange(64).reshape(8, 8)
A_l = A_g[col_comm.Get_rank()::2, row_comm.Get_rank()::2].contiguous()
print("A_local:", A_l)

x_g = torch.arange(8).reshape(8, 1)
x_l = x_g[comm.Get_rank()::comm.Get_size(), :].contiguous()
print("x_local:", x_l)

y_g = torch.arange(8).reshape(8, 1)
i = col_comm.Get_rank() + col_comm.Get_size() * row_comm.Get_rank()
y_l = y_g[i::comm.Get_size(), :].contiguous()
print("y_local:", y_l)

x_col = torch.zeros(4,1, dtype=torch.long)
col_comm.Allgather((as_buffer(x_l), 2, MPI.LONG), (as_buffer(x_col), 2, MPI.LONG))
x_col = x_col.reshape(2,2).T.reshape(4,1).contiguous()
print("x_gather_col:", x_col)

y_col = torch.zeros(4,1, dtype=torch.long)
row_comm.Allgather((as_buffer(y_l), 2, MPI.LONG), (as_buffer(y_col), 2, MPI.LONG))
y_col = y_col.reshape(2,2).T.reshape(4,1).contiguous()
print("y_gather_col:", y_col)

Z_l = y_col @ x_col.T + A_l
print("Z_local:", Z_l)




In [None]:
%%px
print(f"Expected: {y_g @ x_g.T + A_g}")

## Create_darray

In [None]:
%%px
if comm.Get_rank() == 0:
    A = torch.arange(64).reshape(8, 8)
    print(A)
    print(A.dtype)

    darray_type = MPI.LONG.Create_darray(
        4,                    # Size
        1,                      # Rank
        # 2,                      # number of array dimensions (as well as process grid dimensions)
        [64,1],                 # size of the global array
        [MPI.DISTRIBUTE_CYCLIC, MPI.DISTRIBUTE_NONE], # distribution type
        [1, 1], # distribution argument
        [4, 1],                 # size of the process grid
        MPI.ORDER_C,            # array storage order
    ).Commit()

    # comm.Send(buf=[as_buffer(A), 8, MPI.LONG], dest=1)
    comm.Send([as_buffer(A), 1, darray_type], dest=1, tag=55)

    darray_type.Free() 

elif comm.Get_rank() == 1:
    A = torch.zeros(4, 4, dtype=torch.int64)

    # comm.Recv(buf=[as_buffer(A), 8, MPI.LONG], source=0)
    comm.Recv([as_buffer(A), 16, MPI.LONG], source=0, tag=55)
    print(A)



## Large Item Count

In [None]:
%%px
import tensorcraft as tc
comm = MPI.COMM_WORLD
## Let's find the maximum element count on my mpi implementation
options = [torch.iinfo(torch.int32).max]

for possible_max in options:
    print(f"Trying {possible_max}")
    if comm.Get_rank() == 0:
        A = torch.ones(possible_max, dtype=torch.bool)
        print(A.dtype)
        print(A[:10])

        print(f"Sending {possible_max} elements, {possible_max / 10**9} Gb")

    else:
        A = torch.zeros(possible_max, dtype=torch.bool)

    comm.Bcast(buf=[tc.mpi4torch.as_buffer(A), possible_max, MPI.BOOL], root=0)

    if comm.Get_rank() == 0:
        print("Sent!")
    else:
        print("Received!")
        print(A[:10])

# Interweave allgather

In [1]:
import ipyparallel as ipp
rc = ipp.Client(profile='mpi')
rc.wait_for_engines(2)
len(rc)

  0%|          | 0/2 [00:00<?, ?engine/s]

4

In [2]:
%%px
import torch
from mpi4py import MPI

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
print(f'Hello from rank {rank}!')

[stdout:1] Hello from rank 1!


[stdout:2] Hello from rank 2!


[stdout:3] Hello from rank 3!


[stdout:0] Hello from rank 0!


In [8]:
%%px
import tensorcraft as tc

x = torch.arange(90).reshape(18, 5)
mesh = torch.Size([4])
dist = tc.dist.MultiAxisDist(mesh, ((0,), None), 3)

x_local = dist.apply(x, rank)
print(x_local)
print(x_local.shape)
print(x_local.dtype)
print(x_local.is_contiguous())

[stdout:0] tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [60, 61, 62, 63, 64],
        [65, 66, 67, 68, 69],
        [70, 71, 72, 73, 74]])
torch.Size([6, 5])
torch.int64
True


[stdout:2] tensor([[30, 31, 32, 33, 34],
        [35, 36, 37, 38, 39],
        [40, 41, 42, 43, 44]])
torch.Size([3, 5])
torch.int64
True


[stdout:1] tensor([[15, 16, 17, 18, 19],
        [20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29],
        [75, 76, 77, 78, 79],
        [80, 81, 82, 83, 84],
        [85, 86, 87, 88, 89]])
torch.Size([6, 5])
torch.int64
True


[stdout:3] tensor([[45, 46, 47, 48, 49],
        [50, 51, 52, 53, 54],
        [55, 56, 57, 58, 59]])
torch.Size([3, 5])
torch.int64
True


In [14]:
%%px
gathered = torch.zeros(4, 2, 3, 5, dtype=torch.int64)

send_buf = tc.util.tensor2mpiBuffer(x_local)
print(send_buf)
recv_buf = tc.util.as_buffer(gathered)
print(recv_buf)

comm.Allgather(send_buf, (recv_buf, 30, MPI.LONG))
print(gathered)


[stdout:3] (<mpi4py.MPI.buffer object at 0x7beeb8cd6b80>, 15, <mpi4py.MPI.Datatype object at 0x7befb9d10b70>)
<mpi4py.MPI.buffer object at 0x7beeb8cd6aa0>
tensor([[[[ 0,  1,  2,  3,  4],
          [ 5,  6,  7,  8,  9],
          [10, 11, 12, 13, 14]],

         [[60, 61, 62, 63, 64],
          [65, 66, 67, 68, 69],
          [70, 71, 72, 73, 74]]],


        [[[15, 16, 17, 18, 19],
          [20, 21, 22, 23, 24],
          [25, 26, 27, 28, 29]],

         [[75, 76, 77, 78, 79],
          [80, 81, 82, 83, 84],
          [85, 86, 87, 88, 89]]],


        [[[30, 31, 32, 33, 34],
          [35, 36, 37, 38, 39],
          [40, 41, 42, 43, 44]],

         [[ 0,  0,  0,  0,  0],
          [ 0,  0,  0,  0,  0],
          [ 0,  0,  0,  0,  0]]],


        [[[45, 46, 47, 48, 49],
          [50, 51, 52, 53, 54],
          [55, 56, 57, 58, 59]],

         [[ 0,  0,  0,  0,  0],
          [ 0,  0,  0,  0,  0],
          [ 0,  0,  0,  0,  0]]]])


[stdout:0] (<mpi4py.MPI.buffer object at 0x7b8c937cebf0>, 30, <mpi4py.MPI.Datatype object at 0x7b8d949e4c60>)
<mpi4py.MPI.buffer object at 0x7b8c937cea30>
tensor([[[[ 0,  1,  2,  3,  4],
          [ 5,  6,  7,  8,  9],
          [10, 11, 12, 13, 14]],

         [[60, 61, 62, 63, 64],
          [65, 66, 67, 68, 69],
          [70, 71, 72, 73, 74]]],


        [[[15, 16, 17, 18, 19],
          [20, 21, 22, 23, 24],
          [25, 26, 27, 28, 29]],

         [[75, 76, 77, 78, 79],
          [80, 81, 82, 83, 84],
          [85, 86, 87, 88, 89]]],


        [[[30, 31, 32, 33, 34],
          [35, 36, 37, 38, 39],
          [40, 41, 42, 43, 44]],

         [[ 0,  0,  0,  0,  0],
          [ 0,  0,  0,  0,  0],
          [ 0,  0,  0,  0,  0]]],


        [[[45, 46, 47, 48, 49],
          [50, 51, 52, 53, 54],
          [55, 56, 57, 58, 59]],

         [[ 0,  0,  0,  0,  0],
          [ 0,  0,  0,  0,  0],
          [ 0,  0,  0,  0,  0]]]])


[stdout:1] (<mpi4py.MPI.buffer object at 0x7591129f6790>, 30, <mpi4py.MPI.Datatype object at 0x759213b649f0>)
<mpi4py.MPI.buffer object at 0x7591129f69c0>
tensor([[[[ 0,  1,  2,  3,  4],
          [ 5,  6,  7,  8,  9],
          [10, 11, 12, 13, 14]],

         [[60, 61, 62, 63, 64],
          [65, 66, 67, 68, 69],
          [70, 71, 72, 73, 74]]],


        [[[15, 16, 17, 18, 19],
          [20, 21, 22, 23, 24],
          [25, 26, 27, 28, 29]],

         [[75, 76, 77, 78, 79],
          [80, 81, 82, 83, 84],
          [85, 86, 87, 88, 89]]],


        [[[30, 31, 32, 33, 34],
          [35, 36, 37, 38, 39],
          [40, 41, 42, 43, 44]],

         [[ 0,  0,  0,  0,  0],
          [ 0,  0,  0,  0,  0],
          [ 0,  0,  0,  0,  0]]],


        [[[45, 46, 47, 48, 49],
          [50, 51, 52, 53, 54],
          [55, 56, 57, 58, 59]],

         [[ 0,  0,  0,  0,  0],
          [ 0,  0,  0,  0,  0],
          [ 0,  0,  0,  0,  0]]]])


[stdout:2] (<mpi4py.MPI.buffer object at 0x7f466f792640>, 15, <mpi4py.MPI.Datatype object at 0x7f476e1dca50>)
<mpi4py.MPI.buffer object at 0x7f466cfca9c0>
tensor([[[[ 0,  1,  2,  3,  4],
          [ 5,  6,  7,  8,  9],
          [10, 11, 12, 13, 14]],

         [[60, 61, 62, 63, 64],
          [65, 66, 67, 68, 69],
          [70, 71, 72, 73, 74]]],


        [[[15, 16, 17, 18, 19],
          [20, 21, 22, 23, 24],
          [25, 26, 27, 28, 29]],

         [[75, 76, 77, 78, 79],
          [80, 81, 82, 83, 84],
          [85, 86, 87, 88, 89]]],


        [[[30, 31, 32, 33, 34],
          [35, 36, 37, 38, 39],
          [40, 41, 42, 43, 44]],

         [[ 0,  0,  0,  0,  0],
          [ 0,  0,  0,  0,  0],
          [ 0,  0,  0,  0,  0]]],


        [[[45, 46, 47, 48, 49],
          [50, 51, 52, 53, 54],
          [55, 56, 57, 58, 59]],

         [[ 0,  0,  0,  0,  0],
          [ 0,  0,  0,  0,  0],
          [ 0,  0,  0,  0,  0]]]])


In [12]:
%%px
print(gathered.permute(1, 0, 2, 3).reshape(24, 5))

[stdout:1] tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19],
        [20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34],
        [35, 36, 37, 38, 39],
        [40, 41, 42, 43, 44],
        [45, 46, 47, 48, 49],
        [50, 51, 52, 53, 54],
        [55, 56, 57, 58, 59],
        [60, 61, 62, 63, 64],
        [65, 66, 67, 68, 69],
        [70, 71, 72, 73, 74],
        [75, 76, 77, 78, 79],
        [80, 81, 82, 83, 84],
        [85, 86, 87, 88, 89],
        [ 0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0]])


[stdout:3] tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19],
        [20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34],
        [35, 36, 37, 38, 39],
        [40, 41, 42, 43, 44],
        [45, 46, 47, 48, 49],
        [50, 51, 52, 53, 54],
        [55, 56, 57, 58, 59],
        [60, 61, 62, 63, 64],
        [65, 66, 67, 68, 69],
        [70, 71, 72, 73, 74],
        [75, 76, 77, 78, 79],
        [80, 81, 82, 83, 84],
        [85, 86, 87, 88, 89],
        [ 0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0]])


[stdout:0] tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19],
        [20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34],
        [35, 36, 37, 38, 39],
        [40, 41, 42, 43, 44],
        [45, 46, 47, 48, 49],
        [50, 51, 52, 53, 54],
        [55, 56, 57, 58, 59],
        [60, 61, 62, 63, 64],
        [65, 66, 67, 68, 69],
        [70, 71, 72, 73, 74],
        [75, 76, 77, 78, 79],
        [80, 81, 82, 83, 84],
        [85, 86, 87, 88, 89],
        [ 0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0]])


[stdout:2] tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19],
        [20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34],
        [35, 36, 37, 38, 39],
        [40, 41, 42, 43, 44],
        [45, 46, 47, 48, 49],
        [50, 51, 52, 53, 54],
        [55, 56, 57, 58, 59],
        [60, 61, 62, 63, 64],
        [65, 66, 67, 68, 69],
        [70, 71, 72, 73, 74],
        [75, 76, 77, 78, 79],
        [80, 81, 82, 83, 84],
        [85, 86, 87, 88, 89],
        [ 0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0]])
