In [1]:
import pandas as pd
import numpy as np
from tqdm import tqdm

from multiprocessing import Process, Manager

import sys
sys.path.append('./data/')
sys.path.append('./../')
import gen_lemonade_data
import training_util

# Import models
from sklearn.linear_model import LinearRegression
from sklearn.svm import SVR
import lightgbm as lgb

# Import metrics
from sklearn.metrics import mean_squared_error

# Import OP Solvers
from pulp import LpProblem, LpStatus, lpSum, LpVariable, LpMinimize

In [13]:
def cost_function_1(y, z, c0, c1):
    return z*(c0 - c1) + c1*y

def cost_function_2(y, z, c0, c2):
    return z*(c0 + c2) - c2*y

def cost_function(y, z, c0, c1, c2):
    if y-z >= 0:
        return cost_function_1(y, z, c0, c1)
    else:
        return cost_function_2(y, z, c0, c1)
    
def cost_function_list(y_list, z_list, c0, c1, c2):
    cost_list = []
    for y, z in zip(y_list, z_list):
        cost_list.append(cost_function(y, z, c0, c1, c2))
    return cost_list
        
def SolOpt_1(y, c0, c1):
    t_Model = LpProblem(name="small-problem", sense=LpMinimize)
    z = LpVariable(name="z", lowBound=0)

    t_Model+=(y-z>=0,"cstr1")
    t_Model+=(z>=0,"cstr2")
    
    obj_func = cost_function_1(y, z, c0, c1)
    t_Model += obj_func
    status = t_Model.solve()
    var=t_Model.variables()
    return var[0].value(),t_Model.objective.value()


def SolOpt_2(y, c0, c2):
    t_Model = LpProblem(name="small-problem", sense=LpMinimize)
    z = LpVariable(name="z", lowBound=0)

    t_Model+=(z-y>=0,"cstr1")
    t_Model+=(z>=0,"cstr2")
    
    obj_func = cost_function_2(y, z, c0, c2)
    t_Model += obj_func
    status = t_Model.solve()
    var=t_Model.variables()
    return var[0].value(),t_Model.objective.value()

def SolOpt(y, c0, c1, c2, i, z_opt, f_opt):
    Result_1 = SolOpt_1(y, c0, c1)
    Result_2 = SolOpt_2(y, c0, c2)
    
    if Result_1[1] < Result_2[1]:
        Result = Result_1
    else:
        Result = Result_2
    z_opt[i] = Result[0]
    f_opt[i] = Result[1]
    
def run_solver(data_test, c0, c1, c2, model_type):
    
    y_col = 'y'
    z_col = 'z_opt_from_y'  
    if model_type != 'real':
        y_col = 'y_pred_{}'.format(model_type)
        z_col = 'z_opt_from_y_pred_{}'.format(model_type)
    
    manager = Manager()
    z_opt = manager.dict()
    f_opt = manager.dict()
    jobs = []
    for i in tqdm(range(0, len(data_test))):
        p = Process(target=SolOpt, 
                    args=(data_test[y_col].iloc[i], 
                          c0, c1, c2, i, z_opt, f_opt))
        jobs.append(p)
        p.start()

    for proc in jobs:
        proc.join()

    df_solver = pd.DataFrame(
        data = {'ind':z_opt.keys(), 
                z_col:z_opt.values()}
                ).set_index('ind')

    data_test = pd.concat([data_test, df_solver], axis = 1)
    
    return data_test

In [14]:
N = 3000
train_perc = 0.8

In [15]:
data_lemonade = gen_lemonade_data.generate_lemonade_dataset(N = N)

Ntr = int(train_perc*N)
Nte = N - Ntr
print('Ntr, Nte, N')
print(Ntr, Nte, N)

data_train = data_lemonade.iloc[:Ntr, :].reset_index(drop=True).copy()
data_test = data_lemonade.iloc[Ntr:, :].reset_index(drop=True).copy()

feat_cols = ['x1','x2','x3']
target_col = ['y']

Ntr, Nte, N
2400 600 3000


In [16]:
#########################################################################
##### Set hyperparams for models ########################################
#########################################################################

# Hyperparams for SVM
params_svm = {
    'C':1000
}

# Hyperparams for LGB
params_lgb = {
    'objective':'regression',
    'boosting_type': 'gbdt',
    'learning_rate': 0.04,
    'max_depth':5,
    'num_leaves':7,
}


#########################################################################
##### Train: linear, SVM and LGB regressors #############################
#########################################################################

mdl_lin = training_util.train_linear(
    X_tr=data_train[feat_cols], y_tr=data_train[target_col])

mdl_svm = training_util.train_svm(
    X_tr=data_train[feat_cols], y_tr=data_train[target_col], 
    params=params_svm)

mdl_gbm = training_util.train_lgb(
    X_tr=data_train[feat_cols], y_tr=data_train[target_col], 
    params=params_lgb)



#########################################################################
##### Predict test data: linear, SVM and LGB regressors #################
#########################################################################

data_test.loc[:,'y_pred_lin'] = mdl_lin.predict(data_test[feat_cols])
data_test.loc[:,'y_pred_svm'] = mdl_svm.predict(data_test[feat_cols])
data_test.loc[:,'y_pred_gbm'] = mdl_gbm.predict(data_test[feat_cols])

