In [5]:
import pulp 

In [6]:
GB = 1 << 30
T = 1e12

b, s = 64, 512  

# constants
h = 7168
l = 96

class Bandwidths:
    c2g = 32 * GB 
    g2c = 32 * GB 
    c2d = 3 * GB 
    d2c = 18 * GB 

B_type = 2 # for fp16
class FLOPS:
    mm = {2: 200 * T}

class MemCap: 
    gmem: int = 24 * GB     # Nvidia 4090
    cmem: int = 220 * GB    # Intel Xeon @ 2.3 GHz
    dmem: int = 1500 * GB 



In [17]:
class AttrDict(dict):
    """Sub-classes dict, and further allows attribute-like access to dictionary items.

    Examples:
        >>> d = AttrDict({'a': 1})
        >>> d.a, d['a'], d.get('a')
        (1, 1, 1)
        >>> d.b = 2
        >>> d.b, d['b']
        (2, 2)
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.__dict__ = self

In [29]:
def lp(b, s, stage='z', verbose=False):
    assert stage in ['x', 'y', 'z'] 

    prob = pulp.LpProblem(name='SARecomp', sense=pulp.LpMinimize)

    # w, x, y, z refer to weights, kv cache, actv cache, recomputed actv

    # percents
    pxg = pulp.LpVariable('pxg', 0, 1)
    pxc = pulp.LpVariable('pxc', 0, 1)
    pxd = pulp.LpVariable('pxd', 0, 1)
    pyg = pulp.LpVariable('pyg', 0, 1)
    pyc = pulp.LpVariable('pyc', 0, 1)
    pyd = pulp.LpVariable('pyd', 0, 1)
    pz = pulp.LpVariable('pz', 0, 1)
    prob += pxg + pxc + pxd + pyg + pyc + pyd + pz == 1 

    if stage == 'x':
        prob += pyg == 0
        prob += pyc == 0
        prob += pyc == 0
        prob += pz == 0
    elif stage == 'y':
        prob += pz == 0

    pwg = pulp.LpVariable('pwg', 0, 1)
    pwc = pulp.LpVariable('pwc', 0, 1)
    pwd = pulp.LpVariable('pwd', 0, 1)
    prob += pwg + pwc + pwd == 1 

    # c2g, g2c, c2d, d2c, comp

    # dm: data movement, T: time 
    dm_c2g = pulp.LpVariable('dm_c2g', 0)
    T_c2g = pulp.LpVariable('T_c2g', 0)
    prob += dm_c2g == 12 * h ** 2 * (pwc + pwd) + 2 * b * s * h * (pxc + pxd) + b * s * h * (pyc + pyd) 
    prob += T_c2g == dm_c2g * B_type / Bandwidths.c2g 

    dm_g2c = pulp.LpVariable('dm_g2c', 0)
    T_g2c = pulp.LpVariable('T_g2c', 0)
    prob += dm_g2c == 2 * b * 1 * h * (pxc + pxd) + b * 1 * h * (pyc + pyd) 
    prob += T_g2c == dm_g2c * B_type / Bandwidths.g2c 

    dm_d2c = pulp.LpVariable('dm_d2c', 0)
    T_d2c = pulp.LpVariable('T_d2c', 0)
    prob += dm_d2c == 12 * h ** 2 * (0 + pwd) + 2 * b * s * h * (0 + pxd) + b * s * h * (0 + pyd) 
    prob += T_d2c == dm_d2c * B_type / Bandwidths.d2c 

    dm_c2d = pulp.LpVariable('dm_c2d', 0)
    T_c2d = pulp.LpVariable('T_c2d', 0)
    prob += dm_c2d == 2 * b * 1 * h * (0 + pxd) + b * 1 * h * (0 + pyd) 
    prob += T_c2d == dm_c2d * B_type / Bandwidths.c2d 

    y = pulp.LpVariable('y', 0)
    z = pulp.LpVariable('z', 0)
    prob += y == s * (pyg + pyc + pyd)
    prob += z == s * pz 

    comp = pulp.LpVariable('comp', 0)
    T_comp = pulp.LpVariable('T_comp', 0)
    prob += comp == 3 * b * y * h ** 2 + 2 * b * s * h * (z + 1) + 12 * b * (z + 1) * h ** 2
    prob += T_comp == comp * 1 / FLOPS.mm[B_type] 

    T_ovl = pulp.LpVariable('T_ovl', 0)
    prob += (T_ovl >= T_c2g)
    prob += (T_ovl >= T_g2c)
    prob += (T_ovl >= T_d2c)
    prob += (T_ovl >= T_c2d)
    prob += (T_ovl >= T_comp) 

    # mem cap constraints
    G_home = pulp.LpVariable('G_home', 0)
    C_home = pulp.LpVariable('C_home', 0)
    D_home = pulp.LpVariable('D_home', 0)
    prob += G_home == l * (12 * h ** 2 * pwg + 2 * b * s * h * pxg + b * s * h * pyg) * B_type
    prob += C_home == l * (12 * h ** 2 * pwc + 2 * b * s * h * pxc + b * s * h * pyc) * B_type
    prob += D_home == l * (12 * h ** 2 * pwd + 2 * b * s * h * pxd + b * s * h * pyd) * B_type

    G_tmp1 = pulp.LpVariable('G_tmp1', 0)
    G_tmp2 = pulp.LpVariable('G_tmp2', 0)
    prob += G_tmp1 == 2 * b * s * h * B_type # KV
    prob += G_tmp2 == 4 * b * z * h * B_type # ffn hidden

    G_tmp = pulp.LpVariable('G_tmp', 0)
    prob += (G_tmp >= G_tmp1)
    prob += (G_tmp >= G_tmp2) 

    prob += (G_home + G_tmp <= MemCap.gmem)
    prob += (C_home <= MemCap.cmem)
    prob += (D_home <= MemCap.dmem) 

    # obj
    prob += T_ovl 
    status = prob.solve(pulp.PULP_CBC_CMD(msg=verbose))

    if verbose: 
        print(pulp.LpStatus[status])
        print(b, s)
        for x in [pwg , pwc , pwd,pxg , pxc , pxd , pyg , pyc , pyd , pz, G_tmp ]:
            print(f"{x.name} = {x.value()}")
        print(T_ovl.value())

    out = AttrDict() 
    out.status = pulp.LpStatus[status]
    out.percents = [x.value() for x in [pwg , pwc , pwd,pxg , pxc , pxd , pyg , pyc , pyd , pz]]
    out.T = T_ovl.value()
    
    return out

for b, s in [(64, 512), (64, 2048), (64, 4096), (64, 7000), (64, 9162), (64, 100000)]:
    print(b, s, lp(b, s))

64 512 {'status': 'Optimal', 'percents': [0.20181406, 0.79818594, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], 'T': 0.028645833}
64 2048 {'status': 'Optimal', 'percents': [0.15419501, 0.84580499, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], 'T': 0.030354818}
64 4096 {'status': 'Optimal', 'percents': [0.090702948, 0.39781746, 0.51147959, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], 'T': 0.032633464}
64 7000 {'status': 'Optimal', 'percents': [0.00067318594, 0.99932681, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0], 'T': 0.035864512}
64 9162 {'status': 'Optimal', 'percents': [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.019687295, 0.21391734, 0.76639536], 'T': 0.093040466}
64 100000 {'status': 'Infeasible', 'percents': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'T': 0.0}
