Skip to content

Fix objective functions with zero hessian #1199

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Jan 16, 2018
Merged

Fix objective functions with zero hessian #1199

merged 20 commits into from
Jan 16, 2018

Conversation

guolinke
Copy link
Collaborator

Use sklearn's solution for the objective function with zero heassian:

  1. only use first-order gradients to construct,
  2. then fix the tree's output according to objective function.

Some benchmarks on L1 objective function:

dataset: https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/regression.html#YearPredictionMSD

Result:
image

@guolinke guolinke requested a review from Laurae2 January 13, 2018 14:48
@guolinke
Copy link
Collaborator Author

refer to #1182 , #979

ping @joseortiz3 , I also fix the quantile loss in this PR, can you also have a try and provide the feedback? the new solution doesn't need the reg_sqrt. And quantile_l2 is removed.

@joseortiz3
Copy link

joseortiz3 commented Jan 14, 2018

The new quantiles behaves better than before, using the standard test. Parameters that specify small, simple trees (NUM_LEAVES < 10, N_ESTIMATORS about 100, LEARNING_RATE about 0.1) seem to result in fair quantile estimates that are properly calibrated (between 0.22 and 0.28 when 0.25 is expected). It seems more sensitive to parameter tuning than sklearn's quantiles, but at least now it seems possible to find good parameters.

figure_1

@joseortiz3
Copy link

joseortiz3 commented Jan 15, 2018

d412004 definitely made quantiles behave better. Seems way more reliable now. Now testing e7d5691...Ok, seems like it's still working great, not sure what the difference is but great job :)

@guolinke
Copy link
Collaborator Author

@joseortiz3 Thanks for the feedback.

After refine the percentile solution, the result is much better now.

Script:

import numpy as np
import matplotlib.pyplot as plt
import lightgbm as lgb
from sklearn.ensemble import GradientBoostingRegressor
from gbdt_quantiles import plot_figure
np.random.seed(1)

# Use sklearn or lightgbm?
USE_SKLEARN = True # Toggle this to observe issue.


# Quantile to Estimate
alpha = 0.75
# Training data size
N_DATA = 1000
# Function to Estimate
def f(x):
    """The function to predict."""
    return x * np.sin(x)

# model parameters
LEARNING_RATE = 0.1
N_ESTIMATORS = 100
NUM_LEAVES = 8 # lgbm only
MAX_DEPTH = 3
MIN_DATA = 9

#---------------------- DATA GENERATION ------------------- #

#  First the noiseless case
X = np.atleast_2d(np.random.uniform(0, 10.0, size=N_DATA)).T
X = X.astype(np.float32)

# Observations
y = f(X).ravel()

dy = 1.5 + 1.0 * np.random.random(y.shape)
noise = np.random.normal(0, dy)
y += noise
y = y.astype(np.float32)

# Mesh the input space for evaluations of the real function, the prediction and
# its MSE
xx = np.atleast_2d(np.linspace(0, 10, 9999)).T
xx = xx.astype(np.float32)


# Train high, low, and mean regressors.
# ------------------- HIGH/UPPER BOUND ------------------- #
if USE_SKLEARN:
    clfh = GradientBoostingRegressor(loss='quantile', alpha=alpha,
                n_estimators=N_ESTIMATORS, max_depth=MAX_DEPTH,
                learning_rate=LEARNING_RATE, min_samples_leaf=MIN_DATA,
                min_samples_split=MIN_DATA)
    clfh.fit(X, y)
else:
    ## ADDED
    clfh = lgb.LGBMRegressor(objective = 'quantile',
                            alpha = alpha,
                            num_leaves = NUM_LEAVES,
                            learning_rate = LEARNING_RATE,
                            n_estimators = N_ESTIMATORS,
                            max_depth = MAX_DEPTH,
                            min_data = MIN_DATA)
    clfh.fit(X, y,
            #eval_set=[(X, y)],
            #eval_metric='quantile'
           )
    ## END ADDED

# ------------------- LOW/LOWER BOUND ------------------- #

