# Optim Bottom-Up

## Setup


In [None]:
%load_ext autoreload
%autoreload 2
from importlib import reload
import logging
reload(logging)

import torch

import tensorcraft as tc

log = logging.getLogger("tensorcraft")
log.setLevel(logging.INFO)

28-03-2025 11:26:22 : INFO : compiler : __init__ -- Grammar file loaded successfully.
28-03-2025 11:26:23 : INFO : compiler : __init__ -- Parser object created successfully.


In [None]:
columns = ["Step #", "Operation", "Distribution", "Cost", "Memory Usage [MB]"]
columns_width = [8, 20, 40, 8, 8]

type_size = 8

def print_path(path: list[tuple[str, any, float]], tensor_shape: torch.Size):

    line = " & ".join(f"{col:<{width}}" for col, width in zip(columns, columns_width)) + " \\\\"
    print(line)
    for i, (op, s_dist, s_cost) in enumerate(path):
        line = f"{i:<{columns_width[0]}} & {op:<{columns_width[1]}} & {str(s_dist):<{columns_width[2]}} & {s_cost:<{columns_width[3]}} & {s_dist.maxNumElements(tensor_shape) * 8 / 10**6:<{columns_width[4]}} \\\\" 
        print(line)


## Redistributors

Given a tensor shape, a starting distribution and a target distribution, creates a sequence of collective ops to reach the target dist while optimizing for different metrics.

### Memory efficient redistributions example

In [None]:
tensor_shape = torch.Size([8, 8, 8, 4])
mesh = torch.Size([4,2,4])
dist = tc.dist.MultiAxisDist(mesh, ((0,1), (), (), ()), 1)
print(dist)

print("Option 1")
new_dist, comm_vol, n_procs = dist.alltoall(tensor_shape, 0, 1, minor=True)
print(f"Step 0: alltoall(0,1) - {new_dist}, comm_vol, n_procs: {comm_vol, n_procs}")
new_dist, comm_vol, n_procs = new_dist.alltoall(tensor_shape, 0, 2)
print(f"Step 1: alltoall(0,2) - {new_dist}, comm_vol, n_procs: {comm_vol, n_procs}")

print("Option 2")
new_dist, comm_vol, n_procs = dist.split(tensor_shape, 3, 2, 1)
print(f"Step 0: split(3,2)  {new_dist}, comm_vol, n_procs: {comm_vol, n_procs}")
new_dist, comm_vol, n_procs = new_dist.alltoall(tensor_shape, 0, 1, minor=True)
print(f"Step 1: alltoall(0,1) - {new_dist}, comm_vol, n_procs: {comm_vol, n_procs}")
new_dist, comm_vol, n_procs = new_dist.alltoall(tensor_shape, 0, 2)
print(f"Step 2: alltoall(0,2) - {new_dist}, comm_vol, n_procs: {comm_vol, n_procs}")
new_dist, comm_vol, n_procs = new_dist.allgather(tensor_shape, 2)
print(f"Step 3: allgather(2)  {new_dist}, comm_vol, n_procs: {comm_vol, n_procs}")

D_[4,2,4]⊥{(0,1),∅,∅,∅}(1,∅,∅,∅)
Option 1
Step 0: alltoall(0,1) - D_[4,2,4]⊥{0,1,∅,∅}(2,1,∅,∅), comm_vol, n_procs: (512, 2)
Step 1: alltoall(0,2) - D_[4,2,4]⊥{∅,1,0,∅}(∅,1,2,∅), comm_vol, n_procs: (1024, 4)
Option 2
Step 0: split(3,2)  D_[4,2,4]⊥{(0,1),∅,∅,2}(1,∅,∅,1), comm_vol, n_procs: (0, 0)
Step 1: alltoall(0,1) - D_[4,2,4]⊥{0,1,∅,2}(2,1,∅,1), comm_vol, n_procs: (128, 2)
Step 2: alltoall(0,2) - D_[4,2,4]⊥{∅,1,0,2}(∅,1,2,1), comm_vol, n_procs: (256, 4)
Step 3: allgather(2)  D_[4,2,4]⊥{∅,1,0,∅}(∅,1,2,∅), comm_vol, n_procs: (256, 4)


### Naive Gather Split

Simplest redistributor. Just allgathers, then splits. Should be both communication and memory ineficient. 

In [None]:

