# 2D and 3D Matrix Multiplication
## Setup
### Installation
```pip install ipyparallel```

or 

```pip install -e .[notebook]```

### Start cluster

```ipcluster start -n 4 --engines=MPI --profile mpi```

In [None]:
import ipyparallel as ipp
rc = ipp.Client(profile='mpi')
rc.wait_for_engines(4)
len(rc)

In [None]:
%%px
import torch
from mpi4py import MPI

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
print(f'Hello from rank {rank}!')

In [None]:
%%px
cart_comm = comm.Create_cart(dims=[2, 2], periods=[True, True], reorder=True)
print(f'Hello from rank {rank}! My coordinates are {cart_comm.Get_coords(rank)}')


In [None]:
%%px
print(f"Topo: {cart_comm.Get_topo()}")

In [None]:
%%px
row_comm = cart_comm.Sub([0,1])
row_global_ranks = row_comm.allgather(rank)
print(f'Hello from rank {rank}! In my row, the ranks are {row_global_ranks}')

col_comm = cart_comm.Sub([1,0])
col_global_ranks = col_comm.allgather(rank)
print(f'Hello from rank {rank}! In my col, the ranks are {col_global_ranks}')


In [None]:
%%px
def as_buffer(x: torch.Tensor):
    return MPI.buffer.fromaddress(x.untyped_storage().data_ptr(), 0)

## Matrix Vector

### y:= Ax

In [None]:
%%px
A_g = torch.arange(64).reshape(8, 8)
A_l = A_g[col_comm.Get_rank()::2, row_comm.Get_rank()::2].contiguous()
print("A_local:", A_l)

x_g = torch.arange(8).reshape(8, 1)
x_l = x_g[comm.Get_rank()::comm.Get_size(), :].contiguous()
print("x_local:", x_l)


# x_col = torch.zeros(4,1, dtype=torch.long)
# col_comm.Allgather((as_buffer(x_l), 2, MPI.LONG), (as_buffer(x_col), 2, MPI.LONG))
# x_col = x_col.reshape(2,2).T.reshape(4,1).contiguous()
# print("x_gather_col:", x_col)

x_col = torch.zeros(4,1, dtype=torch.long)
data_type = MPI.LONG.Create_vector(2, 1, 2).Create_resized(MPI.LONG.Get_extent()[0], MPI.LONG.Get_extent()[1]).Commit()
col_comm.Allgather((as_buffer(x_l), 2, MPI.LONG), (as_buffer(x_col), 1, data_type))
print("x_gather_col:", x_col)

y_l = A_l @ x_col
print("y_local:", y_l)

y_l = y_l.reshape(2,2).T.reshape(4,1).contiguous()
y_scatter = torch.zeros(2, 1, dtype=torch.long)
row_comm.Reduce_scatter((as_buffer(y_l), 4, MPI.LONG), (as_buffer(y_scatter), 2, MPI.LONG), [2,2], MPI.SUM)
print("y_scatter:", y_scatter)

y_end = torch.zeros(2,2,2, dtype=torch.long)
comm.Allgather((as_buffer(y_scatter), 2, MPI.LONG), (as_buffer(y_end), 2, MPI.LONG))
y_end = y_end.permute(2,1,0).reshape(8,1).contiguous()
print("y_end:", y_end)



In [None]:
%%px
print("Expected:", A_g @ x_g)
print("Actual:", y_end)
torch.allclose(y_end, A_g @ x_g)

### x = A.T * y

In [None]:
%%px
A_g = torch.arange(64).reshape(8, 8)
A_l = A_g[col_comm.Get_rank()::2, row_comm.Get_rank()::2].contiguous()
print("A_local:", A_l)

y_g = torch.arange(8).reshape(8, 1)
i = col_comm.Get_rank() + col_comm.Get_size() * row_comm.Get_rank()
y_l = y_g[i::comm.Get_size(), :].contiguous()
print("y_local:", y_l)

y_col = torch.zeros(4,1, dtype=torch.long)
row_comm.Allgather((as_buffer(y_l), 2, MPI.LONG), (as_buffer(y_col), 2, MPI.LONG))
y_col = y_col.reshape(2,2).T.reshape(4,1).contiguous()
print("y_gather_col:", y_col)

x_l = A_l.T @ y_col
print("x_local:", x_l)

x_l = x_l.reshape(2,2).T.reshape(4,1).contiguous()
x_scatter = torch.zeros(2, 1, dtype=torch.long)
col_comm.Reduce_scatter((as_buffer(x_l), 4, MPI.LONG), (as_buffer(x_scatter), 2, MPI.LONG), [2,2], MPI.SUM)
print("x_scatter:", x_scatter)