mse_lin = mean_squared_error(y_true=data_test['y'], 
                             y_pred=data_test['y_pred_lin'])

mse_svm = mean_squared_error(y_true=data_test['y'], 
                             y_pred=data_test['y_pred_svm'])

mdl_gbm = mean_squared_error(y_true=data_test['y'], 
                             y_pred=data_test['y_pred_gbm'])

print('\n----------Results----------')
print('MSE using linear model:', mse_lin)
print('MSE using SVM regressor:', mse_svm)
print('MSE using LGBM regressor:', mdl_gbm)

You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 267
[LightGBM] [Info] Number of data points in the train set: 1800, number of used features: 3
[LightGBM] [Info] Start training from score 255.822781
[100]	TR's l2: 155.36	VA's l2: 143.07
[200]	TR's l2: 34.5107	VA's l2: 35.8328
[300]	TR's l2: 24.4415	VA's l2: 26.2605
[400]	TR's l2: 18.3037	VA's l2: 20.8778
[500]	TR's l2: 14.4979	VA's l2: 17.5611
[600]	TR's l2: 12.1184	VA's l2: 15.5171
[700]	TR's l2: 10.5344	VA's l2: 14.2707
[800]	TR's l2: 9.73928	VA's l2: 13.416
[900]	TR's l2: 8.59228	VA's l2: 12.4771
[1000]	TR's l2: 7.91049	VA's l2: 12.0527

----------Results----------
MSE using linear model: 5560.041760021315
MSE using SVM regressor: 1033.6262002383862
MSE using LGBM regressor: 11.532233326618316


In [17]:
c0 = 10
c1 = 50
c2 = 30

In [18]:
mdl_types = ['real','lin','svm','gbm']
for mdl_type in mdl_types:
    print('Run solver for', mdl_type)
    data_test = run_solver(data_test, c0, c1, c2, model_type=mdl_type)

100%|██████████| 600/600 [00:57<00:00, 10.43it/s]
100%|██████████| 600/600 [00:54<00:00, 11.00it/s]
100%|██████████| 600/600 [00:53<00:00, 11.12it/s]
100%|██████████| 600/600 [00:55<00:00, 10.88it/s]


In [20]:
mdl_types = ['real','lin','svm','gbm']
for mdl_type in mdl_types:
    y_col = 'y'
    z_col = 'z_opt_from_y'
    f_col = 'f_opt_from_y'
    if mdl_type!='real':
        z_col = 'z_opt_from_y_pred_{}'.format(mdl_type)
        f_col = 'f_opt_from_y_pred_{}'.format(mdl_type)
    
    data_test.loc[:, f_col] = cost_function_list(
        y_list = data_test.loc[:,y_col], 
        z_list = data_test.loc[:,z_col], 
        c0 = c0, c1 = c1, c2 = c2)

In [21]:
data_test

Unnamed: 0,x1,x2,x3,y,y_pred_lin,y_pred_svm,y_pred_gbm,z_opt_from_y,z_opt_from_y_pred_lin,z_opt_from_y_pred_svm,z_opt_from_y_pred_gbm,f_opt_from_y,f_opt_from_y_pred_lin,f_opt_from_y_pred_svm,f_opt_from_y_pred_gbm
0,14.800374,0.0,6.0,396.007476,316.954146,395.398770,391.350123,396.007480,316.954150,395.398770,391.350120,3960.075019,7122.207781,3984.422981,4146.368981
1,15.441118,1.0,4.0,112.426152,180.551815,112.218386,111.910078,112.426150,180.551820,112.218390,111.910080,1124.261585,5211.801615,1132.571985,1144.904385
2,16.517854,1.0,2.0,107.589269,154.120801,92.243707,107.397326,107.589270,154.120800,92.243707,107.397330,1075.892768,3867.784568,1689.715152,1083.570232
3,24.219798,0.0,4.0,321.417773,379.098324,342.388520,321.818259,321.417770,379.098320,342.388520,321.818260,3214.177834,6675.010566,4472.422566,3238.206966
4,22.128466,0.0,3.0,271.284662,337.963030,271.237394,272.056033,271.284660,337.963030,271.237390,272.056030,2712.846707,6713.548693,2714.737507,2759.128693
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
595,23.795182,1.0,1.0,143.975908,212.453537,187.700221,142.315400,143.975910,212.453540,187.700220,142.315400,1439.759182,5548.416982,4063.217782,1506.179418
596,20.952821,0.0,5.0,389.292316,363.343907,389.394780,387.182733,389.292320,363.343910,389.394780,387.182730,3892.923383,4930.859417,3899.070983,3977.306617
597,27.645674,1.0,5.0,244.842556,329.060263,317.634036,235.095892,244.842560,329.060260,317.634040,235.095890,2448.425808,7501.487808,6815.914608,2838.292192
598,16.700301,2.0,2.0,43.400602,50.715503,7.556367,41.428353,43.400602,50.715503,7.556367,41.428353,434.006024,872.900084,1867.775404,512.895976