if USE_SKLEARN:
    clfl = GradientBoostingRegressor(loss='quantile', alpha=1.0-alpha,
        n_estimators=N_ESTIMATORS, max_depth=MAX_DEPTH,
        learning_rate=LEARNING_RATE, min_samples_leaf=MIN_DATA,
        min_samples_split=MIN_DATA)
    clfl.fit(X, y)
else:
    ## ADDED
    clfl = lgb.LGBMRegressor(objective = 'quantile',
                            alpha = 1.0 - alpha,
                            num_leaves = NUM_LEAVES,
                            learning_rate = LEARNING_RATE,
                            n_estimators = N_ESTIMATORS,
                            max_depth = MAX_DEPTH,
                            min_data = MIN_DATA)
    clfl.fit(X, y,
            #eval_set=[(X, y)],
            #eval_metric='quantile'
            )
    ## END ADDED

# ------------------- MEAN/PREDICTION ------------------- #

if USE_SKLEARN:
    clf = GradientBoostingRegressor(loss='lad',
            n_estimators=N_ESTIMATORS, max_depth=MAX_DEPTH,
            learning_rate=LEARNING_RATE, min_samples_leaf=MIN_DATA,
            min_samples_split=MIN_DATA)
    clf.fit(X, y)
else:
    ## ADDED
    clf = lgb.LGBMRegressor(objective = 'regression_l1',
                            num_leaves = NUM_LEAVES,
                            learning_rate = LEARNING_RATE,
                            n_estimators = N_ESTIMATORS,
                            max_depth = MAX_DEPTH,
                            min_data = MIN_DATA)
    clf.fit(X, y,
            #eval_set=[(X, y)],
            #eval_metric='l2',
            #early_stopping_rounds=5
            )
    ## END ADDED

# ---------------- PREDICTING ----------------- #

# Make the prediction on the meshed x-axis
y_pred = clf.predict(xx)
y_lower = clfl.predict(xx)
y_upper = clfh.predict(xx)

# Check calibration by predicting the training data.
y_autopred = clf.predict(X)
y_autolow = clfl.predict(X)
y_autohigh = clfh.predict(X)
frac_below_upper = round(np.count_nonzero(y_autohigh > y) / len(y),3)
frac_above_upper = round(np.count_nonzero(y_autohigh < y) / len(y),3)
frac_above_lower = round(np.count_nonzero(y_autolow < y) / len(y),3)
frac_below_lower = round(np.count_nonzero(y_autolow > y) / len(y),3)

# Print calibration test
print('fraction below upper estimate: \t actual: ' + str(frac_below_upper) + '\t ideal: ' + str(alpha))
print('fraction above lower estimate: \t actual: ' + str(frac_above_lower) + '\t ideal: ' + str(alpha))

# ------------------- PLOTTING ----------------- #

plt.plot(xx, f(xx), 'g:', label=u'$f(x) = x\,\sin(x)$')
plt.plot(X, y, 'b.', markersize=3, label=u'Observations')
plt.plot(xx, y_pred, 'r-', label=u'Mean Prediction')
plt.plot(xx, y_upper, 'k-')
plt.plot(xx, y_lower, 'k-')
plt.fill(np.concatenate([xx, xx[::-1]]),
            np.concatenate([y_upper, y_lower[::-1]]),
            alpha=.5, fc='b', ec='None', label=(str(round(100*(alpha-0.5)*2))+'% prediction interval'))
plt.scatter(x=X[y_autohigh < y], y=y[y_autohigh < y], s=20, marker='x', c = 'red', 
        label = str(round(100*frac_above_upper,1))+'% of training data above upper (expect '+str(round(100*(1-alpha),1))+'%)')
plt.scatter(x=X[y_autolow > y], y=y[y_autolow > y], s=20, marker='x', c = 'orange', 
        label = str(round(100*frac_below_lower,1))+ '% of training data below lower (expect '+str(round(100*(1-alpha),1))+'%)')
plt.xlabel('$x$')
plt.ylabel('$f(x)$')
plt.ylim(-10, 20)
plt.legend(loc='upper left')
plt.title(  '  Alpha: '+str(alpha) +
            '  Sklearn?: '+str(USE_SKLEARN) +
            '  N_est: '+str(N_ESTIMATORS) +
            '  L_rate: '+str(LEARNING_RATE) +
            '  N_Leaf: '+str(NUM_LEAVES))

