# Optim Bottom-Up

## Setup


In [None]:
%load_ext autoreload
%autoreload 2
from importlib import reload
import logging
reload(logging)

import torch
import matplotlib.pyplot as plt
from pprint import pprint

import tensorcraft as tc

log = logging.getLogger("tensorcraft")
log.setLevel(logging.INFO)

06-03-2025 04:41:54 : INFO : compiler : __init__ -- Grammar file loaded successfully.
06-03-2025 04:41:54 : INFO : compiler : __init__ -- Parser object created successfully.


## Redistributors

Given a tensor shape, a starting distribution and a target distribution, creates a sequence of collective ops to reach the target dist while optimizing for different metrics.

### Memory efficient redistributions example

In [None]:
tensor_shape = torch.Size([8, 8, 8, 4])
mesh = torch.Size([4,2,4])
dist = tc.dist.MultiAxisDist(mesh, ((0,1), (), (), ()), 1)
print(dist)

print("Option 1")
new_dist, comm_vol, n_procs = dist.alltoall(tensor_shape, 0, 1, minor=True)
print(f"Step 0: alltoall(0,1) - {new_dist}, comm_vol, n_procs: {comm_vol, n_procs}")
new_dist, comm_vol, n_procs = new_dist.alltoall(tensor_shape, 0, 2)
print(f"Step 1: alltoall(0,2) - {new_dist}, comm_vol, n_procs: {comm_vol, n_procs}")

print("Option 2")
new_dist, comm_vol, n_procs = dist.split(tensor_shape, 3, 2, 1)
print(f"Step 0: split(3,2)  {new_dist}, comm_vol, n_procs: {comm_vol, n_procs}")
new_dist, comm_vol, n_procs = new_dist.alltoall(tensor_shape, 0, 1, minor=True)
print(f"Step 1: alltoall(0,1) - {new_dist}, comm_vol, n_procs: {comm_vol, n_procs}")
new_dist, comm_vol, n_procs = new_dist.alltoall(tensor_shape, 0, 2)
print(f"Step 2: alltoall(0,2) - {new_dist}, comm_vol, n_procs: {comm_vol, n_procs}")
new_dist, comm_vol, n_procs = new_dist.allgather(tensor_shape, 2)
print(f"Step 3: allgather(2)  {new_dist}, comm_vol, n_procs: {comm_vol, n_procs}")

D_[4,2,4]⊥{(0,1),∅,∅,∅}(1,∅,∅,∅)
Option 1
Step 0: alltoall(0,1) - D_[4,2,4]⊥{0,1,∅,∅}(2,1,∅,∅), comm_vol, n_procs: (256, 2)
Step 1: alltoall(0,2) - D_[4,2,4]⊥{∅,1,0,∅}(∅,1,2,∅), comm_vol, n_procs: (256, 4)
Option 2
Step 0: split(3,2)  D_[4,2,4]⊥{(0,1),∅,∅,2}(1,∅,∅,1), comm_vol, n_procs: (0, 0)
Step 1: alltoall(0,1) - D_[4,2,4]⊥{0,1,∅,2}(2,1,∅,1), comm_vol, n_procs: (64, 2)
Step 2: alltoall(0,2) - D_[4,2,4]⊥{∅,1,0,2}(∅,1,2,1), comm_vol, n_procs: (64, 4)
Step 3: allgather(2)  D_[4,2,4]⊥{∅,1,0,∅}(∅,1,2,∅), comm_vol, n_procs: (64, 4)


### Naive Gather Split

Simplest redistributor. Just allgathers, then splits. Should be both communication and memory ineficient. 

In [None]:
tensor_shape = torch.Size((50,50))
mesh = torch.Size((2,4))
dist = tc.dist.MultiAxisDist(mesh, ((0,), ()), (10, 10))
target_dist = tc.dist.MultiAxisDist(mesh, ((), (0,)), (10, 10))
naive_rdist = tc.optim.NaiveGathererRedist(tc.optim.IdealLowerBoundsCM())

sequence, cost = naive_rdist.redistribute(tensor_shape, dist, target_dist)
columns = ["Step #", "Operation", "Distribution", "Cost"]
print(f"{columns[0]:<5} | {columns[1]:<15} {columns[2]:<30} {columns[3]}")
for i, (op, s_dist, s_cost) in enumerate(sequence):
    print(f"{i:<5}   {op:<15} {str(s_dist):<30} {s_cost}")
print(f"Total cost: {cost}")


