# 2D and 3D Matrix Multiplication
## Setup
### Installation
```pip install ipyparallel```

or 

```pip install -e .[notebook]```

### Start cluster

```ipcluster start -n 4 --engines=MPI --profile mpi```

In [None]:
import ipyparallel as ipp
rc = ipp.Client(profile='mpi')
rc.wait_for_engines(4)
len(rc)

In [None]:
%%px
import torch
from mpi4py import MPI

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
print(f'Hello from rank {rank}!')

In [None]:
%%px
cart_comm = comm.Create_cart(dims=[2, 2], periods=[True, True], reorder=True)
print(f'Hello from rank {rank}! My coordinates are {cart_comm.Get_coords(rank)}')


In [None]:
%%px
print(f"Topo: {cart_comm.Get_topo()}")

In [None]:
%%px
row_comm = cart_comm.Sub([0,1])
row_global_ranks = row_comm.allgather(rank)
print(f'Hello from rank {rank}! In my row, the ranks are {row_global_ranks}')

col_comm = cart_comm.Sub([1,0])
col_global_ranks = col_comm.allgather(rank)
print(f'Hello from rank {rank}! In my col, the ranks are {col_global_ranks}')


In [None]:
%%px
def as_buffer(x: torch.Tensor):
    return MPI.buffer.fromaddress(x.untyped_storage().data_ptr(), 0)

## Matrix Vector

### y:= Ax

In [None]:
%%px
A_g = torch.arange(64).reshape(8, 8)
A_l = A_g[col_comm.Get_rank()::2, row_comm.Get_rank()::2].contiguous()
print("A_local:", A_l)

x_g = torch.arange(8).reshape(8, 1)
x_l = x_g[comm.Get_rank()::comm.Get_size(), :].contiguous()
print("x_local:", x_l)


# x_col = torch.zeros(4,1, dtype=torch.long)
# col_comm.Allgather((as_buffer(x_l), 2, MPI.LONG), (as_buffer(x_col), 2, MPI.LONG))
# x_col = x_col.reshape(2,2).T.reshape(4,1).contiguous()
# print("x_gather_col:", x_col)

x_col = torch.zeros(4,1, dtype=torch.long)
data_type = MPI.LONG.Create_vector(2, 1, 2).Create_resized(MPI.LONG.Get_extent()[0], MPI.LONG.Get_extent()[1]).Commit()
col_comm.Allgather((as_buffer(x_l), 2, MPI.LONG), (as_buffer(x_col), 1, data_type))
print("x_gather_col:", x_col)

y_l = A_l @ x_col
print("y_local:", y_l)

y_l = y_l.reshape(2,2).T.reshape(4,1).contiguous()
y_scatter = torch.zeros(2, 1, dtype=torch.long)
row_comm.Reduce_scatter((as_buffer(y_l), 4, MPI.LONG), (as_buffer(y_scatter), 2, MPI.LONG), [2,2], MPI.SUM)
print("y_scatter:", y_scatter)

y_end = torch.zeros(2,2,2, dtype=torch.long)
comm.Allgather((as_buffer(y_scatter), 2, MPI.LONG), (as_buffer(y_end), 2, MPI.LONG))
y_end = y_end.permute(2,1,0).reshape(8,1).contiguous()
print("y_end:", y_end)



In [None]:
%%px
print("Expected:", A_g @ x_g)
print("Actual:", y_end)
torch.allclose(y_end, A_g @ x_g)

### x = A.T * y

In [None]:
%%px
A_g = torch.arange(64).reshape(8, 8)
A_l = A_g[col_comm.Get_rank()::2, row_comm.Get_rank()::2].contiguous()
print("A_local:", A_l)

y_g = torch.arange(8).reshape(8, 1)
i = col_comm.Get_rank() + col_comm.Get_size() * row_comm.Get_rank()
y_l = y_g[i::comm.Get_size(), :].contiguous()
print("y_local:", y_l)

y_col = torch.zeros(4,1, dtype=torch.long)
row_comm.Allgather((as_buffer(y_l), 2, MPI.LONG), (as_buffer(y_col), 2, MPI.LONG))
y_col = y_col.reshape(2,2).T.reshape(4,1).contiguous()
print("y_gather_col:", y_col)