plt.show()

Sklearn's result:
image

LightGBM's result:
image

@Laurae2
Copy link
Contributor

Laurae2 commented Jan 15, 2018

@guolinke can you compare when squaring labels before training, and using L2/MSE for training?

@guolinke
Copy link
Collaborator Author

@Laurae2 Did you mean use sqrt ?

image

The label distribution of this dataset is uniform, as a result, the sqrt seems doesn't work.

@guolinke
Copy link
Collaborator Author

@Laurae2 Do you have other regression datasets to test this ?

@guolinke
Copy link
Collaborator Author

guolinke commented Jan 15, 2018

@joseortiz3
d412004 simply uses one data-point to approximate the percentile point.

e7d5691 uses two data-points with linear interpolation to approximate the percentile point.

When #data is large, these two solution are almost the same. But second solution will be more stable when #data is small.

@guolinke guolinke requested a review from henry0312 January 15, 2018 06:47
@henry0312
Copy link
Contributor

I'll check...
If i can't give you any response in a week, please ignore me.

@guolinke
Copy link
Collaborator Author

@henry0312 it is okay, take your time.
You can always provide feedback even after the merge of this PR.

@Laurae2
Copy link
Contributor

Laurae2 commented Jan 15, 2018

@guolinke
Copy link
Collaborator Author

guolinke commented Jan 16, 2018

@Laurae2 Thanks so much!

the result:

import numpy as np
import pandas as pd
import lightgbm as lgb

from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
from scipy.stats import skew, boxcox

train_data = pd.read_csv('train.csv')
train_size=train_data.shape[0]
test_data = pd.read_csv('test.csv')
# Merge data
full_data=pd.concat([train_data,test_data])
del( train_data, test_data)


data_types = full_data.dtypes  
cat_cols = list(data_types[data_types=='object'].index)
num_cols = list(data_types[data_types=='int64'].index) + list(data_types[data_types=='float64'].index)

id_col = 'id'
target_col = 'loss'
num_cols.remove('id')
num_cols.remove('loss')


SSL = StandardScaler()
skewed_cols = full_data[num_cols].apply(lambda x: skew(x.dropna()))
skewed_cols = skewed_cols[skewed_cols > 0.25].index.values

for skewed_col in skewed_cols:
    full_data[skewed_col], lam = boxcox(full_data[skewed_col] + 1)

for num_col in num_cols:
    full_data[num_col] = SSL.fit_transform(full_data[num_col].values.reshape(-1,1))


for cat_name in cat_cols:
    full_data[cat_name] = full_data[cat_name].astype("category")


full_columns = cat_cols + num_cols
train_x = full_data[:train_size][full_columns]
test_x = full_data[train_size:][full_columns]
train_y = full_data[:train_size].loss.values
ID = full_data.id[:train_size].values

X_train, X_val, y_train, y_val = train_test_split(train_x, train_y, test_size=0.1, random_state=42)


train_set = lgb.Dataset(X_train, label=y_train)
val_set = lgb.Dataset(X_val, label=y_val)

params = {"objective": "regression_l1",
"metric":"l1",
"reg_sqrt":True,
"num_leaves": 130,
"learning_rate": 0.02,
"min_data": 150,
"sub_feature": 0.5,
"bagging_fraction": 0.98,
"bagging_freq": 5,
"max_cat_threshold": 3,
"cat_l2": 20,
"lambda_l2": 5,
}

num_rounds = 2000
used_model = lgb.train(params, train_set, num_rounds, valid_sets=[val_set],early_stopping_rounds=50, verbose_eval=100)

params["reg_sqrt"] = False
model = lgb.train(params, train_set, num_rounds, valid_sets=[val_set],early_stopping_rounds=50, verbose_eval=100)

params["reg_sqrt"] = True
params["objective"] = "regression_l2"
model = lgb.train(params, train_set, num_rounds, valid_sets=[val_set],early_stopping_rounds=50, verbose_eval=100)

