# Optim Bottom-Up

## Setup


In [None]:
%load_ext autoreload
%autoreload 2
from importlib import reload
import logging

import torch

import tensorcraft as tc

tc.set_logger_config(level = logging.INFO)

ALPHA = 1e-6 # 1 micro second of latency (Maybe bigger)
BETA=64.0/( 200.0 * 1e9) # 200 GBits per second bandwidth

NODE_LIMIT = 1000
TOP_K = 15
MAX_DEPTH = 6
PATH_COST_W = 1.02
ESTIMATE_W = 1.00

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
columns = ["Step #", "Operation", "Distribution", "Cost[s]", "Memory Usage [MB]"]
columns_width = [8, 20, 40, 8, 8]

type_size = 8

def print_path(path: list[tuple[str, any, float]], tensor_shape: torch.Size):

    line = " & ".join(f"{col:<{width}}" for col, width in zip(columns, columns_width)) + " \\\\"
    print(line)
    for i, (op, s_dist, s_cost) in enumerate(path):
        line = f"{i:<{columns_width[0]}} & {op:<{columns_width[1]}} & {s_dist.latexStr():<{columns_width[2]}} & {s_cost:<{columns_width[3]}} & {s_dist.maxNumElements(tensor_shape) * 8 / 10**6:<{columns_width[4]}} \\\\" 
        print(line)


In [None]:

def mem_constrained_filter(shape: torch.Size, start_dist: tc.dist.MultiAxisDist, target_dist: tc.dist.MultiAxisDist, current_dist: tc.dist.MultiAxisDist ) -> bool:
    max_n_elements = max(start_dist.maxNumElements(shape), target_dist.maxNumElements(shape))
    return max_n_elements < current_dist.maxNumElements(shape)

## Redistributors

Given a tensor shape, a starting distribution and a target distribution, creates a sequence of collective ops to reach the target dist while optimizing for different metrics.

### Problem 1 ( Tiled Matrix to Row cyclic)

Shifting from a tiled matrix, to a row cyclic distribution

In [None]:
tensor_shape = torch.Size([10000000, 512])
mesh = torch.Size([2,4])
dist = tc.dist.MultiAxisDist(mesh, ((0,), (1,),), (5000000, 128)) 
dist.compatible(tensor_shape)
target_dist = tc.dist.MultiAxisDist(mesh, ((0,1), None), 1250000)
target_dist.compatible(tensor_shape)

True

In [None]:

naive_rdist = tc.optim.NaiveGathererRedist(tc.optim.IdealLowerBoundsCM(), alpha=ALPHA, beta=BETA)

sequence, total_cost = naive_rdist.redistribute(tensor_shape, dist, target_dist)
print_path(sequence, tensor_shape)
print(f"Total cost: {total_cost:.2f}s")


Step #   & Operation            & Distribution                             & Cost[s]  & Memory Usage [MB] \\
0        &                      & $T_{\perp\{ 0,1 \}(5000000,128)}$        & 0        & 5120.0   \\
1        & allgather_*          & $T_{\perp\{ \emptyset,\emptyset \}(\emptyset,\emptyset)}$ & 1.433603 & 40960.0  \\
2        & split_*              & $T_{\perp\{ (0,1),\emptyset \}(1250000,\emptyset)}$ & 0        & 5120.0   \\
Total cost: 1.43s


In [None]:
%%time
astar_redist = tc.optim.AStarRedistributor(tc.optim.IdealLowerBoundsCM(), alpha=ALPHA, beta=BETA, path_cost_w=PATH_COST_W, estimate_w=ESTIMATE_W, top_k=TOP_K, node_limit=NODE_LIMIT, max_depth=MAX_DEPTH)
sequence, total_cost = astar_redist.redistribute(tensor_shape, dist, target_dist)

print_path(sequence, tensor_shape)
print(f"Total cost: {total_cost:.3f}s")