06-03-2025 04:49:05 : INFO : naive_gatherer : _redistribute_multi_axis -- Dist D_[2,4]⊥{∅,∅}(∅,∅), volume: 1500, n_procs 2
Step # | Operation       Distribution                   Cost
0                       D_[2,4]⊥{0,∅}(10,∅)            Cost(latency=0, bandwidth=0, computation=0, max_memory_delta=0)
1       allgather_*     D_[2,4]⊥{∅,∅}(∅,∅)             Cost(latency=1, bandwidth=1500.0, computation=0, max_memory_delta=1000)
2       split_*         D_[2,4]⊥{∅,0}(∅,10)            Cost(latency=0, bandwidth=0, computation=0, max_memory_delta=-1000)
Total cost: 2501.0


In [None]:
tensor_shape = torch.Size([50, 50])
mesh = torch.Size([2, 4])
dist = tc.dist.MultiAxisDist(mesh, ((0,), ()), 2) 
target_dist = tc.dist.MultiAxisDist(mesh, ((), (0,)), 2)
mem_constrained_dist = tc.optim.MemoryConstrainedRedist(tc.optim.IdealLowerBoundsCM())
sequence, cost = mem_constrained_dist.redistribute(tensor_shape, dist, target_dist)
columns = ["Step #", "Operation", "Distribution", "Cost"]
print(f"{columns[0]:<5} | {columns[1]:<15} {columns[2]:<30} {columns[3]}")
for i, (op, s_dist, s_cost) in enumerate(sequence):
    print(f"{i:<5}   {op:<15} {str(s_dist):<30} {s_cost}")
print(f"Total cost: {cost}")

06-03-2025 04:49:43 : INFO : mem_const : _redistribute_multi_axis -- D_[2,4]⊥{0,∅}(2,∅)
06-03-2025 04:49:43 : INFO : mem_const : _redistribute_multi_axis -- D_[2,4]⊥{∅,0}(∅,2)


Step # | Operation       Distribution                   Cost
0                       D_[2,4]⊥{0,∅}(2,∅)             Cost(latency=0, bandwidth=0, computation=0, max_memory_delta=0)
1       alltoall_0_1_-1 D_[2,4]⊥{∅,0}(∅,2)             Cost(latency=1, bandwidth=1300.0, computation=0, max_memory_delta=0)
Total cost: 1301.0


In [None]:
tensor_shape = torch.Size([8, 8, 8, 4])
mesh = torch.Size([4,2,4])
dist = tc.dist.MultiAxisDist(mesh, ((0,1), (), (), ()), 1) 
target_dist = tc.dist.MultiAxisDist(mesh, ((), (1,), (0,), ()), (-1, 4, 2, -1))
mem_constrained_dist = tc.optim.MemoryConstrainedRedist(tc.optim.IdealLowerBoundsCM())
sequence, cost = mem_constrained_dist.redistribute(tensor_shape, dist, target_dist)
print(f"{columns[0]:<5} | {columns[1]:<30} {columns[2]:<50} {columns[3]}")
for i, (op, s_dist, s_cost) in enumerate(sequence):
    print(f"{i:<5}   {op:<30} {str(s_dist):<50} {s_cost}")
print(f"Total cost: {cost}")

06-03-2025 04:50:41 : INFO : mem_const : _redistribute_multi_axis -- D_[4,2,4]⊥{(0,1),∅,∅,∅}(1,∅,∅,∅)
06-03-2025 04:50:41 : INFO : mem_const : _redistribute_multi_axis -- D_[4,2,4]⊥{∅,1,0,∅}(∅,4,2,∅)


Step # | Operation                      Distribution                                       Cost
0                                      D_[4,2,4]⊥{(0,1),∅,∅,∅}(1,∅,∅,∅)                   Cost(latency=0, bandwidth=0, computation=0, max_memory_delta=0)
1       alltoall_minor_0_1_4           D_[4,2,4]⊥{0,1,∅,∅}(2,4,∅,∅)                       Cost(latency=1, bandwidth=256.0, computation=0, max_memory_delta=0)
2       split_minor_1_2_2              D_[4,2,4]⊥{0,(1,2),∅,∅}(2,1,∅,∅)                   Cost(latency=0, bandwidth=0, computation=0, max_memory_delta=-192)
3       alltoall_0_2_-1                D_[4,2,4]⊥{∅,(1,2),0,∅}(∅,1,2,∅)                   Cost(latency=2, bandwidth=192.0, computation=0, max_memory_delta=0)
4       allgather_2                    D_[4,2,4]⊥{∅,1,0,∅}(∅,4,2,∅)                       Cost(latency=2, bandwidth=192.0, computation=0, max_memory_delta=0)
Total cost: 645.0
