In [61]:
import pulp 

In [62]:
from flexgen.utils import GB, T

# constants
b, s, h = 64, 2048, 7168 
l = 96

class Bandwidths:
    c2g = 32 * GB 
    g2c = 32 * GB 
    c2d = 3 * GB 
    d2c = 18 * GB 

B_type = 2 # for fp16
class FLOPS:
    mm = {2: 200 * T}

class MemCap: 
    gmem: int = 24 * GB     # Nvidia 4090
    cmem: int = 220 * GB    # Intel Xeon @ 2.3 GHz
    dmem: int = 1500 * GB 



In [63]:
prob = pulp.LpProblem(name='SARecomp', sense=pulp.LpMinimize)


In [64]:
# w, x, y, z refer to weights, kv cache, actv cache, recomputed actv

# percents
pxg = pulp.LpVariable('pxg', 0, 1)
pxc = pulp.LpVariable('pxc', 0, 1)
pxd = pulp.LpVariable('pxd', 0, 1)
pyg = pulp.LpVariable('pyg', 0, 1)
pyc = pulp.LpVariable('pyc', 0, 1)
pyd = pulp.LpVariable('pyd', 0, 1)
pz = pulp.LpVariable('pz', 0, 1)
prob += pxg + pxc + pxd + pyg + pyc + pyd + pz == 1 

pwg = pulp.LpVariable('pwg', 0, 1)
pwc = pulp.LpVariable('pwc', 0, 1)
pwd = pulp.LpVariable('pwd', 0, 1)
prob += pwg + pwc + pwd == 1 

In [65]:
# c2g, g2c, c2d, d2c, comp

# dm: data movement, T: time 
dm_c2g = 12 * h ** 2 * (pwc + pwd) + 2 * b * s * h * (pxc + pxd) + b * s * h * (pyc + pyd) 
T_c2g = dm_c2g * B_type / Bandwidths.c2g 

dm_g2c = 2 * b * 1 * h * (pxc + pxd) + b * 1 * h * (pyc + pyd) 
T_g2c = dm_g2c * B_type / Bandwidths.g2c 

dm_d2c = 12 * h ** 2 * (0 + pwd) + 2 * b * s * h * (0 + pxd) + b * s * h * (0 + pyd) 
T_d2c = dm_d2c * B_type / Bandwidths.d2c 

dm_c2d = 2 * b * 1 * h * (0 + pxd) + b * 1 * h * (0 + pyd) 
T_c2d = dm_c2d * B_type / Bandwidths.c2d 

y = s * (pyg + pyc + pyd)
z = s * pz 
comp = 3 * b * y * h ** 2 + 2 * b * s * h * (z + 1) + 12 * b * (z + 1) * h ** 2
T_comp = comp / FLOPS.mm[B_type] 

T_ovl = pulp.LpVariable('T_ovl', 0)
prob += (T_ovl >= T_c2g) and (T_ovl >= T_g2c) and (T_ovl >= T_d2c) and (T_ovl >= T_c2d) and (T_ovl >= T_comp) 

In [66]:
# mem cap constraints
G_home = l * (12 * h ** 2 * pwg + 2 * b * s * h * pxg + b * s * h * pyg) * B_type
C_home = l * (12 * h ** 2 * pwc + 2 * b * s * h * pxc + b * s * h * pyc) * B_type
D_home = l * (12 * h ** 2 * pwd + 2 * b * s * h * pxd + b * s * h * pyd) * B_type

G_tmp1 = 2 * b * s * h * B_type # KV
G_tmp2 = 4 * b * z * h * B_type # ffn hidden
G_tmp = pulp.LpVariable('G_tmp', 0)
prob += (G_tmp >= G_tmp1) and (G_tmp >= G_tmp2) 
prob += (G_home + G_tmp <= MemCap.gmem) and (C_home <= MemCap.cmem) and (D_home <= MemCap.dmem) 


In [67]:
prob += T_ovl 

In [68]:
status = prob.solve() 

Welcome to the CBC MILP Solver 
Version: 2.10.3 
Build Date: Dec 15 2019 

command line - /home/dingfangyu/miniconda3/lib/python3.8/site-packages/pulp/solverdir/cbc/linux/64/cbc /tmp/2571b48c201b4a00834ef72c128fb973-pulp.mps -timeMode elapsed -branch -printingOptions all -solution /tmp/2571b48c201b4a00834ef72c128fb973-pulp.sol (default strategy 1)
At line 2 NAME          MODEL
At line 3 ROWS
At line 10 COLUMNS
At line 32 RHS
At line 38 BOUNDS
At line 49 ENDATA
Problem MODEL has 5 rows, 12 columns and 20 elements
Coin0008I MODEL read with 0 errors
Option for timeMode changed from cpu to elapsed
Presolve 2 (-3) rows, 3 (-9) columns and 5 (-15) elements
0  Obj 0 Primal inf 0.0008052447 (1)
1  Obj 0.0002066953
Optimal - objective value 0.0002066953
After Postsolve, objective 0.0002066953, infeasibilities - dual 0 (0), primal 0 (0)
Optimal objective 0.0002066953011 - 1 iterations time 0.002, Presolve 0.00
Option for printingOptions changed from normal to all
Total time (CPU seconds):       

In [69]:
for x in [pwg , pwc , pwd,pxg , pxc , pxd , pyg , pyc , pyd , pz]:
    print(f"{x.name} = {x.value()}", end=', ')

pwg = 1.0, pwc = 0.0, pwd = 0.0, pxg = 1.0, pxc = 0.0, pxd = 0.0, pyg = 0.0, pyc = 0.0, pyd = 0.0, pz = 0.0, 