# 2D and 3D Matrix Multiplication
## Setup
### Installation
```pip install ipyparallel```

or 

```pip install -e .[notebook]```

### Start cluster

```ipcluster start -n 4 --engines=MPI --profile mpi```

In [1]:
import ipyparallel as ipp
rc = ipp.Client(profile='mpi')
rc.wait_for_engines(4)
len(rc)

4

In [2]:
%%px
import torch
from mpi4py import MPI

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
print(f'Hello from rank {rank}!')

%px:   0%|          | 0/4 [00:00<?, ?tasks/s]

[stdout:1] Hello from rank 1!


[stdout:2] Hello from rank 2!


[stdout:0] Hello from rank 0!


[stdout:3] Hello from rank 3!


In [3]:
%%px
cart_comm = comm.Create_cart(dims=[2, 2], periods=[True, True], reorder=True)
print(f'Hello from rank {rank}! My coordinates are {cart_comm.Get_coords(rank)}')


[stdout:0] Hello from rank 0! My coordinates are [0, 0]


[stdout:2] Hello from rank 2! My coordinates are [1, 0]


[stdout:3] Hello from rank 3! My coordinates are [1, 1]


[stdout:1] Hello from rank 1! My coordinates are [0, 1]


In [4]:
%%px
print(f"Topo: {cart_comm.Get_topo()}")

[stdout:0] Topo: ([2, 2], [1, 1], [0, 0])


[stdout:2] Topo: ([2, 2], [1, 1], [1, 0])


[stdout:1] Topo: ([2, 2], [1, 1], [0, 1])


[stdout:3] Topo: ([2, 2], [1, 1], [1, 1])


In [5]:
%%px
row_comm = cart_comm.Sub([0,1])
row_global_ranks = row_comm.allgather(rank)
print(f'Hello from rank {rank}! In my row, the ranks are {row_global_ranks}')

col_comm = cart_comm.Sub([1,0])
col_global_ranks = col_comm.allgather(rank)
print(f'Hello from rank {rank}! In my col, the ranks are {col_global_ranks}')


[stdout:1] Hello from rank 1! In my row, the ranks are [0, 1]
Hello from rank 1! In my col, the ranks are [1, 3]


[stdout:0] Hello from rank 0! In my row, the ranks are [0, 1]
Hello from rank 0! In my col, the ranks are [0, 2]


[stdout:2] Hello from rank 2! In my row, the ranks are [2, 3]
Hello from rank 2! In my col, the ranks are [0, 2]


[stdout:3] Hello from rank 3! In my row, the ranks are [2, 3]
Hello from rank 3! In my col, the ranks are [1, 3]


In [6]:
%%px
def as_buffer(x: torch.Tensor):
    return MPI.buffer.fromaddress(x.untyped_storage().data_ptr(), 0)

## Matrix Vector

### y:= Ax

In [7]:
%%px
A_g = torch.arange(64).reshape(8, 8)
A_l = A_g[col_comm.Get_rank()::2, row_comm.Get_rank()::2].contiguous()
print("A_local:", A_l)

x_g = torch.arange(8).reshape(8, 1)
x_l = x_g[comm.Get_rank()::comm.Get_size(), :].contiguous()
print("x_local:", x_l)


x_col = torch.zeros(4,1, dtype=torch.long)
col_comm.Allgather((as_buffer(x_l), 2, MPI.LONG), (as_buffer(x_col), 2, MPI.LONG))
x_col = x_col.reshape(2,2).T.reshape(4,1).contiguous()
print("x_gather_col:", x_col)

y_l = A_l @ x_col
print("y_local:", y_l)

y_l = y_l.reshape(2,2).T.reshape(4,1).contiguous()
y_scatter = torch.zeros(2, 1, dtype=torch.long)
row_comm.Reduce_scatter((as_buffer(y_l), 4, MPI.LONG), (as_buffer(y_scatter), 2, MPI.LONG), [2,2], MPI.SUM)
print("y_scatter:", y_scatter)

y_end = torch.zeros(2,2,2, dtype=torch.long)
comm.Allgather((as_buffer(y_scatter), 2, MPI.LONG), (as_buffer(y_end), 2, MPI.LONG))
y_end = y_end.permute(2,1,0).reshape(8,1).contiguous()
print("y_end:", y_end)



[stdout:2] A_local: tensor([[ 8, 10, 12, 14],
        [24, 26, 28, 30],
        [40, 42, 44, 46],
        [56, 58, 60, 62]])
x_local: tensor([[2],
        [6]])
x_gather_col: tensor([[0],
        [2],
        [4],
        [6]])