In [None]:
%%px
print("expected:", A_g.T @ y_g)

### A := y * x.T + A

In [None]:
%%px
A_g = torch.arange(64).reshape(8, 8)
A_l = A_g[col_comm.Get_rank()::2, row_comm.Get_rank()::2].contiguous()
print("A_local:", A_l)

x_g = torch.arange(8).reshape(8, 1)
x_l = x_g[comm.Get_rank()::comm.Get_size(), :].contiguous()
print("x_local:", x_l)

y_g = torch.arange(8).reshape(8, 1)
i = col_comm.Get_rank() + col_comm.Get_size() * row_comm.Get_rank()
y_l = y_g[i::comm.Get_size(), :].contiguous()
print("y_local:", y_l)

x_col = torch.zeros(4,1, dtype=torch.long)
col_comm.Allgather((as_buffer(x_l), 2, MPI.LONG), (as_buffer(x_col), 2, MPI.LONG))
x_col = x_col.reshape(2,2).T.reshape(4,1).contiguous()
print("x_gather_col:", x_col)

y_col = torch.zeros(4,1, dtype=torch.long)
row_comm.Allgather((as_buffer(y_l), 2, MPI.LONG), (as_buffer(y_col), 2, MPI.LONG))
y_col = y_col.reshape(2,2).T.reshape(4,1).contiguous()
print("y_gather_col:", y_col)

Z_l = y_col @ x_col.T + A_l
print("Z_local:", Z_l)




In [None]:
%%px
print(f"Expected: {y_g @ x_g.T + A_g}")

## Create_darray

In [None]:
%%px
if comm.Get_rank() == 0:
    A = torch.arange(64).reshape(8, 8)
    print(A)
    print(A.dtype)

    darray_type = MPI.LONG.Create_darray(
        4,                    # Size
        1,                      # Rank
        # 2,                      # number of array dimensions (as well as process grid dimensions)
        [64,1],                 # size of the global array
        [MPI.DISTRIBUTE_CYCLIC, MPI.DISTRIBUTE_NONE], # distribution type
        [1, 1], # distribution argument
        [4, 1],                 # size of the process grid
        MPI.ORDER_C,            # array storage order
    ).Commit()

    # comm.Send(buf=[as_buffer(A), 8, MPI.LONG], dest=1)
    comm.Send([as_buffer(A), 1, darray_type], dest=1, tag=55)

    darray_type.Free() 

elif comm.Get_rank() == 1:
    A = torch.zeros(4, 4, dtype=torch.int64)

    # comm.Recv(buf=[as_buffer(A), 8, MPI.LONG], source=0)
    comm.Recv([as_buffer(A), 16, MPI.LONG], source=0, tag=55)
    print(A)



## Large Item Count

In [None]:
%%px
import tensorcraft as tc
comm = MPI.COMM_WORLD
## Let's find the maximum element count on my mpi implementation
options = [torch.iinfo(torch.int32).max]

for possible_max in options:
    print(f"Trying {possible_max}")
    if comm.Get_rank() == 0:
        A = torch.ones(possible_max, dtype=torch.bool)
        print(A.dtype)
        print(A[:10])

        print(f"Sending {possible_max} elements, {possible_max / 10**9} Gb")

    else:
        A = torch.zeros(possible_max, dtype=torch.bool)

    comm.Bcast(buf=[tc.mpi4torch.as_buffer(A), possible_max, MPI.BOOL], root=0)

    if comm.Get_rank() == 0:
        print("Sent!")
    else:
        print("Received!")
        print(A[:10])

# Interweave allgather

In [None]:
import ipyparallel as ipp
rc = ipp.Client(profile='mpi')
rc.wait_for_engines(4)
len(rc)

4

In [None]:
%%px
import torch
from mpi4py import MPI

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
print(f'Hello from rank {rank}!')

%px:   0%|          | 0/4 [00:36<?, ?tasks/s]

In [None]:
%%px
import logging
import tensorcraft as tc

log = logging.getLogger('tensorcraft')
log.setLevel(logging.INFO)

x = torch.arange(40).reshape(2, 10, 2)
mesh = torch.Size([2,2])
dist = tc.dist.MultiAxisDist(mesh, (None, (0,1), None), 1)

x_local = dist.apply(x, rank)
print(x_local)
print(x_local.shape)
print(x_local.dtype)
print(x_local.is_contiguous())

%px:   0%|          | 0/4 [00:02<?, ?tasks/s]