[2025-05-07 15:11:18,952][[1;36mtensorcraft.util.route_finder[0m][[1;35mfind_routes[0m][[1;37mINFO[0m] - [1;37mExplored 69 nodes, found 15 possible paths.[0m
[2025-05-07 15:11:18,952][[1;36m[1;36mtensorcraft.util.route_finder[0m[0m][[1;35m[1;35mfind_routes[0m[0m][[1;30m[1;37mINFO[0m[0m] - [1;30m[1;37mExplored 69 nodes, found 15 possible paths.[0m[0m
[2025-05-07 15:11:18,952][[1;36m[1;36m[1;36mtensorcraft.util.route_finder[0m[0m[0m][[1;35m[1;35m[1;35mfind_routes[0m[0m[0m][[1;30m[1;30m[1;37mINFO[0m[0m[0m] - [1;30m[1;30m[1;37mExplored 69 nodes, found 15 possible paths.[0m[0m[0m
[2025-05-07 15:11:18,952][[1;36m[1;36m[1;36m[1;36mtensorcraft.util.route_finder[0m[0m[0m[0m][[1;35m[1;35m[1;35m[1;35mfind_routes[0m[0m[0m[0m][[1;30m[1;30m[1;30m[1;37mINFO[0m[0m[0m[0m] - [1;30m[1;30m[1;30m[1;37mExplored 69 nodes, found 15 possible paths.[0m[0m[0m[0m
Step #   & Operation            & Distribution                          

In [None]:
%%time
astar_redist = tc.optim.AStarRedistributor(tc.optim.IdealLowerBoundsCM(), node_filter=mem_constrained_filter, alpha=ALPHA, beta=BETA, path_cost_w=PATH_COST_W, estimate_w=ESTIMATE_W, top_k=TOP_K, node_limit=NODE_LIMIT, max_depth=MAX_DEPTH)
sequence, total_cost = astar_redist.redistribute(tensor_shape, dist, target_dist)

print_path(sequence, tensor_shape)
print(f"Total cost: {total_cost:.3f}s")

[2025-05-07 15:11:32,025][[1;36mtensorcraft.util.route_finder[0m][[1;35mfind_routes[0m][[1;37mINFO[0m] - [1;37mExplored 33 nodes, found 9 possible paths.[0m
[2025-05-07 15:11:32,025][[1;36m[1;36mtensorcraft.util.route_finder[0m[0m][[1;35m[1;35mfind_routes[0m[0m][[1;30m[1;37mINFO[0m[0m] - [1;30m[1;37mExplored 33 nodes, found 9 possible paths.[0m[0m
[2025-05-07 15:11:32,025][[1;36m[1;36m[1;36mtensorcraft.util.route_finder[0m[0m[0m][[1;35m[1;35m[1;35mfind_routes[0m[0m[0m][[1;30m[1;30m[1;37mINFO[0m[0m[0m] - [1;30m[1;30m[1;37mExplored 33 nodes, found 9 possible paths.[0m[0m[0m
[2025-05-07 15:11:32,025][[1;36m[1;36m[1;36m[1;36mtensorcraft.util.route_finder[0m[0m[0m[0m][[1;35m[1;35m[1;35m[1;35mfind_routes[0m[0m[0m[0m][[1;30m[1;30m[1;30m[1;37mINFO[0m[0m[0m[0m] - [1;30m[1;30m[1;30m[1;37mExplored 33 nodes, found 9 possible paths.[0m[0m[0m[0m
Step #   & Operation            & Distribution                             &

In [None]:
%%time
astar_redist = tc.optim.AStarRedistributor(tc.optim.IdealLowerBoundsCM(), alpha=ALPHA, beta=BETA, path_cost_w=PATH_COST_W, estimate_w=ESTIMATE_W, top_k=1, node_limit=NODE_LIMIT, max_depth=MAX_DEPTH)
sequence, total_cost = astar_redist.redistribute(tensor_shape, dist, target_dist)

print_path(sequence, tensor_shape)
print(f"Total cost: {total_cost:.3f}s")

[2025-05-07 15:11:32,232][[1;36mtensorcraft.util.route_finder[0m][[1;35mfind_routes[0m][[1;37mINFO[0m] - [1;37mExplored 2 nodes, found 2 possible paths.[0m
[2025-05-07 15:11:32,232][[1;36m[1;36mtensorcraft.util.route_finder[0m[0m][[1;35m[1;35mfind_routes[0m[0m][[1;30m[1;37mINFO[0m[0m] - [1;30m[1;37mExplored 2 nodes, found 2 possible paths.[0m[0m
[2025-05-07 15:11:32,232][[1;36m[1;36m[1;36mtensorcraft.util.route_finder[0m[0m[0m][[1;35m[1;35m[1;35mfind_routes[0m[0m[0m][[1;30m[1;30m[1;37mINFO[0m[0m[0m] - [1;30m[1;30m[1;37mExplored 2 nodes, found 2 possible paths.[0m[0m[0m
[2025-05-07 15:11:32,232][[1;36m[1;36m[1;36m[1;36mtensorcraft.util.route_finder[0m[0m[0m[0m][[1;35m[1;35m[1;35m[1;35mfind_routes[0m[0m[0m[0m][[1;30m[1;30m[1;30m[1;37mINFO[0m[0m[0m[0m] - [1;30m[1;30m[1;30m[1;37mExplored 2 nodes, found 2 possible paths.[0m[0m[0m[0m
Step #   & Operation            & Distribution                             & Cos

## Problem 2

In [None]:
tensor_shape = torch.Size([1000, 400, 400, 8])
dist = tc.dist.MultiAxisDist(torch.Size([4, 4, 2]), ((0,), None, None, (1, 2)), 1)
target_dist = tc.dist.MultiAxisDist(
    torch.Size([4, 4, 2]), ((0,), (1,), (2,), None), (1, 100, 100, None)
)

In [None]:
%%time
naive_rdist = tc.optim.NaiveGathererRedist(tc.optim.IdealLowerBoundsCM(), alpha=ALPHA, beta=BETA)

sequence, total_cost = naive_rdist.redistribute(tensor_shape, dist, target_dist)
print_path(sequence, tensor_shape)
print(f"Total cost: {total_cost:.2f}s")

Step #   & Operation            & Distribution                             & Cost[s]  & Memory Usage [MB] \\
0        &                      & $T_{\perp\{ 0,\emptyset,\emptyset,(1,2) \}(1,\emptyset,\emptyset,1)}$ & 0        & 320.0    \\
1        & allgather_*          & $T_{\perp\{ \emptyset,\emptyset,\emptyset,\emptyset \}(\emptyset,\emptyset,\emptyset,\emptyset)}$ & 0.39680499999999996 & 10240.0  \\
2        & split_*              & $T_{\perp\{ 0,1,2,\emptyset \}(1,100,100,\emptyset)}$ & 0        & 320.0    \\
Total cost: 0.40s
CPU times: user 5.67 ms, sys: 887 μs, total: 6.55 ms
Wall time: 5.99 ms


In [None]:
%%time
mem_constrained_dist = tc.optim.AStarRedistributor(tc.optim.IdealLowerBoundsCM(), alpha=ALPHA, beta=BETA, node_filter=mem_constrained_filter, path_cost_w=PATH_COST_W, estimate_w=ESTIMATE_W, top_k=TOP_K, node_limit=NODE_LIMIT, max_depth=MAX_DEPTH)
sequence, total_cost = mem_constrained_dist.redistribute(tensor_shape, dist, target_dist)

print_path(sequence, tensor_shape)
print(f"Total cost: {total_cost:.3f}s")

[2025-05-07 15:41:06,822][[1;36mtensorcraft.util.route_finder[0m][[1;35mfind_routes[0m][[1;37mINFO[0m] - [1;37mExplored 891 nodes, found 9 possible paths.[0m
[2025-05-07 15:41:06,822][[1;36m[1;36mtensorcraft.util.route_finder[0m[0m][[1;35m[1;35mfind_routes[0m[0m][[1;30m[1;37mINFO[0m[0m] - [1;30m[1;37mExplored 891 nodes, found 9 possible paths.[0m[0m
[2025-05-07 15:41:06,822][[1;36m[1;36m[1;36mtensorcraft.util.route_finder[0m[0m[0m][[1;35m[1;35m[1;35mfind_routes[0m[0m[0m][[1;30m[1;30m[1;37mINFO[0m[0m[0m] - [1;30m[1;30m[1;37mExplored 891 nodes, found 9 possible paths.[0m[0m[0m
[2025-05-07 15:41:06,822][[1;36m[1;36m[1;36m[1;36mtensorcraft.util.route_finder[0m[0m[0m[0m][[1;35m[1;35m[1;35m[1;35mfind_routes[0m[0m[0m[0m][[1;30m[1;30m[1;30m[1;37mINFO[0m[0m[0m[0m] - [1;30m[1;30m[1;30m[1;37mExplored 891 nodes, found 9 possible paths.[0m[0m[0m[0m
Step #   & Operation            & Distribution                          

In [None]:
%%time
astar_redist = tc.optim.AStarRedistributor(tc.optim.IdealLowerBoundsCM(), alpha=ALPHA, beta=BETA, path_cost_w=PATH_COST_W, estimate_w=ESTIMATE_W, top_k=TOP_K, node_limit=NODE_LIMIT, max_depth=MAX_DEPTH)
sequence, total_cost = astar_redist.redistribute(tensor_shape, dist, target_dist)

print_path(sequence, tensor_shape)
print(f"Total cost: {total_cost:.3f}s")

[2025-05-07 15:42:02,718][[1;36mtensorcraft.util.route_finder[0m][[1;35mfind_routes[0m][[1;37mINFO[0m] - [1;37mExplored 1000 nodes, found 13 possible paths.[0m
[2025-05-07 15:42:02,718][[1;36m[1;36mtensorcraft.util.route_finder[0m[0m][[1;35m[1;35mfind_routes[0m[0m][[1;30m[1;37mINFO[0m[0m] - [1;30m[1;37mExplored 1000 nodes, found 13 possible paths.[0m[0m
[2025-05-07 15:42:02,718][[1;36m[1;36m[1;36mtensorcraft.util.route_finder[0m[0m[0m][[1;35m[1;35m[1;35mfind_routes[0m[0m[0m][[1;30m[1;30m[1;37mINFO[0m[0m[0m] - [1;30m[1;30m[1;37mExplored 1000 nodes, found 13 possible paths.[0m[0m[0m
[2025-05-07 15:42:02,718][[1;36m[1;36m[1;36m[1;36mtensorcraft.util.route_finder[0m[0m[0m[0m][[1;35m[1;35m[1;35m[1;35mfind_routes[0m[0m[0m[0m][[1;30m[1;30m[1;30m[1;37mINFO[0m[0m[0m[0m] - [1;30m[1;30m[1;30m[1;37mExplored 1000 nodes, found 13 possible paths.[0m[0m[0m[0m
Step #   & Operation            & Distribution                  

## Problem 3

In [None]:
tensor_shape = torch.Size([1000, 1000, 1000])
dist = tc.dist.MultiAxisDist(torch.Size([2, 2, 2]), ((0,), (1,), (2,)), 1)
target_dist = tc.dist.MultiAxisDist(
    torch.Size([2, 2, 2]), ((), (0,1,2), ()), 125
)

In [None]:
%%time
naive_rdist = tc.optim.NaiveGathererRedist(tc.optim.IdealLowerBoundsCM(), alpha=ALPHA, beta=BETA)

sequence, total_cost = naive_rdist.redistribute(tensor_shape, dist, target_dist)
print_path(sequence, tensor_shape)
print(f"Total cost: {total_cost:.2f}s")

Step #   & Operation            & Distribution                             & Cost[s]  & Memory Usage [MB] \\
0        &                      & $T_{\perp\{ 0,1,2 \}(1,1,1)}$            & 0        & 1000.0   \\
1        & allgather_*          & $T_{\perp\{ \emptyset,\emptyset,\emptyset \}(\emptyset,\emptyset,\emptyset)}$ & 0.28000299999999995 & 8000.0   \\
2        & split_*              & $T_{\perp\{ \emptyset,(0,1,2),\emptyset \}(\emptyset,125,\emptyset)}$ & 0        & 1000.0   \\
Total cost: 0.28s
CPU times: user 4.35 ms, sys: 844 μs, total: 5.19 ms
Wall time: 4.66 ms


In [None]:
%%time
mem_constrained_dist = tc.optim.AStarRedistributor(tc.optim.IdealLowerBoundsCM(), alpha=ALPHA, beta=BETA, node_filter=mem_constrained_filter, path_cost_w=PATH_COST_W, estimate_w=ESTIMATE_W, top_k=TOP_K, node_limit=NODE_LIMIT, max_depth=MAX_DEPTH)
sequence, total_cost = mem_constrained_dist.redistribute(tensor_shape, dist, target_dist)

print_path(sequence, tensor_shape)
print(f"Total cost: {total_cost:.3f}s")

[2025-05-07 17:05:34,749][[1;36mtensorcraft.util.route_finder[0m][[1;35mfind_routes[0m][[1;37mINFO[0m] - [1;37mExplored 31 nodes, found 16 possible paths.[0m
[2025-05-07 17:05:34,749][[1;36m[1;36mtensorcraft.util.route_finder[0m[0m][[1;35m[1;35mfind_routes[0m[0m][[1;30m[1;37mINFO[0m[0m] - [1;30m[1;37mExplored 31 nodes, found 16 possible paths.[0m[0m
[2025-05-07 17:05:34,749][[1;36m[1;36m[1;36mtensorcraft.util.route_finder[0m[0m[0m][[1;35m[1;35m[1;35mfind_routes[0m[0m[0m][[1;30m[1;30m[1;37mINFO[0m[0m[0m] - [1;30m[1;30m[1;37mExplored 31 nodes, found 16 possible paths.[0m[0m[0m
[2025-05-07 17:05:34,749][[1;36m[1;36m[1;36m[1;36mtensorcraft.util.route_finder[0m[0m[0m[0m][[1;35m[1;35m[1;35m[1;35mfind_routes[0m[0m[0m[0m][[1;30m[1;30m[1;30m[1;37mINFO[0m[0m[0m[0m] - [1;30m[1;30m[1;30m[1;37mExplored 31 nodes, found 16 possible paths.[0m[0m[0m[0m
Step #   & Operation            & Distribution                          

In [None]:
%%time
astar_redist = tc.optim.AStarRedistributor(tc.optim.IdealLowerBoundsCM(), alpha=ALPHA, beta=BETA, path_cost_w=PATH_COST_W, estimate_w=ESTIMATE_W, top_k=TOP_K, node_limit=NODE_LIMIT, max_depth=MAX_DEPTH)
sequence, total_cost = astar_redist.redistribute(tensor_shape, dist, target_dist)

print_path(sequence, tensor_shape)
print(f"Total cost: {total_cost:.3f}s")

[2025-05-07 17:05:37,227][[1;36mtensorcraft.util.route_finder[0m][[1;35mfind_routes[0m][[1;37mINFO[0m] - [1;37mExplored 52 nodes, found 17 possible paths.[0m
[2025-05-07 17:05:37,227][[1;36m[1;36mtensorcraft.util.route_finder[0m[0m][[1;35m[1;35mfind_routes[0m[0m][[1;30m[1;37mINFO[0m[0m] - [1;30m[1;37mExplored 52 nodes, found 17 possible paths.[0m[0m
[2025-05-07 17:05:37,227][[1;36m[1;36m[1;36mtensorcraft.util.route_finder[0m[0m[0m][[1;35m[1;35m[1;35mfind_routes[0m[0m[0m][[1;30m[1;30m[1;37mINFO[0m[0m[0m] - [1;30m[1;30m[1;37mExplored 52 nodes, found 17 possible paths.[0m[0m[0m
[2025-05-07 17:05:37,227][[1;36m[1;36m[1;36m[1;36mtensorcraft.util.route_finder[0m[0m[0m[0m][[1;35m[1;35m[1;35m[1;35mfind_routes[0m[0m[0m[0m][[1;30m[1;30m[1;30m[1;37mINFO[0m[0m[0m[0m] - [1;30m[1;30m[1;30m[1;37mExplored 52 nodes, found 17 possible paths.[0m[0m[0m[0m
Step #   & Operation            & Distribution                          