y_local: tensor([[152],
        [344],
        [536],
        [728]])
y_scatter: tensor([[ 364],
        [1260]])
y_end: tensor([[ 140],
        [ 364],
        [ 588],
        [ 812],
        [1036],
        [1260],
        [1484],
        [1708]])


[stdout:1] A_local: tensor([[ 1,  3,  5,  7],
        [17, 19, 21, 23],
        [33, 35, 37, 39],
        [49, 51, 53, 55]])
x_local: tensor([[1],
        [5]])
x_gather_col: tensor([[1],
        [3],
        [5],
        [7]])
y_local: tensor([[ 84],
        [340],
        [596],
        [852]])
y_scatter: tensor([[ 588],
        [1484]])
y_end: tensor([[ 140],
        [ 364],
        [ 588],
        [ 812],
        [1036],
        [1260],
        [1484],
        [1708]])


[stdout:3] A_local: tensor([[ 9, 11, 13, 15],
        [25, 27, 29, 31],
        [41, 43, 45, 47],
        [57, 59, 61, 63]])
x_local: tensor([[3],
        [7]])
x_gather_col: tensor([[1],
        [3],
        [5],
        [7]])
y_local: tensor([[212],
        [468],
        [724],
        [980]])
y_scatter: tensor([[ 812],
        [1708]])
y_end: tensor([[ 140],
        [ 364],
        [ 588],
        [ 812],
        [1036],
        [1260],
        [1484],
        [1708]])


[stdout:0] A_local: tensor([[ 0,  2,  4,  6],
        [16, 18, 20, 22],
        [32, 34, 36, 38],
        [48, 50, 52, 54]])
x_local: tensor([[0],
        [4]])
x_gather_col: tensor([[0],
        [2],
        [4],
        [6]])
y_local: tensor([[ 56],
        [248],
        [440],
        [632]])
y_scatter: tensor([[ 140],
        [1036]])
y_end: tensor([[ 140],
        [ 364],
        [ 588],
        [ 812],
        [1036],
        [1260],
        [1484],
        [1708]])


In [8]:
%%px
print("Expected:", A_g @ x_g)
print("Actual:", y_end)
torch.allclose(y_end, A_g @ x_g)

[stdout:0] Expected: tensor([[ 140],
        [ 364],
        [ 588],
        [ 812],
        [1036],
        [1260],
        [1484],
        [1708]])
Actual: tensor([[ 140],
        [ 364],
        [ 588],
        [ 812],
        [1036],
        [1260],
        [1484],
        [1708]])