[stdout:0] 22-04-2025 11:42:18 : INFO : multi_axis : apply -- R0: Processor multi index: torch.Size([0, 0])
22-04-2025 11:42:18 : INFO : multi_axis : apply -- R0: Missing elements: [0, 1, 0]
22-04-2025 11:42:18 : INFO : multi_axis : apply -- R0: N blocks per axis: [1, 10, 1]
22-04-2025 11:42:18 : INFO : multi_axis : apply -- R0: Padded tensor shape: torch.Size([2, 11, 2])
22-04-2025 11:42:18 : INFO : multi_axis : apply -- R0: Permute tuple: (0, 2, 4, 1, 3, 5)
22-04-2025 11:42:18 : INFO : multi_axis : apply -- R0: Reshape tuple: [1, 2, 11, 1, 1, 2]
22-04-2025 11:42:18 : INFO : multi_axis : apply -- R0: Tile Slices: [slice(None, None, None), slice(None, None, None), slice(tensor(0), None, 4), slice(None, None, None), slice(None, None, None), slice(None, None, None)]
22-04-2025 11:42:18 : INFO : multi_axis : apply -- R0: Local tensor shape: torch.Size([1, 2, 3, 1, 1, 2])
22-04-2025 11:42:18 : INFO : multi_axis : apply -- R0: Target local tensor shape: [2, 3, 2]
22-04-2025 11:42:18 : INFO 

[stdout:1] 22-04-2025 11:42:18 : INFO : multi_axis : apply -- R1: Processor multi index: torch.Size([0, 1])
22-04-2025 11:42:18 : INFO : multi_axis : apply -- R1: Missing elements: [0, 1, 0]
22-04-2025 11:42:18 : INFO : multi_axis : apply -- R1: N blocks per axis: [1, 10, 1]
22-04-2025 11:42:18 : INFO : multi_axis : apply -- R1: Padded tensor shape: torch.Size([2, 11, 2])
22-04-2025 11:42:18 : INFO : multi_axis : apply -- R1: Permute tuple: (0, 2, 4, 1, 3, 5)
22-04-2025 11:42:18 : INFO : multi_axis : apply -- R1: Reshape tuple: [1, 2, 11, 1, 1, 2]
22-04-2025 11:42:18 : INFO : multi_axis : apply -- R1: Tile Slices: [slice(None, None, None), slice(None, None, None), slice(tensor(1), None, 4), slice(None, None, None), slice(None, None, None), slice(None, None, None)]
22-04-2025 11:42:18 : INFO : multi_axis : apply -- R1: Local tensor shape: torch.Size([1, 2, 3, 1, 1, 2])
22-04-2025 11:42:18 : INFO : multi_axis : apply -- R1: Target local tensor shape: [2, 3, 2]
22-04-2025 11:42:18 : INFO 

%px:   0%|          | 0/4 [00:02<?, ?tasks/s]

[stdout:3] 22-04-2025 11:42:18 : INFO : multi_axis : apply -- R3: Processor multi index: torch.Size([1, 1])
22-04-2025 11:42:18 : INFO : multi_axis : apply -- R3: Missing elements: [0, 1, 0]
22-04-2025 11:42:18 : INFO : multi_axis : apply -- R3: N blocks per axis: [1, 10, 1]
22-04-2025 11:42:18 : INFO : multi_axis : apply -- R3: Padded tensor shape: torch.Size([2, 11, 2])
22-04-2025 11:42:18 : INFO : multi_axis : apply -- R3: Permute tuple: (0, 2, 4, 1, 3, 5)
22-04-2025 11:42:18 : INFO : multi_axis : apply -- R3: Reshape tuple: [1, 2, 11, 1, 1, 2]
22-04-2025 11:42:18 : INFO : multi_axis : apply -- R3: Tile Slices: [slice(None, None, None), slice(None, None, None), slice(tensor(3), None, 4), slice(None, None, None), slice(None, None, None), slice(None, None, None)]
22-04-2025 11:42:18 : INFO : multi_axis : apply -- R3: Local tensor shape: torch.Size([1, 2, 2, 1, 1, 2])
22-04-2025 11:42:18 : INFO : multi_axis : apply -- R3: Target local tensor shape: [2, 2, 2]
22-04-2025 11:42:18 : INFO 