x_l = A_l.T @ y_col
print("x_local:", x_l)

x_l = x_l.reshape(2,2).T.reshape(4,1).contiguous()
x_scatter = torch.zeros(2, 1, dtype=torch.long)
col_comm.Reduce_scatter((as_buffer(x_l), 4, MPI.LONG), (as_buffer(x_scatter), 2, MPI.LONG), [2,2], MPI.SUM)
print("x_scatter:", x_scatter)


In [None]:
%%px
print("expected:", A_g.T @ y_g)

### A := y * x.T + A

In [None]:
%%px
A_g = torch.arange(64).reshape(8, 8)
A_l = A_g[col_comm.Get_rank()::2, row_comm.Get_rank()::2].contiguous()
print("A_local:", A_l)

x_g = torch.arange(8).reshape(8, 1)
x_l = x_g[comm.Get_rank()::comm.Get_size(), :].contiguous()
print("x_local:", x_l)

y_g = torch.arange(8).reshape(8, 1)
i = col_comm.Get_rank() + col_comm.Get_size() * row_comm.Get_rank()
y_l = y_g[i::comm.Get_size(), :].contiguous()
print("y_local:", y_l)

x_col = torch.zeros(4,1, dtype=torch.long)
col_comm.Allgather((as_buffer(x_l), 2, MPI.LONG), (as_buffer(x_col), 2, MPI.LONG))
x_col = x_col.reshape(2,2).T.reshape(4,1).contiguous()
print("x_gather_col:", x_col)

y_col = torch.zeros(4,1, dtype=torch.long)
row_comm.Allgather((as_buffer(y_l), 2, MPI.LONG), (as_buffer(y_col), 2, MPI.LONG))
y_col = y_col.reshape(2,2).T.reshape(4,1).contiguous()
print("y_gather_col:", y_col)

Z_l = y_col @ x_col.T + A_l
print("Z_local:", Z_l)




In [None]:
%%px
print(f"Expected: {y_g @ x_g.T + A_g}")

## Create_darray

In [None]:
%%px
if comm.Get_rank() == 0:
    A = torch.arange(64).reshape(8, 8)
    print(A)
    print(A.dtype)

    darray_type = MPI.LONG.Create_darray(
        4,                    # Size
        1,                      # Rank
        # 2,                      # number of array dimensions (as well as process grid dimensions)
        [64,1],                 # size of the global array
        [MPI.DISTRIBUTE_CYCLIC, MPI.DISTRIBUTE_NONE], # distribution type
        [1, 1], # distribution argument
        [4, 1],                 # size of the process grid
        MPI.ORDER_C,            # array storage order
    ).Commit()

    # comm.Send(buf=[as_buffer(A), 8, MPI.LONG], dest=1)
    comm.Send([as_buffer(A), 1, darray_type], dest=1, tag=55)

    darray_type.Free() 

elif comm.Get_rank() == 1:
    A = torch.zeros(4, 4, dtype=torch.int64)

    # comm.Recv(buf=[as_buffer(A), 8, MPI.LONG], source=0)
    comm.Recv([as_buffer(A), 16, MPI.LONG], source=0, tag=55)
    print(A)



## Large Item Count

In [None]:
%%px
import tensorcraft as tc
comm = MPI.COMM_WORLD
## Let's find the maximum element count on my mpi implementation
options = [torch.iinfo(torch.int32).max]

for possible_max in options:
    print(f"Trying {possible_max}")
    if comm.Get_rank() == 0:
        A = torch.ones(possible_max, dtype=torch.bool)
        print(A.dtype)
        print(A[:10])

        print(f"Sending {possible_max} elements, {possible_max / 10**9} Gb")

    else:
        A = torch.zeros(possible_max, dtype=torch.bool)

    comm.Bcast(buf=[tc.mpi4torch.as_buffer(A), possible_max, MPI.BOOL], root=0)

    if comm.Get_rank() == 0:
        print("Sent!")
    else:
        print("Received!")
        print(A[:10])

# Interweave allgather

In [None]:
import ipyparallel as ipp
rc = ipp.Client(profile='mpi')
rc.wait_for_engines(4)
len(rc)