[0;31mOut[0:7]: [0mTrue

[stdout:3] Expected: tensor([[ 140],
        [ 364],
        [ 588],
        [ 812],
        [1036],
        [1260],
        [1484],
        [1708]])
Actual: tensor([[ 140],
        [ 364],
        [ 588],
        [ 812],
        [1036],
        [1260],
        [1484],
        [1708]])


[stdout:1] Expected: tensor([[ 140],
        [ 364],
        [ 588],
        [ 812],
        [1036],
        [1260],
        [1484],
        [1708]])
Actual: tensor([[ 140],
        [ 364],
        [ 588],
        [ 812],
        [1036],
        [1260],
        [1484],
        [1708]])


[stdout:2] Expected: tensor([[ 140],
        [ 364],
        [ 588],
        [ 812],
        [1036],
        [1260],
        [1484],
        [1708]])
Actual: tensor([[ 140],
        [ 364],
        [ 588],
        [ 812],
        [1036],
        [1260],
        [1484],
        [1708]])


[0;31mOut[3:7]: [0mTrue

[0;31mOut[1:7]: [0mTrue

[0;31mOut[2:7]: [0mTrue

### x = A.T * y

In [9]:
%%px
A_g = torch.arange(64).reshape(8, 8)
A_l = A_g[col_comm.Get_rank()::2, row_comm.Get_rank()::2].contiguous()
print("A_local:", A_l)

y_g = torch.arange(8).reshape(8, 1)
i = col_comm.Get_rank() + col_comm.Get_size() * row_comm.Get_rank()
y_l = y_g[i::comm.Get_size(), :].contiguous()
print("y_local:", y_l)

y_col = torch.zeros(4,1, dtype=torch.long)
row_comm.Allgather((as_buffer(y_l), 2, MPI.LONG), (as_buffer(y_col), 2, MPI.LONG))
y_col = y_col.reshape(2,2).T.reshape(4,1).contiguous()
print("y_gather_col:", y_col)

x_l = A_l.T @ y_col
print("x_local:", x_l)

x_l = x_l.reshape(2,2).T.reshape(4,1).contiguous()
x_scatter = torch.zeros(2, 1, dtype=torch.long)
col_comm.Reduce_scatter((as_buffer(x_l), 4, MPI.LONG), (as_buffer(x_scatter), 2, MPI.LONG), [2,2], MPI.SUM)
print("x_scatter:", x_scatter)


[stdout:3] A_local: tensor([[ 9, 11, 13, 15],
        [25, 27, 29, 31],
        [41, 43, 45, 47],
        [57, 59, 61, 63]])
y_local: tensor([[3],
        [7]])
y_gather_col: tensor([[1],
        [3],
        [5],
        [7]])
x_local: tensor([[688],
        [720],
        [752],
        [784]])
x_scatter: tensor([[1204],
        [1316]])


[stdout:0] A_local: tensor([[ 0,  2,  4,  6],
        [16, 18, 20, 22],
        [32, 34, 36, 38],
        [48, 50, 52, 54]])
y_local: tensor([[0],
        [4]])
y_gather_col: tensor([[0],
        [2],
        [4],
        [6]])
x_local: tensor([[448],
        [472],
        [496],
        [520]])
x_scatter: tensor([[1120],
        [1232]])


[stdout:1] A_local: tensor([[ 1,  3,  5,  7],
        [17, 19, 21, 23],
        [33, 35, 37, 39],
        [49, 51, 53, 55]])
y_local: tensor([[2],
        [6]])
y_gather_col: tensor([[0],
        [2],
        [4],
        [6]])
x_local: tensor([[460],
        [484],
        [508],
        [532]])
x_scatter: tensor([[1148],
        [1260]])


[stdout:2] A_local: tensor([[ 8, 10, 12, 14],
        [24, 26, 28, 30],
        [40, 42, 44, 46],
        [56, 58, 60, 62]])
y_local: tensor([[1],
        [5]])
y_gather_col: tensor([[1],
        [3],
        [5],
        [7]])
x_local: tensor([[672],
        [704],
        [736],
        [768]])
x_scatter: tensor([[1176],
        [1288]])


In [10]:
%%px
print("expected:", A_g.T @ y_g)

[stdout:0] expected: tensor([[1120],
        [1148],
        [1176],
        [1204],
        [1232],
        [1260],
        [1288],
        [1316]])


[stdout:1] expected: tensor([[1120],
        [1148],
        [1176],
        [1204],
        [1232],
        [1260],
        [1288],
        [1316]])


[stdout:2] expected: tensor([[1120],
        [1148],
        [1176],
        [1204],
        [1232],
        [1260],
        [1288],
        [1316]])


[stdout:3] expected: tensor([[1120],
        [1148],
        [1176],
        [1204],
        [1232],
        [1260],
        [1288],
        [1316]])


### A := y * x.T + A

In [14]:
%%px
A_g = torch.arange(64).reshape(8, 8)
A_l = A_g[col_comm.Get_rank()::2, row_comm.Get_rank()::2].contiguous()
print("A_local:", A_l)

x_g = torch.arange(8).reshape(8, 1)
x_l = x_g[comm.Get_rank()::comm.Get_size(), :].contiguous()
print("x_local:", x_l)

y_g = torch.arange(8).reshape(8, 1)
i = col_comm.Get_rank() + col_comm.Get_size() * row_comm.Get_rank()
y_l = y_g[i::comm.Get_size(), :].contiguous()
print("y_local:", y_l)

x_col = torch.zeros(4,1, dtype=torch.long)
col_comm.Allgather((as_buffer(x_l), 2, MPI.LONG), (as_buffer(x_col), 2, MPI.LONG))
x_col = x_col.reshape(2,2).T.reshape(4,1).contiguous()
print("x_gather_col:", x_col)

y_col = torch.zeros(4,1, dtype=torch.long)
row_comm.Allgather((as_buffer(y_l), 2, MPI.LONG), (as_buffer(y_col), 2, MPI.LONG))
y_col = y_col.reshape(2,2).T.reshape(4,1).contiguous()
print("y_gather_col:", y_col)

Z_l = y_col @ x_col.T + A_l
print("Z_local:", Z_l)




[stdout:3] A_local: tensor([[ 9, 11, 13, 15],
        [25, 27, 29, 31],
        [41, 43, 45, 47],
        [57, 59, 61, 63]])
x_local: tensor([[3],
        [7]])
y_local: tensor([[3],
        [7]])
x_gather_col: tensor([[1],
        [3],
        [5],
        [7]])
y_gather_col: tensor([[1],
        [3],
        [5],
        [7]])
Z_local: tensor([[ 10,  14,  18,  22],
        [ 28,  36,  44,  52],
        [ 46,  58,  70,  82],
        [ 64,  80,  96, 112]])


[stdout:2] A_local: tensor([[ 8, 10, 12, 14],
        [24, 26, 28, 30],
        [40, 42, 44, 46],
        [56, 58, 60, 62]])
x_local: tensor([[2],
        [6]])
y_local: tensor([[1],
        [5]])
x_gather_col: tensor([[0],
        [2],
        [4],
        [6]])
y_gather_col: tensor([[1],
        [3],
        [5],
        [7]])
Z_local: tensor([[  8,  12,  16,  20],
        [ 24,  32,  40,  48],
        [ 40,  52,  64,  76],
        [ 56,  72,  88, 104]])


[stdout:1] A_local: tensor([[ 1,  3,  5,  7],
        [17, 19, 21, 23],
        [33, 35, 37, 39],
        [49, 51, 53, 55]])
x_local: tensor([[1],
        [5]])
y_local: tensor([[2],
        [6]])
x_gather_col: tensor([[1],
        [3],
        [5],
        [7]])
y_gather_col: tensor([[0],
        [2],
        [4],
        [6]])
Z_local: tensor([[ 1,  3,  5,  7],
        [19, 25, 31, 37],
        [37, 47, 57, 67],
        [55, 69, 83, 97]])


[stdout:0] A_local: tensor([[ 0,  2,  4,  6],
        [16, 18, 20, 22],
        [32, 34, 36, 38],
        [48, 50, 52, 54]])
x_local: tensor([[0],
        [4]])
y_local: tensor([[0],
        [4]])
x_gather_col: tensor([[0],
        [2],
        [4],
        [6]])
y_gather_col: tensor([[0],
        [2],
        [4],
        [6]])
Z_local: tensor([[ 0,  2,  4,  6],
        [16, 22, 28, 34],
        [32, 42, 52, 62],
        [48, 62, 76, 90]])


In [12]:
%%px
print(f"Expected: {y_g @ x_g.T + A_g}")

[stdout:3] Expected: tensor([[  0,   1,   2,   3,   4,   5,   6,   7],
        [  8,  10,  12,  14,  16,  18,  20,  22],
        [ 16,  19,  22,  25,  28,  31,  34,  37],
        [ 24,  28,  32,  36,  40,  44,  48,  52],
        [ 32,  37,  42,  47,  52,  57,  62,  67],
        [ 40,  46,  52,  58,  64,  70,  76,  82],
        [ 48,  55,  62,  69,  76,  83,  90,  97],
        [ 56,  64,  72,  80,  88,  96, 104, 112]])


[stdout:0] Expected: tensor([[  0,   1,   2,   3,   4,   5,   6,   7],
        [  8,  10,  12,  14,  16,  18,  20,  22],
        [ 16,  19,  22,  25,  28,  31,  34,  37],
        [ 24,  28,  32,  36,  40,  44,  48,  52],
        [ 32,  37,  42,  47,  52,  57,  62,  67],
        [ 40,  46,  52,  58,  64,  70,  76,  82],
        [ 48,  55,  62,  69,  76,  83,  90,  97],
        [ 56,  64,  72,  80,  88,  96, 104, 112]])


[stdout:1] Expected: tensor([[  0,   1,   2,   3,   4,   5,   6,   7],
        [  8,  10,  12,  14,  16,  18,  20,  22],
        [ 16,  19,  22,  25,  28,  31,  34,  37],
        [ 24,  28,  32,  36,  40,  44,  48,  52],
        [ 32,  37,  42,  47,  52,  57,  62,  67],
        [ 40,  46,  52,  58,  64,  70,  76,  82],
        [ 48,  55,  62,  69,  76,  83,  90,  97],
        [ 56,  64,  72,  80,  88,  96, 104, 112]])


[stdout:2] Expected: tensor([[  0,   1,   2,   3,   4,   5,   6,   7],
        [  8,  10,  12,  14,  16,  18,  20,  22],
        [ 16,  19,  22,  25,  28,  31,  34,  37],
        [ 24,  28,  32,  36,  40,  44,  48,  52],
        [ 32,  37,  42,  47,  52,  57,  62,  67],
        [ 40,  46,  52,  58,  64,  70,  76,  82],
        [ 48,  55,  62,  69,  76,  83,  90,  97],
        [ 56,  64,  72,  80,  88,  96, 104, 112]])