[stdout:2] 22-04-2025 11:42:18 : INFO : multi_axis : apply -- R2: Processor multi index: torch.Size([1, 0])
22-04-2025 11:42:18 : INFO : multi_axis : apply -- R2: Missing elements: [0, 1, 0]
22-04-2025 11:42:18 : INFO : multi_axis : apply -- R2: N blocks per axis: [1, 10, 1]
22-04-2025 11:42:18 : INFO : multi_axis : apply -- R2: Padded tensor shape: torch.Size([2, 11, 2])
22-04-2025 11:42:18 : INFO : multi_axis : apply -- R2: Permute tuple: (0, 2, 4, 1, 3, 5)
22-04-2025 11:42:18 : INFO : multi_axis : apply -- R2: Reshape tuple: [1, 2, 11, 1, 1, 2]
22-04-2025 11:42:18 : INFO : multi_axis : apply -- R2: Tile Slices: [slice(None, None, None), slice(None, None, None), slice(tensor(2), None, 4), slice(None, None, None), slice(None, None, None), slice(None, None, None)]
22-04-2025 11:42:18 : INFO : multi_axis : apply -- R2: Local tensor shape: torch.Size([1, 2, 3, 1, 1, 2])
22-04-2025 11:42:18 : INFO : multi_axis : apply -- R2: Target local tensor shape: [2, 3, 2]
22-04-2025 11:42:18 : INFO 

%px: 100%|██████████| 4/4 [00:02<00:00,  1.82tasks/s]


In [None]:
%%px
post_gather, new_dist = dist.apply_allgather(x.shape, x_local, comm, mesh_dim=0)
print(post_gather)

[stdout:0] 22-04-2025 11:42:19 : INFO : multi_axis : apply_allgather -- R0: Local tensor shape: torch.Size([2, 3, 2])
22-04-2025 11:42:19 : INFO : multi_axis : apply_allgather -- R0: Expected local shape: torch.Size([2, 3, 2])
22-04-2025 11:42:19 : INFO : multi_axis : apply_allgather -- R0: Changed tensor axis: 1, minor: False
22-04-2025 11:42:19 : INFO : multi_axis : apply_allgather -- R0: New distribution: D_[2,2]⊥{∅,1,∅}(∅,1,∅), new shape: torch.Size([2, 5, 2])
22-04-2025 11:42:19 : INFO : multi_axis : apply_allgather -- R0: Processor multi index: torch.Size([0, 0])
22-04-2025 11:42:19 : INFO : multi_axis : apply_allgather -- R0: N procs: 2
22-04-2025 11:42:19 : INFO : multi_axis : apply_allgather -- R0: Rank of largest tensor in the subcommunicator: [0, 0] 0
22-04-2025 11:42:19 : INFO : multi_axis : apply_allgather -- R0: N elements: 12
22-04-2025 11:42:19 : INFO : multi_axis : apply_allgather -- R0: Max local shape: torch.Size([2, 3, 2])
22-04-2025 11:42:19 : INFO : multi_axis : a

[stdout:2] 22-04-2025 11:42:19 : INFO : multi_axis : apply_allgather -- R2: Local tensor shape: torch.Size([2, 2, 2])
22-04-2025 11:42:19 : INFO : multi_axis : apply_allgather -- R2: Expected local shape: torch.Size([2, 2, 2])
22-04-2025 11:42:19 : INFO : multi_axis : apply_allgather -- R2: Changed tensor axis: 1, minor: False
22-04-2025 11:42:19 : INFO : multi_axis : apply_allgather -- R2: New distribution: D_[2,2]⊥{∅,1,∅}(∅,1,∅), new shape: torch.Size([2, 5, 2])
22-04-2025 11:42:19 : INFO : multi_axis : apply_allgather -- R2: Processor multi index: torch.Size([1, 0])
22-04-2025 11:42:19 : INFO : multi_axis : apply_allgather -- R2: N procs: 2
22-04-2025 11:42:19 : INFO : multi_axis : apply_allgather -- R2: Rank of largest tensor in the subcommunicator: [0, 0] 0
22-04-2025 11:42:19 : INFO : multi_axis : apply_allgather -- R2: N elements: 12
22-04-2025 11:42:19 : INFO : multi_axis : apply_allgather -- R2: Max local shape: torch.Size([2, 3, 2])
22-04-2025 11:42:19 : INFO : multi_axis : a