4

In [None]:
%%px
import torch
from mpi4py import MPI

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
print(f'Hello from rank {rank}!')

%px:   0%|          | 0/4 [00:03<?, ?tasks/s]

[stdout:3] Hello from rank 3!


%px:  25%|██▌       | 1/4 [00:03<00:00,  9.53tasks/s]

[stdout:1] Hello from rank 1!


%px:  50%|█████     | 2/4 [00:03<00:00,  6.11tasks/s]

[stdout:0] Hello from rank 0!


[stdout:2] Hello from rank 2!


%px: 100%|██████████| 4/4 [00:03<00:00,  1.08tasks/s]


In [None]:
%%px
import logging
import tensorcraft as tc

tc.set_logger_config(log_rank=True, level=logging.DEBUG)

x = torch.arange(40).reshape(2, 10, 2)
mesh = torch.Size([2,2])
dist = tc.mpi.MPIMultiAxisDist(mesh, (None, (0,1), None), 1)

x_local = dist.apply(x, rank)
print(x_local)
print(x_local.shape)
print(x_local.dtype)
print(x_local.is_contiguous())

%px:   0%|          | 0/4 [00:01<?, ?tasks/s]

[stdout:1] [2025-05-02 14:18:33,045][[1;36mtensorcraft.mpi.distributions.multi_axis_dist[0m][[1;35mapply[0m][[1;32mDEBUG[0m] - R1/4:[1;32mProcessor multi index: torch.Size([0, 1])[0m
[2025-05-02 14:18:33,047][[1;36mtensorcraft.mpi.distributions.multi_axis_dist[0m][[1;35mapply[0m][[1;32mDEBUG[0m] - R1/4:[1;32mMissing elements: [0, 1, 0][0m
[2025-05-02 14:18:33,048][[1;36mtensorcraft.mpi.distributions.multi_axis_dist[0m][[1;35mapply[0m][[1;32mDEBUG[0m] - R1/4:[1;32mN blocks per axis: [1, 10, 1][0m
[2025-05-02 14:18:33,050][[1;36mtensorcraft.mpi.distributions.multi_axis_dist[0m][[1;35mapply[0m][[1;32mDEBUG[0m] - R1/4:[1;32mPadded tensor shape: torch.Size([2, 11, 2])[0m
[2025-05-02 14:18:33,054][[1;36mtensorcraft.mpi.distributions.multi_axis_dist[0m][[1;35mapply[0m][[1;32mDEBUG[0m] - R1/4:[1;32mPermute tuple: (0, 2, 4, 1, 3, 5)[0m
[2025-05-02 14:18:33,056][[1;36mtensorcraft.mpi.distributions.multi_axis_dist[0m][[1;35mapply[0m][[1;32mDEBUG[0m] -

[stdout:3] [2025-05-02 14:18:33,106][[1;36mtensorcraft.mpi.distributions.multi_axis_dist[0m][[1;35mapply[0m][[1;32mDEBUG[0m] - R3/4:[1;32mProcessor multi index: torch.Size([1, 1])[0m
[2025-05-02 14:18:33,108][[1;36mtensorcraft.mpi.distributions.multi_axis_dist[0m][[1;35mapply[0m][[1;32mDEBUG[0m] - R3/4:[1;32mMissing elements: [0, 1, 0][0m
[2025-05-02 14:18:33,111][[1;36mtensorcraft.mpi.distributions.multi_axis_dist[0m][[1;35mapply[0m][[1;32mDEBUG[0m] - R3/4:[1;32mN blocks per axis: [1, 10, 1][0m
[2025-05-02 14:18:33,116][[1;36mtensorcraft.mpi.distributions.multi_axis_dist[0m][[1;35mapply[0m][[1;32mDEBUG[0m] - R3/4:[1;32mPadded tensor shape: torch.Size([2, 11, 2])[0m
[2025-05-02 14:18:33,122][[1;36mtensorcraft.mpi.distributions.multi_axis_dist[0m][[1;35mapply[0m][[1;32mDEBUG[0m] - R3/4:[1;32mPermute tuple: (0, 2, 4, 1, 3, 5)[0m
[2025-05-02 14:18:33,125][[1;36mtensorcraft.mpi.distributions.multi_axis_dist[0m][[1;35mapply[0m][[1;32mDEBUG[0m] -

%px:  25%|██▌       | 1/4 [00:01<00:00,  9.61tasks/s]

[stdout:0] [2025-05-02 14:18:33,146][[1;36mtensorcraft.mpi.distributions.multi_axis_dist[0m][[1;35mapply[0m][[1;32mDEBUG[0m] - R0/4:[1;32mProcessor multi index: torch.Size([0, 0])[0m
[2025-05-02 14:18:33,148][[1;36mtensorcraft.mpi.distributions.multi_axis_dist[0m][[1;35mapply[0m][[1;32mDEBUG[0m] - R0/4:[1;32mMissing elements: [0, 1, 0][0m
[2025-05-02 14:18:33,150][[1;36mtensorcraft.mpi.distributions.multi_axis_dist[0m][[1;35mapply[0m][[1;32mDEBUG[0m] - R0/4:[1;32mN blocks per axis: [1, 10, 1][0m
[2025-05-02 14:18:33,152][[1;36mtensorcraft.mpi.distributions.multi_axis_dist[0m][[1;35mapply[0m][[1;32mDEBUG[0m] - R0/4:[1;32mPadded tensor shape: torch.Size([2, 11, 2])[0m
[2025-05-02 14:18:33,156][[1;36mtensorcraft.mpi.distributions.multi_axis_dist[0m][[1;35mapply[0m][[1;32mDEBUG[0m] - R0/4:[1;32mPermute tuple: (0, 2, 4, 1, 3, 5)[0m
[2025-05-02 14:18:33,158][[1;36mtensorcraft.mpi.distributions.multi_axis_dist[0m][[1;35mapply[0m][[1;32mDEBUG[0m] -

%px:  75%|███████▌  | 3/4 [00:01<00:00, 15.24tasks/s]

[stdout:2] [2025-05-02 14:18:33,254][[1;36mtensorcraft.mpi.distributions.multi_axis_dist[0m][[1;35mapply[0m][[1;32mDEBUG[0m] - R2/4:[1;32mProcessor multi index: torch.Size([1, 0])[0m
[2025-05-02 14:18:33,259][[1;36mtensorcraft.mpi.distributions.multi_axis_dist[0m][[1;35mapply[0m][[1;32mDEBUG[0m] - R2/4:[1;32mMissing elements: [0, 1, 0][0m
[2025-05-02 14:18:33,262][[1;36mtensorcraft.mpi.distributions.multi_axis_dist[0m][[1;35mapply[0m][[1;32mDEBUG[0m] - R2/4:[1;32mN blocks per axis: [1, 10, 1][0m
[2025-05-02 14:18:33,263][[1;36mtensorcraft.mpi.distributions.multi_axis_dist[0m][[1;35mapply[0m][[1;32mDEBUG[0m] - R2/4:[1;32mPadded tensor shape: torch.Size([2, 11, 2])[0m
[2025-05-02 14:18:33,268][[1;36mtensorcraft.mpi.distributions.multi_axis_dist[0m][[1;35mapply[0m][[1;32mDEBUG[0m] - R2/4:[1;32mPermute tuple: (0, 2, 4, 1, 3, 5)[0m
[2025-05-02 14:18:33,270][[1;36mtensorcraft.mpi.distributions.multi_axis_dist[0m][[1;35mapply[0m][[1;32mDEBUG[0m] -

%px: 100%|██████████| 4/4 [00:01<00:00,  2.48tasks/s]


In [None]:
%%px
new_dist, post_gather = dist.apply_allgather(x.shape, x_local, comm, gather_mesh_dim=0)
print(post_gather)

[stdout:3] [2025-05-02 14:18:33,506][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R3/4:[1;32mTensor axis: 1[0m
[2025-05-02 14:18:33,508][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R3/4:[1;32mMesh axis: 0[0m
[2025-05-02 14:18:33,509][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R3/4:[1;32mMappings: (0, 1)[0m
[2025-05-02 14:18:33,512][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R3/4:[1;32mTensor axis: 1[0m
[2025-05-02 14:18:33,514][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R3/4:[1;32mMesh axis: 0[0m
[2025-05-02 14:18:33,515][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R3/4:[1;32mMappings: (0, 1)[0m
[2025-05-02 14:18:33,522][[1;36mtensorcraft.mpi.distributions.multi_axis_dist[0m][[1;35m_apply_sin

[stdout:0] [2025-05-02 14:18:33,505][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R0/4:[1;32mTensor axis: 1[0m
[2025-05-02 14:18:33,509][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R0/4:[1;32mMesh axis: 0[0m
[2025-05-02 14:18:33,510][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R0/4:[1;32mMappings: (0, 1)[0m
[2025-05-02 14:18:33,515][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R0/4:[1;32mTensor axis: 1[0m
[2025-05-02 14:18:33,518][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R0/4:[1;32mMesh axis: 0[0m
[2025-05-02 14:18:33,522][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R0/4:[1;32mMappings: (0, 1)[0m
[2025-05-02 14:18:33,528][[1;36mtensorcraft.mpi.distributions.multi_axis_dist[0m][[1;35m_apply_sin

[stdout:1] [2025-05-02 14:18:33,512][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R1/4:[1;32mTensor axis: 1[0m
[2025-05-02 14:18:33,516][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R1/4:[1;32mMesh axis: 0[0m
[2025-05-02 14:18:33,522][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R1/4:[1;32mMappings: (0, 1)[0m
[2025-05-02 14:18:33,536][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R1/4:[1;32mTensor axis: 1[0m
[2025-05-02 14:18:33,544][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R1/4:[1;32mMesh axis: 0[0m
[2025-05-02 14:18:33,553][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R1/4:[1;32mMappings: (0, 1)[0m
[2025-05-02 14:18:33,565][[1;36mtensorcraft.mpi.distributions.multi_axis_dist[0m][[1;35m_apply_sin

[stdout:2] [2025-05-02 14:18:33,526][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R2/4:[1;32mTensor axis: 1[0m
[2025-05-02 14:18:33,530][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R2/4:[1;32mMesh axis: 0[0m
[2025-05-02 14:18:33,531][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R2/4:[1;32mMappings: (0, 1)[0m
[2025-05-02 14:18:33,536][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R2/4:[1;32mTensor axis: 1[0m
[2025-05-02 14:18:33,544][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R2/4:[1;32mMesh axis: 0[0m
[2025-05-02 14:18:33,548][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R2/4:[1;32mMappings: (0, 1)[0m
[2025-05-02 14:18:33,553][[1;36mtensorcraft.mpi.distributions.multi_axis_dist[0m][[1;35m_apply_sin

  index_tensor = torch.tensor(index)


  index_tensor = torch.tensor(index)


  index_tensor = torch.tensor(index)


  index_tensor = torch.tensor(index)


In [None]:
%%px
post_gather

[0;31mOut[2:4]: [0m
tensor([[[ 0,  1],
         [ 4,  5],
         [ 8,  9],
         [12, 13],
         [16, 17]],

        [[20, 21],
         [24, 25],
         [28, 29],
         [32, 33],
         [36, 37]]])

[0;31mOut[0:4]: [0m
tensor([[[ 0,  1],
         [ 4,  5],
         [ 8,  9],
         [12, 13],
         [16, 17]],

        [[20, 21],
         [24, 25],
         [28, 29],
         [32, 33],
         [36, 37]]])

[0;31mOut[1:4]: [0m
tensor([[[ 2,  3],
         [ 6,  7],
         [10, 11],
         [14, 15],
         [18, 19]],

        [[22, 23],
         [26, 27],
         [30, 31],
         [34, 35],
         [38, 39]]])

[0;31mOut[3:4]: [0m
tensor([[[ 2,  3],
         [ 6,  7],
         [10, 11],
         [14, 15],
         [18, 19]],

        [[22, 23],
         [26, 27],
         [30, 31],
         [34, 35],
         [38, 39]]])

In [None]:
%%px
new_dist, post_gather = dist.apply_allgather(x.shape, x_local, comm, gather_mesh_dim=1)
print(post_gather)

[stdout:0] [2025-05-02 14:19:16,455][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R0/4:[1;32mTensor axis: 1[0m
[2025-05-02 14:19:16,457][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R0/4:[1;32mMesh axis: 1[0m
[2025-05-02 14:19:16,458][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R0/4:[1;32mMappings: (0, 1)[0m
[2025-05-02 14:19:16,461][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R0/4:[1;32mTensor axis: 1[0m
[2025-05-02 14:19:16,464][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R0/4:[1;32mMesh axis: 1[0m
[2025-05-02 14:19:16,465][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R0/4:[1;32mMappings: (0, 1)[0m
[2025-05-02 14:19:16,466][[1;36mtensorcraft.mpi.distributions.multi_axis_dist[0m][[1;35m_apply_sin

[stdout:3] [2025-05-02 14:19:16,456][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R3/4:[1;32mTensor axis: 1[0m
[2025-05-02 14:19:16,458][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R3/4:[1;32mMesh axis: 1[0m
[2025-05-02 14:19:16,459][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R3/4:[1;32mMappings: (0, 1)[0m
[2025-05-02 14:19:16,461][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R3/4:[1;32mTensor axis: 1[0m
[2025-05-02 14:19:16,463][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R3/4:[1;32mMesh axis: 1[0m
[2025-05-02 14:19:16,464][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R3/4:[1;32mMappings: (0, 1)[0m
[2025-05-02 14:19:16,465][[1;36mtensorcraft.mpi.distributions.multi_axis_dist[0m][[1;35m_apply_sin

[stdout:1] [2025-05-02 14:19:16,457][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R1/4:[1;32mTensor axis: 1[0m
[2025-05-02 14:19:16,460][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R1/4:[1;32mMesh axis: 1[0m
[2025-05-02 14:19:16,461][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R1/4:[1;32mMappings: (0, 1)[0m
[2025-05-02 14:19:16,465][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R1/4:[1;32mTensor axis: 1[0m
[2025-05-02 14:19:16,465][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R1/4:[1;32mMesh axis: 1[0m
[2025-05-02 14:19:16,465][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R1/4:[1;32mMappings: (0, 1)[0m
[2025-05-02 14:19:16,467][[1;36mtensorcraft.mpi.distributions.multi_axis_dist[0m][[1;35m_apply_sin

[stdout:2] [2025-05-02 14:19:16,465][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R2/4:[1;32mTensor axis: 1[0m
[2025-05-02 14:19:16,471][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R2/4:[1;32mMesh axis: 1[0m
[2025-05-02 14:19:16,472][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R2/4:[1;32mMappings: (0, 1)[0m
[2025-05-02 14:19:16,473][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R2/4:[1;32mTensor axis: 1[0m
[2025-05-02 14:19:16,473][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R2/4:[1;32mMesh axis: 1[0m
[2025-05-02 14:19:16,474][[1;36mtensorcraft.distributions.multi_axis[0m][[1;35mallgather[0m][[1;32mDEBUG[0m] - R2/4:[1;32mMappings: (0, 1)[0m
[2025-05-02 14:19:16,475][[1;36mtensorcraft.mpi.distributions.multi_axis_dist[0m][[1;35m_apply_sin

In [None]:
%%px
print(post_gather)

[stdout:2] tensor([[[ 4,  5],
         [ 6,  7],
         [12, 13],
         [14, 15]],

        [[24, 25],
         [26, 27],
         [32, 33],
         [34, 35]]])


[stdout:0] tensor([[[ 0,  1],
         [ 2,  3],
         [ 8,  9],
         [10, 11],
         [16, 17],
         [18, 19]],

        [[20, 21],
         [22, 23],
         [28, 29],
         [30, 31],
         [36, 37],
         [38, 39]]])


[stdout:3] tensor([[[ 4,  5],
         [ 6,  7],
         [12, 13],
         [14, 15]],

        [[24, 25],
         [26, 27],
         [32, 33],
         [34, 35]]])


[stdout:1] tensor([[[ 0,  1],
         [ 2,  3],
         [ 8,  9],
         [10, 11],
         [16, 17],
         [18, 19]],

        [[20, 21],
         [22, 23],
         [28, 29],
         [30, 31],
         [36, 37],
         [38, 39]]])