params["reg_sqrt"] = False
params["objective"] = "regression_l2"
model = lgb.train(params, train_set, num_rounds, valid_sets=[val_set],early_stopping_rounds=50, verbose_eval=100)

output:

>>> used_model = lgb.train(params, train_set, num_rounds, valid_sets=[val_set],early_stopping_rounds=50, verbose_eval=100)
[LightGBM] [Info] Total Bins 3812
[LightGBM] [Info] Number of data: 169486, number of used features: 122
[LightGBM] [Info] Start training from score 45.982388
Training until validation scores don't improve for 50 rounds.
[100]   valid_0's l1: 1241.4
[200]   valid_0's l1: 1164.38
[300]   valid_0's l1: 1148.03
[400]   valid_0's l1: 1141.93
[500]   valid_0's l1: 1138.52
[600]   valid_0's l1: 1136.67
[700]   valid_0's l1: 1135.32
[800]   valid_0's l1: 1133.99
[900]   valid_0's l1: 1132.95
[1000]  valid_0's l1: 1132.47
[1100]  valid_0's l1: 1132.03
[1200]  valid_0's l1: 1131.63
[1300]  valid_0's l1: 1131.34
[1400]  valid_0's l1: 1131.08
[1500]  valid_0's l1: 1130.94
[1600]  valid_0's l1: 1130.7
[1700]  valid_0's l1: 1130.53
[1800]  valid_0's l1: 1130.28
[1900]  valid_0's l1: 1130.11
Early stopping, best iteration is:
[1912]  valid_0's l1: 1130.06
>>>
>>> params["reg_sqrt"] = False
>>> model = lgb.train(params, train_set, num_rounds, valid_sets=[val_set],early_stopping_rounds=50, verbose_eval=100)
[LightGBM] [Info] Total Bins 3812
[LightGBM] [Info] Number of data: 169486, number of used features: 122
[LightGBM] [Info] Start training from score 2114.379883
Training until validation scores don't improve for 50 rounds.
[100]   valid_0's l1: 1232.74
[200]   valid_0's l1: 1165.64
[300]   valid_0's l1: 1150.57
[400]   valid_0's l1: 1144.76
[500]   valid_0's l1: 1141.65
[600]   valid_0's l1: 1139.52
[700]   valid_0's l1: 1138.19
[800]   valid_0's l1: 1137
[900]   valid_0's l1: 1136.29
[1000]  valid_0's l1: 1135.65
[1100]  valid_0's l1: 1134.95
[1200]  valid_0's l1: 1134.42
[1300]  valid_0's l1: 1133.94
[1400]  valid_0's l1: 1133.61
Early stopping, best iteration is:
[1444]  valid_0's l1: 1133.5
>>>
>>> params["reg_sqrt"] = True
>>> params["objective"] = "regression_l2"
>>> model = lgb.train(params, train_set, num_rounds, valid_sets=[val_set],early_stopping_rounds=50, verbose_eval=100)
[LightGBM] [Info] Total Bins 3812
[LightGBM] [Info] Number of data: 169486, number of used features: 122
[LightGBM] [Info] Start training from score 50.690943
Training until validation scores don't improve for 50 rounds.
[100]   valid_0's l1: 1234.11
[200]   valid_0's l1: 1164.91
[300]   valid_0's l1: 1150.17
[400]   valid_0's l1: 1144.72
[500]   valid_0's l1: 1141.98
[600]   valid_0's l1: 1140.51
[700]   valid_0's l1: 1139.84
[800]   valid_0's l1: 1139.03
[900]   valid_0's l1: 1138.69
[1000]  valid_0's l1: 1138.36
Early stopping, best iteration is:
[1005]  valid_0's l1: 1138.33
>>>
>>> params["reg_sqrt"] = False
>>> params["objective"] = "regression_l2"
>>> model = lgb.train(params, train_set, num_rounds, valid_sets=[val_set],early_stopping_rounds=50, verbose_eval=100)
[LightGBM] [Info] Total Bins 3812
[LightGBM] [Info] Number of data: 169486, number of used features: 122
[LightGBM] [Info] Start training from score 3037.660480
Training until validation scores don't improve for 50 rounds.
[100]   valid_0's l1: 1272.72
[200]   valid_0's l1: 1198.76
[300]   valid_0's l1: 1180.78
[400]   valid_0's l1: 1174.3
[500]   valid_0's l1: 1171.89
[600]   valid_0's l1: 1170.19
[700]   valid_0's l1: 1169.11
[800]   valid_0's l1: 1168.38
[900]   valid_0's l1: 1167.67
Early stopping, best iteration is:
[900]   valid_0's l1: 1167.67

