In [11]:
import pickle

from pyomo.environ import *
import numpy as np
import time
import json
from scipy.interpolate import RegularGridInterpolator

In [7]:
BS = 256
NUM_TOKENS = 198
WARMUP = 20
TOTAL = 35

In [19]:
EMB, HEAD, QK, V, MLP = 768,12,64,64,3072
all_variable_specs = {
    "EMB": [768, 16, 768//16+1],
    "HEAD": [12, 1, 12//1+1],
    "QK": [64, 2, 64//2+1],
    "V": [64, 2, 64//2+1],
    "MLP": [3072, 32, 3072//32+1],
}

In [20]:
latency_file = "latency_head.json"
with open(latency_file) as json_file:
    measurement = json.load(json_file)
EMB = np.arange(4)*256
head = np.array([1,3,6,9,12])
QK = np.array([1,16,32,48,64])
V = np.array([1,16,32,48,64])
MLP = (np.arange(25))*128
data = np.zeros([4,5,5,5,25])
for i in range(3):
    for j in range(5):
        for k1 in range(5):
            for k2 in range(5):
                for l in range(25):
                    e = EMB[i+1]
                    q_h = head[j]
                    q = QK[k1]
                    v = V[k2]
                    if MLP[l]:
                        h = MLP[l]
                    else:
                        h = 1
                    data[i+1,j,k1,k2,l] = measurement['EMB_'+str(e)]['QK_'+str(q_h)+'_'+str(q)]['V_'+str(v)]['MLP_'+str(h)]
latency_look_up_table = RegularGridInterpolator((EMB, head, QK, V, MLP), data) 

In [21]:
save_name = f"mlp_lut_BS{BS}_NUM_TOKENS{NUM_TOKENS}_v100.pkl"
with open(save_name, 'rb') as f:
    mlp_lut = pickle.load(f)

In [22]:
save_name = f"qk_lut_BS{BS}_NUM_TOKENS{NUM_TOKENS}_v100.pkl"
with open(save_name, 'rb') as f:
    qk_lut = pickle.load(f)

In [23]:
save_name = f"vandproj_lut_BS{BS}_NUM_TOKENS{NUM_TOKENS}_v100.pkl"
with open(save_name, 'rb') as f:
    vandproj_lut = pickle.load(f)

In [25]:
# # Define index ranges and strides
# emb_idx_range = np.arange(all_variable_specs["EMB"][2])
# head_idx_range = np.arange(all_variable_specs["HEAD"][2])
# qk_idx_range = np.arange(all_variable_specs["QK"][2])
# v_idx_range = np.arange(all_variable_specs["V"][2])
# mlp_idx_range = np.arange(all_variable_specs["MLP"][2])

# # Calculate and fill latency table
# for emb_idx in emb_idx_range:
#     print(f"progress {emb_idx}/{all_variable_specs['EMB'][2]}")
    
#     emb_values = np.maximum(emb_idx * all_variable_specs["EMB"][1], 1)
#     head_values = np.maximum(head_idx_range * all_variable_specs["HEAD"][1], 1)
#     qk_values = np.maximum(qk_idx_range * all_variable_specs["QK"][1], 1)
#     v_values = np.maximum(v_idx_range * all_variable_specs["V"][1], 1)
#     mlp_values = np.maximum(mlp_idx_range * all_variable_specs["MLP"][1], 1)
    
#     emb_grid, head_grid, qk_grid, v_grid, mlp_grid = np.meshgrid(emb_values, head_values, qk_values, v_values, mlp_values, indexing='ij')

#     dim_vecs = np.vstack([emb_grid.ravel(), head_grid.ravel(), qk_grid.ravel(), v_grid.ravel(), mlp_grid.ravel()]).T
#     latency_values = latency_look_up_table(dim_vecs)
    
#     latency_table[emb_idx, :, :, :, :] = latency_values.reshape(all_variable_specs["HEAD"][2], all_variable_specs["QK"][2], all_variable_specs["V"][2], all_variable_specs["MLP"][2])

In [26]:
# test_emb = 6
# test_head = 11
# test_qk = 16
# test_v = 30
# test_mlp = 40

# # Convert test dimensions to the scale of the data
# test_dim_vec = np.array([
#     max(test_emb * all_variable_specs["EMB"][1], 1),
#     max(test_head * all_variable_specs["HEAD"][1], 1),
#     max(test_qk * all_variable_specs["QK"][1], 1),
#     max(test_v * all_variable_specs["V"][1], 1),
#     max(test_mlp * all_variable_specs["MLP"][1], 1)
# ])

# # Print the latency table value at the specified indices
# print(latency_table[test_emb, test_head, test_qk, test_v, test_mlp])

# # Interpolate to find the latency value
# latency_value = latency_look_up_table(test_dim_vec)
# print(latency_value)

In [27]:
# import pickle
# with open("minlp_latency_table.pkl", "wb") as f:
#     pickle.dump(latency_table, f)

## Trivial GLPK Test

In [28]:
model = ConcreteModel()
# Define variables
variable_slices_by_type = {}
counter = 0
for var_type, var_spec in all_variable_specs.items():
    variable_slices_by_type[var_type] = (counter, counter+all_variable_specs[var_type][2])
    counter += all_variable_specs[var_type][2]

all_items = list(range(counter))
model.decision_vars = Var(all_items, domain=Binary)

# Define importance and constraint
importance = 0
model.group_unique_constraint = ConstraintList()

for var_type, var_spec in all_variable_specs.items():
    cur_decision_vars = [model.decision_vars[k] for k in range(variable_slices_by_type[var_type][0], variable_slices_by_type[var_type][1])]
    model.group_unique_constraint.add(sum(cur_decision_vars[i] for i in list(range(len(cur_decision_vars)))) == 1)
    random_importance = np.abs(np.random.randn(all_variable_specs[var_type][2])) * 100
    importance += sum(cur_decision_vars[i] * random_importance[i] for i in range(len(cur_decision_vars)))

model.obj = Objective(expr=importance, sense=maximize)

In [29]:
# solver = SolverFactory('mindtpy')
solver = SolverFactory('glpk')
solver.solve(model)
# results = solver.solve(model, strategy='OA', init_strategy='FP', mip_solver='glpk', nlp_solver='ipopt', tee=True) 
# results = solver.solve(model, strategy='OA', init_strategy='FP', mip_solver='glpk', nlp_solver='ipopt') 
# results = solver.solve(model) 
# results = solver.solve(model, mip_solver='glpk', nlp_solver='ipopt') 

{'Problem': [{'Name': 'unknown', 'Lower bound': 1276.7676538037, 'Upper bound': 1276.7676538037, 'Number of objectives': 1, 'Number of constraints': 6, 'Number of variables': 226, 'Number of nonzeros': 226, 'Sense': 'maximize'}], 'Solver': [{'Status': 'ok', 'Termination condition': 'optimal', 'Statistics': {'Branch and bound': {'Number of bounded subproblems': '1', 'Number of created subproblems': '1'}}, 'Error rc': 0, 'Time': 0.017297744750976562}], 'Solution': [OrderedDict([('number of solutions', 0), ('number of solutions displayed', 0)])]}

In [30]:
for var_type, var_spec in all_variable_specs.items():
    indices = list(range(variable_slices_by_type[var_type][0], variable_slices_by_type[var_type][1]))
    cur_decision_vars = [model.decision_vars[k] for k in indices]
    cur_decision_vars_value = [x.value for x in cur_decision_vars]
    print(var_type, cur_decision_vars_value)

EMB [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
HEAD [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
QK [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
V [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
MLP [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.

## MINLP

In [31]:
all_variable_specs["EMB"]

[768, 16, 49]

In [44]:
total_latency = mlp_lut[-1, -1] + vandproj_lut[-1, -1, -1] + qk_lut[-1, -1, -1]
target_latency = total_latency * 0.6

In [53]:
model = ConcreteModel()
# Define variables
variable_slices_by_type = {}
counter = 0
for var_type, var_spec in all_variable_specs.items():
    variable_slices_by_type[var_type] = (counter, counter+all_variable_specs[var_type][2])
    counter += all_variable_specs[var_type][2]

all_items = list(range(counter))
model.decision_vars = Var(all_items, domain=Binary)

# Define importance and constraint
importance = 0
model.group_unique_constraint = ConstraintList()

for var_type, var_spec in all_variable_specs.items():
    cur_decision_vars = [model.decision_vars[k] for k in range(variable_slices_by_type[var_type][0], variable_slices_by_type[var_type][1])]
    model.group_unique_constraint.add(sum(cur_decision_vars[i] for i in list(range(len(cur_decision_vars)))) == 1)
    random_importance = np.abs(np.random.randn(all_variable_specs[var_type][2])) * 100
    importance += sum(cur_decision_vars[i] * random_importance[i] for i in range(len(cur_decision_vars)))

# Add latency constraint
emb_vectors = np.array([model.decision_vars[k] for k in range(variable_slices_by_type["EMB"][0], variable_slices_by_type["EMB"][1])])
head_vectors = np.array([model.decision_vars[k] for k in range(variable_slices_by_type["HEAD"][0], variable_slices_by_type["HEAD"][1])])
qk_vectors = np.array([model.decision_vars[k] for k in range(variable_slices_by_type["QK"][0], variable_slices_by_type["QK"][1])])
v_vectors = np.array([model.decision_vars[k] for k in range(variable_slices_by_type["V"][0], variable_slices_by_type["V"][1])])
mlp_vectors = np.array([model.decision_vars[k] for k in range(variable_slices_by_type["MLP"][0], variable_slices_by_type["MLP"][1])])

# grid_emb, grid_head, grid_qk, grid_v, grid_mlp = np.meshgrid(emb_vectors, head_vectors, qk_vectors, v_vectors, mlp_vectors, indexing='ij')
# T = grid_emb * grid_head * grid_qk * grid_v * grid_mlp
# T = np.tensordot(np.tensordot(np.tensordot(np.tensordot(emb_vectors, head_vectors, axes=0), qk_vectors, axes=0), v_vectors, axes=0), mlp_vectors, axes=0)

# mlp latency
T1 = np.tensordot(emb_vectors, mlp_vectors, axes=0)
latency_expr_mlp = np.sum(T1 * mlp_lut)
T2 = np.tensordot(head_vectors, np.tensordot(emb_vectors, v_vectors, axes=0), axes=0)
latency_expr_vandproj = np.sum(T2 * vandproj_lut)
T3 = np.tensordot(head_vectors, np.tensordot(emb_vectors, qk_vectors, axes=0), axes=0)
latency_expr_qk = np.sum(T3 * qk_lut)
latency_expr = latency_expr_mlp + latency_expr_vandproj + latency_expr_qk
model.latency_constraint = Constraint(expr=latency_expr <= target_latency)
model.obj = Objective(expr=importance, sense=maximize)

In [54]:
solver = SolverFactory('mindtpy')
# solver = SolverFactory('glpk')
# solver.solve(model)
results = solver.solve(model, strategy='OA', init_strategy='FP', mip_solver='glpk', nlp_solver='ipopt', tee=True) 
# results = solver.solve(model, strategy='OA', init_strategy='FP', mip_solver='glpk', nlp_solver='ipopt') 
# results = solver.solve(model) 
# results = solver.solve(model, mip_solver='glpk', nlp_solver='ipopt') 

---------------------------------------------------------------------------------------------
              Mixed-Integer Nonlinear Decomposition Toolbox in Pyomo (MindtPy)               
---------------------------------------------------------------------------------------------
For more information, please visit https://pyomo.readthedocs.io/en/stable/contributed_packages/mindtpy.html
Original model has 6 constraints (1 nonlinear) and 0 disjunctions, with 225 variables, of which 225 are binary, 0 are integer, and 0 are continuous.
Moving objective to constraint set.
FP is the initial strategy being used.

 Iteration | Subproblem Type | Objective Value | Primal Bound |   Dual Bound |   Gap   | Time(s)

         -       Relaxed NLP            1013.2           -inf         1013.2      nan%     14.05
         1            FP-MIP       2.09396e-06           -inf         1013.2      nan%     16.97
         1            FP-NLP       5.70079e-15           -inf         1013.2      nan%     23

In [55]:
for var_type, var_spec in all_variable_specs.items():
    indices = list(range(variable_slices_by_type[var_type][0], variable_slices_by_type[var_type][1]))
    cur_decision_vars = [model.decision_vars[k] for k in indices]
    cur_decision_vars_value = [x.value for x in cur_decision_vars]
    print(var_type, cur_decision_vars_value)

EMB [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
HEAD [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
QK [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]
V [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
MLP [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.