[stdout:1] 22-04-2025 11:42:19 : INFO : multi_axis : apply_allgather -- R1: Local tensor shape: torch.Size([2, 3, 2])
22-04-2025 11:42:19 : INFO : multi_axis : apply_allgather -- R1: Expected local shape: torch.Size([2, 3, 2])
22-04-2025 11:42:19 : INFO : multi_axis : apply_allgather -- R1: Changed tensor axis: 1, minor: False
22-04-2025 11:42:19 : INFO : multi_axis : apply_allgather -- R1: New distribution: D_[2,2]⊥{∅,1,∅}(∅,1,∅), new shape: torch.Size([2, 5, 2])
22-04-2025 11:42:19 : INFO : multi_axis : apply_allgather -- R1: Processor multi index: torch.Size([0, 1])
22-04-2025 11:42:19 : INFO : multi_axis : apply_allgather -- R1: N procs: 2
22-04-2025 11:42:19 : INFO : multi_axis : apply_allgather -- R1: Rank of largest tensor in the subcommunicator: [0, 1] 1
22-04-2025 11:42:19 : INFO : multi_axis : apply_allgather -- R1: N elements: 12
22-04-2025 11:42:19 : INFO : multi_axis : apply_allgather -- R1: Max local shape: torch.Size([2, 3, 2])
22-04-2025 11:42:19 : INFO : multi_axis : a

[stdout:3] 22-04-2025 11:42:19 : INFO : multi_axis : apply_allgather -- R3: Local tensor shape: torch.Size([2, 2, 2])
22-04-2025 11:42:19 : INFO : multi_axis : apply_allgather -- R3: Expected local shape: torch.Size([2, 2, 2])
22-04-2025 11:42:19 : INFO : multi_axis : apply_allgather -- R3: Changed tensor axis: 1, minor: False
22-04-2025 11:42:19 : INFO : multi_axis : apply_allgather -- R3: New distribution: D_[2,2]⊥{∅,1,∅}(∅,1,∅), new shape: torch.Size([2, 5, 2])
22-04-2025 11:42:19 : INFO : multi_axis : apply_allgather -- R3: Processor multi index: torch.Size([1, 1])
22-04-2025 11:42:19 : INFO : multi_axis : apply_allgather -- R3: N procs: 2
22-04-2025 11:42:19 : INFO : multi_axis : apply_allgather -- R3: Rank of largest tensor in the subcommunicator: [0, 1] 1
22-04-2025 11:42:19 : INFO : multi_axis : apply_allgather -- R3: N elements: 12
22-04-2025 11:42:19 : INFO : multi_axis : apply_allgather -- R3: Max local shape: torch.Size([2, 3, 2])
22-04-2025 11:42:19 : INFO : multi_axis : a

  index_tensor = torch.tensor(index)


  index_tensor = torch.tensor(index)


  index_tensor = torch.tensor(index)


  index_tensor = torch.tensor(index)


In [None]:
%%px
print(post_gather.shape)
post_gather.permute(1,2,0,3,4)

[stdout:1] torch.Size([2, 2, 3, 1, 2])


[stdout:0] torch.Size([2, 2, 3, 1, 2])


[0;31mOut[0:7]: [0m
tensor([[[[[ 0,  1]],

          [[ 4,  5]]],


         [[[ 8,  9]],

          [[12, 13]]],


         [[[16, 17]],

          [[ 0,  0]]]],



        [[[[20, 21]],

          [[24, 25]]],


         [[[28, 29]],

          [[32, 33]]],


         [[[36, 37]],

          [[ 0,  0]]]]])

[0;31mOut[1:7]: [0m
tensor([[[[[ 2,  3]],

          [[ 6,  7]]],


         [[[10, 11]],

          [[14, 15]]],


         [[[18, 19]],

          [[ 0,  0]]]],



        [[[[22, 23]],

          [[26, 27]]],


         [[[30, 31]],

          [[34, 35]]],


         [[[38, 39]],

          [[ 0,  0]]]]])

[stdout:2] torch.Size([2, 2, 3, 1, 2])


[0;31mOut[2:7]: [0m
tensor([[[[[ 0,  1]],

          [[ 4,  5]]],


         [[[ 8,  9]],

          [[12, 13]]],


         [[[16, 17]],

          [[ 0,  0]]]],



        [[[[20, 21]],

          [[24, 25]]],


         [[[28, 29]],

          [[32, 33]]],


         [[[36, 37]],

          [[ 0,  0]]]]])

[stdout:3] torch.Size([2, 2, 3, 1, 2])


[0;31mOut[3:7]: [0m
tensor([[[[[ 2,  3]],

          [[ 6,  7]]],


         [[[10, 11]],

          [[14, 15]]],


         [[[18, 19]],

          [[ 0,  0]]]],



        [[[[22, 23]],

          [[26, 27]]],


         [[[30, 31]],

          [[34, 35]]],


         [[[38, 39]],

          [[ 0,  0]]]]])