overall, the new L1 objective is better.

updated:

add benchmark for mape objective:

>>> params = {"objective": "mape",
... "metric":"mape",
... "reg_sqrt":True,
... "num_leaves": 130,
... "learning_rate": 0.02,
... "min_data": 150,
... "sub_feature": 0.5,
... "bagging_fraction": 0.98,
... "bagging_freq": 5,
... "max_cat_threshold": 3,
... "cat_l2": 20,
... "lambda_l2": 5,
... }
>>>
>>> num_rounds = 2000
>>> used_model = lgb.train(params, train_set, num_rounds, valid_sets=[val_set],early_stopping_rounds=50, verbose_eval=100)
[LightGBM] [Warning] Met 'abs(label) < 1', will convert them to '1' in Mape objective and metric.
[LightGBM] [Info] Total Bins 3812
[LightGBM] [Info] Number of data: 169486, number of used features: 122
[LightGBM] [Info] Start training from score 38.768780
Training until validation scores don't improve for 50 rounds.
[100]   valid_0's mape: 0.45459
[200]   valid_0's mape: 0.438935
[300]   valid_0's mape: 0.434603
[400]   valid_0's mape: 0.433005
[500]   valid_0's mape: 0.432443
[600]   valid_0's mape: 0.431996
[700]   valid_0's mape: 0.431601
[800]   valid_0's mape: 0.431349
[900]   valid_0's mape: 0.431199
[1000]  valid_0's mape: 0.431064
[1100]  valid_0's mape: 0.430969
Early stopping, best iteration is:
[1064]  valid_0's mape: 0.430941
>>>
>>> params["reg_sqrt"] = False
>>> model = lgb.train(params, train_set, num_rounds, valid_sets=[val_set],early_stopping_rounds=50, verbose_eval=100)
[LightGBM] [Warning] Met 'abs(label) < 1', will convert them to '1' in Mape objective and metric.
[LightGBM] [Info] Total Bins 3812
[LightGBM] [Info] Number of data: 169486, number of used features: 122
[LightGBM] [Info] Start training from score 1096.866699
Training until validation scores don't improve for 50 rounds.
[100]   valid_0's mape: 0.458532
[200]   valid_0's mape: 0.454369
Early stopping, best iteration is:
[159]   valid_0's mape: 0.453678
>>>
>>> params["reg_sqrt"] = True
>>> params["objective"] = "regression_l1"
>>> model = lgb.train(params, train_set, num_rounds, valid_sets=[val_set],early_stopping_rounds=50, verbose_eval=100)
[LightGBM] [Info] Total Bins 3812
[LightGBM] [Info] Number of data: 169486, number of used features: 122
[LightGBM] [Info] Start training from score 45.982388
Training until validation scores don't improve for 50 rounds.
[100]   valid_0's mape: 0.520485
[200]   valid_0's mape: 0.493813
[300]   valid_0's mape: 0.487743
[400]   valid_0's mape: 0.485336
[500]   valid_0's mape: 0.484067
[600]   valid_0's mape: 0.483219
[700]   valid_0's mape: 0.482755
[800]   valid_0's mape: 0.482325
[900]   valid_0's mape: 0.482033
[1000]  valid_0's mape: 0.481856
[1100]  valid_0's mape: 0.481758
[1200]  valid_0's mape: 0.481555
[1300]  valid_0's mape: 0.481482
[1400]  valid_0's mape: 0.481368
Early stopping, best iteration is:
[1388]  valid_0's mape: 0.481353
>>>
>>> params["reg_sqrt"] = False
>>> params["objective"] = "regression_l1"
>>> model = lgb.train(params, train_set, num_rounds, valid_sets=[val_set],early_stopping_rounds=50, verbose_eval=100)
[LightGBM] [Info] Total Bins 3812
[LightGBM] [Info] Number of data: 169486, number of used features: 122
[LightGBM] [Info] Start training from score 2114.379883
Training until validation scores don't improve for 50 rounds.
[100]   valid_0's mape: 0.527433
[200]   valid_0's mape: 0.497872
[300]   valid_0's mape: 0.4918
[400]   valid_0's mape: 0.489036
[500]   valid_0's mape: 0.487705
[600]   valid_0's mape: 0.486944
[700]   valid_0's mape: 0.486358
[800]   valid_0's mape: 0.48593
[900]   valid_0's mape: 0.485569
[1000]  valid_0's mape: 0.485325
[1100]  valid_0's mape: 0.485153
[1200]  valid_0's mape: 0.485083
[1300]  valid_0's mape: 0.484913
[1400]  valid_0's mape: 0.484787
[1500]  valid_0's mape: 0.484732
[1600]  valid_0's mape: 0.484713
Early stopping, best iteration is:
[1578]  valid_0's mape: 0.484685
>>>
>>>
>>> params["reg_sqrt"] = True
>>> params["objective"] = "regression_l2"
>>> model = lgb.train(params, train_set, num_rounds, valid_sets=[val_set],early_stopping_rounds=50, verbose_eval=100)
[LightGBM] [Info] Total Bins 3812
[LightGBM] [Info] Number of data: 169486, number of used features: 122
[LightGBM] [Info] Start training from score 50.690943
Training until validation scores don't improve for 50 rounds.
[100]   valid_0's mape: 0.587769
[200]   valid_0's mape: 0.538626
[300]   valid_0's mape: 0.525753
[400]   valid_0's mape: 0.520771
[500]   valid_0's mape: 0.518513
[600]   valid_0's mape: 0.517409
[700]   valid_0's mape: 0.516687
[800]   valid_0's mape: 0.516096
[900]   valid_0's mape: 0.515683
[1000]  valid_0's mape: 0.515387
[1100]  valid_0's mape: 0.515168
[1200]  valid_0's mape: 0.514913
[1300]  valid_0's mape: 0.514759
[1400]  valid_0's mape: 0.514568
[1500]  valid_0's mape: 0.514466
[1600]  valid_0's mape: 0.51439
[1700]  valid_0's mape: 0.514125
[1800]  valid_0's mape: 0.514087
[1900]  valid_0's mape: 0.513967
Early stopping, best iteration is:
[1855]  valid_0's mape: 0.51388
>>>
>>> params["reg_sqrt"] = False
>>> params["objective"] = "regression_l2"
>>> model = lgb.train(params, train_set, num_rounds, valid_sets=[val_set],early_stopping_rounds=50, verbose_eval=100)
[LightGBM] [Info] Total Bins 3812
[LightGBM] [Info] Number of data: 169486, number of used features: 122
[LightGBM] [Info] Start training from score 3037.660480
Training until validation scores don't improve for 50 rounds.
[100]   valid_0's mape: 0.693352
[200]   valid_0's mape: 0.609559
[300]   valid_0's mape: 0.586393
[400]   valid_0's mape: 0.577855
[500]   valid_0's mape: 0.574299
[600]   valid_0's mape: 0.572318
[700]   valid_0's mape: 0.571067
[800]   valid_0's mape: 0.570203
[900]   valid_0's mape: 0.569416
[1000]  valid_0's mape: 0.569294
[1100]  valid_0's mape: 0.568971
[1200]  valid_0's mape: 0.568517
[1300]  valid_0's mape: 0.568295
[1400]  valid_0's mape: 0.567993
[1500]  valid_0's mape: 0.567871
[1600]  valid_0's mape: 0.567708
Early stopping, best iteration is:
[1570]  valid_0's mape: 0.567646

@guolinke guolinke merged commit 5392c9e into master Jan 16, 2018
@guolinke guolinke deleted the renew-tree-output branch January 16, 2018 06:17
@lock lock bot locked as resolved and limited conversation to collaborators Mar 11, 2020
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants