# Optim Bottom-Up

## Setup


In [None]:
%load_ext autoreload
%autoreload 2
from importlib import reload
import logging
reload(logging)

import torch

import tensorcraft as tc

log = logging.getLogger("tensorcraft")
# log.setLevel(logging.INFO)

ALPHA = 1e-6
BETA=64.0/( 200.0 * 1e9)

In [None]:
columns = ["Step #", "Operation", "Distribution", "Cost[s]", "Memory Usage [MB]"]
columns_width = [8, 20, 40, 8, 8]

type_size = 8

def print_path(path: list[tuple[str, any, float]], tensor_shape: torch.Size):

    line = " & ".join(f"{col:<{width}}" for col, width in zip(columns, columns_width)) + " \\\\"
    print(line)
    for i, (op, s_dist, s_cost) in enumerate(path):
        line = f"{i:<{columns_width[0]}} & {op:<{columns_width[1]}} & {s_dist.latexStr():<{columns_width[2]}} & {s_cost:<{columns_width[3]}} & {s_dist.maxNumElements(tensor_shape) * 8 / 10**6:<{columns_width[4]}} \\\\" 
        print(line)


## Redistributors

Given a tensor shape, a starting distribution and a target distribution, creates a sequence of collective ops to reach the target dist while optimizing for different metrics.

### Memory efficient redistributions example

In [None]:
tensor_shape = torch.Size([8, 8, 8, 4])
mesh = torch.Size([4,2,4])
dist = tc.dist.MultiAxisDist(mesh, ((0,1), (), (), ()), 1)
print(dist)

print("Option 1")
new_dist, comm_vol, n_procs = dist.alltoall(tensor_shape, 0, 1, minor=True)
print(f"Step 0: alltoall(0,1) - {new_dist}, comm_vol, n_procs: {comm_vol, n_procs}")
new_dist, comm_vol, n_procs = new_dist.alltoall(tensor_shape, 0, 2)
print(f"Step 1: alltoall(0,2) - {new_dist}, comm_vol, n_procs: {comm_vol, n_procs}")

print("Option 2")
new_dist, comm_vol, n_procs = dist.split(tensor_shape, 3, 2, 1)
print(f"Step 0: split(3,2)  {new_dist}, comm_vol, n_procs: {comm_vol, n_procs}")
new_dist, comm_vol, n_procs = new_dist.alltoall(tensor_shape, 0, 1, minor=True)
print(f"Step 1: alltoall(0,1) - {new_dist}, comm_vol, n_procs: {comm_vol, n_procs}")
new_dist, comm_vol, n_procs = new_dist.alltoall(tensor_shape, 0, 2)
print(f"Step 2: alltoall(0,2) - {new_dist}, comm_vol, n_procs: {comm_vol, n_procs}")
new_dist, comm_vol, n_procs = new_dist.allgather(tensor_shape, 2)
print(f"Step 3: allgather(2)  {new_dist}, comm_vol, n_procs: {comm_vol, n_procs}")

D_[4,2,4]⊥{(0,1),∅,∅,∅}(1,∅,∅,∅)
Option 1
Step 0: alltoall(0,1) - D_[4,2,4]⊥{0,1,∅,∅}(2,1,∅,∅), comm_vol, n_procs: (512, 2)
Step 1: alltoall(0,2) - D_[4,2,4]⊥{∅,1,0,∅}(∅,1,2,∅), comm_vol, n_procs: (1024, 4)
Option 2
Step 0: split(3,2)  D_[4,2,4]⊥{(0,1),∅,∅,2}(1,∅,∅,1), comm_vol, n_procs: (0, 0)
Step 1: alltoall(0,1) - D_[4,2,4]⊥{0,1,∅,2}(2,1,∅,1), comm_vol, n_procs: (128, 2)
Step 2: alltoall(0,2) - D_[4,2,4]⊥{∅,1,0,2}(∅,1,2,1), comm_vol, n_procs: (256, 4)
Step 3: allgather(2)  D_[4,2,4]⊥{∅,1,0,∅}(∅,1,2,∅), comm_vol, n_procs: (256, 4)


### Problem 1

In [None]:
tensor_shape = torch.Size([800, 400, 800, 4])
mesh = torch.Size([4,2,4])
dist = tc.dist.MultiAxisDist(mesh, ((0,1), (), (), ()), 100) 
target_dist = tc.dist.MultiAxisDist(mesh, ((), (1,), (0,), ()), (-1, 100, 100, -1))

#### Naive Gather Split

Simplest redistributor. Just allgathers, then splits. Should be both communication and memory ineficient. 

In [None]:
%%time
naive_rdist = tc.optim.NaiveGathererRedist(tc.optim.IdealLowerBoundsCM(), alpha=ALPHA, beta=BETA)

sequence, total_cost = naive_rdist.redistribute(tensor_shape, dist, target_dist)
print_path(sequence, tensor_shape)
print(f"Total cost: {total_cost:.2f}s")


Step #   & Operation            & Distribution                             & Cost[s]  & Memory Usage [MB] \\
0        &                      & $T_{\perp\{ (0,1),\emptyset,\emptyset,\emptyset \}(100,\emptyset,\emptyset,\emptyset)}$ & 0        & 1024.0   \\
1        & allgather_*          & $T_{\perp\{ \emptyset,\emptyset,\emptyset,\emptyset \}(\emptyset,\emptyset,\emptyset,\emptyset)}$ & 0.28672299999999995 & 8192.0   \\
2        & split_*              & $T_{\perp\{ \emptyset,1,0,\emptyset \}(\emptyset,100,100,\emptyset)}$ & 0        & 1024.0   \\
Total cost: 0.29s
CPU times: user 1.63 ms, sys: 0 ns, total: 1.63 ms
Wall time: 1.45 ms


#### Memory Constrained

In [None]:
%%time
mem_constrained_dist = tc.optim.MemoryConstrainedRedist(tc.optim.IdealLowerBoundsCM(), alpha=ALPHA, beta=BETA)
sequence, total_cost = mem_constrained_dist.redistribute(tensor_shape, dist, target_dist)

print_path(sequence, tensor_shape)
print(f"Total cost: {total_cost:.2f}s")

Explored 1000 nodes, found 20 possible paths.
Step #   & Operation            & Distribution                             & Cost[s]  & Memory Usage [MB] \\
0        &                      & $T_{\perp\{ (0,1),\emptyset,\emptyset,\emptyset \}(100,\emptyset,\emptyset,\emptyset)}$ & 0        & 1024.0   \\
1        & split_3_2_1          & $T_{\perp\{ (0,1),\emptyset,\emptyset,2 \}(100,\emptyset,\emptyset,1)}$ & 0.0      & 256.0    \\
2        & alltoall_minor_0_1_-1 & $T_{\perp\{ 0,1,\emptyset,2 \}(200,100,\emptyset,1)}$ & 0.010240999999999998 & 256.0    \\
3        & alltoall_0_2_100     & $T_{\perp\{ \emptyset,1,0,2 \}(\emptyset,100,100,1)}$ & 0.030721999999999996 & 256.0    \\
4        & allgather_2          & $T_{\perp\{ \emptyset,1,0,\emptyset \}(\emptyset,100,100,\emptyset)}$ & 0.030721999999999996 & 1024.0   \\
Total cost: 0.07s
CPU times: user 1min 40s, sys: 29.4 ms, total: 1min 40s
Wall time: 1min 41s


In [None]:
%%time
mem_constrained_dist = tc.optim.MemoryConstrainedRedist(tc.optim.IdealLowerBoundsCM(), alpha=ALPHA, beta=BETA, max_depth=3)
sequence, total_cost = mem_constrained_dist.redistribute(tensor_shape, dist, target_dist)

print_path(sequence, tensor_shape)
print(f"Total cost: {total_cost:.2f}s")

Explored 547 nodes, found 10 possible paths.
Step #   & Operation            & Distribution                             & Cost[s]  & Memory Usage [MB] \\
0        &                      & $T_{\perp\{ (0,1),\emptyset,\emptyset,\emptyset \}(100,\emptyset,\emptyset,\emptyset)}$ & 0        & 1024.0   \\
1        & alltoall_minor_0_1_-1 & $T_{\perp\{ 0,1,\emptyset,\emptyset \}(200,100,\emptyset,\emptyset)}$ & 0.040961 & 1024.0   \\
2        & alltoall_0_2_100     & $T_{\perp\{ \emptyset,1,0,\emptyset \}(\emptyset,100,100,\emptyset)}$ & 0.12288199999999999 & 1024.0   \\
Total cost: 0.16s
CPU times: user 9.55 s, sys: 3.35 ms, total: 9.56 s
Wall time: 9.58 s


In [None]:
%%time
mem_constrained_dist = tc.optim.MemoryConstrainedRedist(tc.optim.IdealLowerBoundsCM(), alpha=ALPHA, beta=BETA, top_k=1)
sequence, total_cost = mem_constrained_dist.redistribute(tensor_shape, dist, target_dist)

print_path(sequence, tensor_shape)
print(f"Total cost: {total_cost:.2f}s")

Explored 8 nodes, found 1 possible paths.
Step #   & Operation            & Distribution                             & Cost[s]  & Memory Usage [MB] \\
0        &                      & $T_{\perp\{ (0,1),\emptyset,\emptyset,\emptyset \}(100,\emptyset,\emptyset,\emptyset)}$ & 0        & 1024.0   \\
1        & alltoall_minor_0_1_-1 & $T_{\perp\{ 0,1,\emptyset,\emptyset \}(200,100,\emptyset,\emptyset)}$ & 0.040961 & 1024.0   \\
2        & alltoall_0_2_100     & $T_{\perp\{ \emptyset,1,0,\emptyset \}(\emptyset,100,100,\emptyset)}$ & 0.12288199999999999 & 1024.0   \\
Total cost: 0.16s
CPU times: user 324 ms, sys: 4 μs, total: 324 ms
Wall time: 324 ms


#### A*

In [None]:
%%time
a_star_redist = tc.optim.AStarRedistributor(tc.optim.IdealLowerBoundsCM(), alpha=ALPHA, beta=BETA, path_cost_w=100.0, estimate_w=1.0)
sequence, total_cost = a_star_redist.redistribute(tensor_shape, dist, target_dist)

print_path(sequence, tensor_shape)
print(f"Total cost: {total_cost:.2f}s")

Explored 8 nodes, found 1 possible paths.
Step #   & Operation            & Distribution                             & Cost[s]  & Memory Usage [MB] \\
0        &                      & $T_{\perp\{ (0,1),\emptyset,\emptyset,\emptyset \}(100,\emptyset,\emptyset,\emptyset)}$ & 0        & 1024.0   \\
1        & alltoall_0_2_-1      & $T_{\perp\{ \emptyset,\emptyset,(0,1),\emptyset \}(\emptyset,\emptyset,100,\emptyset)}$ & 0.28672299999999995 & 1024.0   \\
2        & alltoall_minor_2_1_-1 & $T_{\perp\{ \emptyset,1,0,\emptyset \}(\emptyset,100,200,\emptyset)}$ & 0.040961 & 1024.0   \\
3        & changeBlockSize_2_100 & $T_{\perp\{ \emptyset,1,0,\emptyset \}(\emptyset,100,100,\emptyset)}$ & 0.030721999999999996 & 1024.0   \\
Total cost: 0.36s
CPU times: user 555 ms, sys: 0 ns, total: 555 ms
Wall time: 559 ms