tensor_shape = torch.Size((50,50))
mesh = torch.Size((2,4))
dist = tc.dist.MultiAxisDist(mesh, ((0,), ()), (10, 10))
target_dist = tc.dist.MultiAxisDist(mesh, ((), (0,)), (10, 10))
naive_rdist = tc.optim.NaiveGathererRedist(tc.optim.IdealLowerBoundsCM())

sequence, total_cost = naive_rdist.redistribute(tensor_shape, dist, target_dist)
print_path(sequence, tensor_shape)
print(f"Total cost: {total_cost}")


28-03-2025 11:30:20 : INFO : naive_gatherer : _redistribute_multi_axis -- Dist D_[2,4]⊥{∅,∅}(∅,∅), volume: 3000, n_procs 2
Step #   & Operation            & Distribution                             & Cost     & Memory Usage [MB] \\
0        &                      & D_[2,4]⊥{0,∅}(10,∅)                      & 0        & 0.012    \\
1        & allgather_*          & D_[2,4]⊥{∅,∅}(∅,∅)                       & 2501.0   & 0.02     \\
2        & split_*              & D_[2,4]⊥{∅,0}(∅,10)                      & 0        & 0.012    \\
Total cost: 2501.0


## Memory Constrained

In [None]:
tensor_shape = torch.Size([100, 100])
mesh = torch.Size([2, 4])
dist = tc.dist.MultiAxisDist(mesh, ((0,), ()), 2) 
target_dist = tc.dist.MultiAxisDist(mesh, ((), (0,)), 2)

mem_constrained_dist = tc.optim.MemoryConstrainedRedist(tc.optim.IdealLowerBoundsCM())
sequence, total_cost = mem_constrained_dist.redistribute(tensor_shape, dist, target_dist)

print_path(sequence, tensor_shape)
print(f"Total cost: {total_cost}")

28-03-2025 11:30:27 : INFO : mem_const : _redistribute_multi_axis -- D_[2,4]⊥{0,∅}(2,∅)
28-03-2025 11:30:27 : INFO : mem_const : _redistribute_multi_axis -- D_[2,4]⊥{∅,0}(∅,2)


Step #   & Operation            & Distribution                             & Cost     & Memory Usage [MB] \\
0        &                      & D_[2,4]⊥{0,∅}(2,∅)                       & 0        & 0.04     \\
1        & alltoall_0_1_-1      & D_[2,4]⊥{∅,0}(∅,2)                       & 5001.0   & 0.04     \\
Total cost: 5001.0


In [None]:
tensor_shape = torch.Size([800, 400, 800, 4])
mesh = torch.Size([4,2,4])
dist = tc.dist.MultiAxisDist(mesh, ((0,1), (), (), ()), 100) 
target_dist = tc.dist.MultiAxisDist(mesh, ((), (1,), (0,), ()), (-1, 100, 100, -1))


mem_constrained_dist = tc.optim.MemoryConstrainedRedist(tc.optim.IdealLowerBoundsCM(), max_depth=5, alpha=0.01, beta=64/(200 * 1000^3))
sequence, total_cost = mem_constrained_dist.redistribute(tensor_shape, dist, target_dist)

print_path(sequence, tensor_shape)
print(f"Total cost: {total_cost}")


28-03-2025 11:32:09 : INFO : mem_const : _redistribute_multi_axis -- D_[4,2,4]⊥{(0,1),∅,∅,∅}(100,∅,∅,∅)
28-03-2025 11:32:09 : INFO : mem_const : _redistribute_multi_axis -- D_[4,2,4]⊥{∅,1,0,∅}(∅,100,100,∅)


Step #   & Operation            & Distribution                             & Cost     & Memory Usage [MB] \\
0        &                      & D_[4,2,4]⊥{(0,1),∅,∅,∅}(100,∅,∅,∅)       & 0        & 1024.0   \\
1        & split_3_2_1          & D_[4,2,4]⊥{(0,1),∅,∅,2}(100,∅,∅,1)       & 0.0      & 256.0    \\
2        & alltoall_minor_0_1_-1 & D_[4,2,4]⊥{0,1,∅,2}(200,100,∅,1)         & 10239.856402303965 & 256.0    \\
3        & alltoall_0_2_100     & D_[4,2,4]⊥{∅,1,0,2}(∅,100,100,1)         & 30719.559206911894 & 256.0    \\
4        & allgather_2          & D_[4,2,4]⊥{∅,1,0,∅}(∅,100,100,∅)         & 30719.559206911894 & 1024.0   \\
Total cost: 71678.97481612775


Tuteja, Keshvi (SCC) <keshvi.tuteja@kit.edu>