-
Notifications
You must be signed in to change notification settings - Fork 3.9k
Fix objective functions with zero hessian #1199
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
ping @joseortiz3 , I also fix the quantile loss in this PR, can you also have a try and provide the feedback? the new solution doesn't need the |
The new |
@joseortiz3 Thanks for the feedback. After refine the percentile solution, the result is much better now. Script: import numpy as np
import matplotlib.pyplot as plt
import lightgbm as lgb
from sklearn.ensemble import GradientBoostingRegressor
from gbdt_quantiles import plot_figure
np.random.seed(1)
# Use sklearn or lightgbm?
USE_SKLEARN = True # Toggle this to observe issue.
# Quantile to Estimate
alpha = 0.75
# Training data size
N_DATA = 1000
# Function to Estimate
def f(x):
"""The function to predict."""
return x * np.sin(x)
# model parameters
LEARNING_RATE = 0.1
N_ESTIMATORS = 100
NUM_LEAVES = 8 # lgbm only
MAX_DEPTH = 3
MIN_DATA = 9
#---------------------- DATA GENERATION ------------------- #
# First the noiseless case
X = np.atleast_2d(np.random.uniform(0, 10.0, size=N_DATA)).T
X = X.astype(np.float32)
# Observations
y = f(X).ravel()
dy = 1.5 + 1.0 * np.random.random(y.shape)
noise = np.random.normal(0, dy)
y += noise
y = y.astype(np.float32)
# Mesh the input space for evaluations of the real function, the prediction and
# its MSE
xx = np.atleast_2d(np.linspace(0, 10, 9999)).T
xx = xx.astype(np.float32)
# Train high, low, and mean regressors.
# ------------------- HIGH/UPPER BOUND ------------------- #
if USE_SKLEARN:
clfh = GradientBoostingRegressor(loss='quantile', alpha=alpha,
n_estimators=N_ESTIMATORS, max_depth=MAX_DEPTH,
learning_rate=LEARNING_RATE, min_samples_leaf=MIN_DATA,
min_samples_split=MIN_DATA)
clfh.fit(X, y)
else:
## ADDED
clfh = lgb.LGBMRegressor(objective = 'quantile',
alpha = alpha,
num_leaves = NUM_LEAVES,
learning_rate = LEARNING_RATE,
n_estimators = N_ESTIMATORS,
max_depth = MAX_DEPTH,
min_data = MIN_DATA)
clfh.fit(X, y,
#eval_set=[(X, y)],
#eval_metric='quantile'
)
## END ADDED
# ------------------- LOW/LOWER BOUND ------------------- #
if USE_SKLEARN:
clfl = GradientBoostingRegressor(loss='quantile', alpha=1.0-alpha,
n_estimators=N_ESTIMATORS, max_depth=MAX_DEPTH,
learning_rate=LEARNING_RATE, min_samples_leaf=MIN_DATA,
min_samples_split=MIN_DATA)
clfl.fit(X, y)
else:
## ADDED
clfl = lgb.LGBMRegressor(objective = 'quantile',
alpha = 1.0 - alpha,
num_leaves = NUM_LEAVES,
learning_rate = LEARNING_RATE,
n_estimators = N_ESTIMATORS,
max_depth = MAX_DEPTH,
min_data = MIN_DATA)
clfl.fit(X, y,
#eval_set=[(X, y)],
#eval_metric='quantile'
)
## END ADDED
# ------------------- MEAN/PREDICTION ------------------- #
if USE_SKLEARN:
clf = GradientBoostingRegressor(loss='lad',
n_estimators=N_ESTIMATORS, max_depth=MAX_DEPTH,
learning_rate=LEARNING_RATE, min_samples_leaf=MIN_DATA,
min_samples_split=MIN_DATA)
clf.fit(X, y)
else:
## ADDED
clf = lgb.LGBMRegressor(objective = 'regression_l1',
num_leaves = NUM_LEAVES,
learning_rate = LEARNING_RATE,
n_estimators = N_ESTIMATORS,
max_depth = MAX_DEPTH,
min_data = MIN_DATA)
clf.fit(X, y,
#eval_set=[(X, y)],
#eval_metric='l2',
#early_stopping_rounds=5
)
## END ADDED
# ---------------- PREDICTING ----------------- #
# Make the prediction on the meshed x-axis
y_pred = clf.predict(xx)
y_lower = clfl.predict(xx)
y_upper = clfh.predict(xx)
# Check calibration by predicting the training data.
y_autopred = clf.predict(X)
y_autolow = clfl.predict(X)
y_autohigh = clfh.predict(X)
frac_below_upper = round(np.count_nonzero(y_autohigh > y) / len(y),3)
frac_above_upper = round(np.count_nonzero(y_autohigh < y) / len(y),3)
frac_above_lower = round(np.count_nonzero(y_autolow < y) / len(y),3)
frac_below_lower = round(np.count_nonzero(y_autolow > y) / len(y),3)
# Print calibration test
print('fraction below upper estimate: \t actual: ' + str(frac_below_upper) + '\t ideal: ' + str(alpha))
print('fraction above lower estimate: \t actual: ' + str(frac_above_lower) + '\t ideal: ' + str(alpha))
# ------------------- PLOTTING ----------------- #
plt.plot(xx, f(xx), 'g:', label=u'$f(x) = x\,\sin(x)$')
plt.plot(X, y, 'b.', markersize=3, label=u'Observations')
plt.plot(xx, y_pred, 'r-', label=u'Mean Prediction')
plt.plot(xx, y_upper, 'k-')
plt.plot(xx, y_lower, 'k-')
plt.fill(np.concatenate([xx, xx[::-1]]),
np.concatenate([y_upper, y_lower[::-1]]),
alpha=.5, fc='b', ec='None', label=(str(round(100*(alpha-0.5)*2))+'% prediction interval'))
plt.scatter(x=X[y_autohigh < y], y=y[y_autohigh < y], s=20, marker='x', c = 'red',
label = str(round(100*frac_above_upper,1))+'% of training data above upper (expect '+str(round(100*(1-alpha),1))+'%)')
plt.scatter(x=X[y_autolow > y], y=y[y_autolow > y], s=20, marker='x', c = 'orange',
label = str(round(100*frac_below_lower,1))+ '% of training data below lower (expect '+str(round(100*(1-alpha),1))+'%)')
plt.xlabel('$x$')
plt.ylabel('$f(x)$')
plt.ylim(-10, 20)
plt.legend(loc='upper left')
plt.title( ' Alpha: '+str(alpha) +
' Sklearn?: '+str(USE_SKLEARN) +
' N_est: '+str(N_ESTIMATORS) +
' L_rate: '+str(LEARNING_RATE) +
' N_Leaf: '+str(NUM_LEAVES))
plt.show() |
@guolinke can you compare when squaring labels before training, and using L2/MSE for training? |
@Laurae2 Did you mean use The label distribution of this dataset is uniform, as a result, the |
@Laurae2 Do you have other regression datasets to test this ? |
@joseortiz3 e7d5691 uses two data-points with linear interpolation to approximate the percentile point. When #data is large, these two solution are almost the same. But second solution will be more stable when #data is small. |
I'll check... |
@henry0312 it is okay, take your time. |
d324d8f
to
b0380da
Compare
@Laurae2 Thanks so much! the result: import numpy as np
import pandas as pd
import lightgbm as lgb
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
from scipy.stats import skew, boxcox
train_data = pd.read_csv('train.csv')
train_size=train_data.shape[0]
test_data = pd.read_csv('test.csv')
# Merge data
full_data=pd.concat([train_data,test_data])
del( train_data, test_data)
data_types = full_data.dtypes
cat_cols = list(data_types[data_types=='object'].index)
num_cols = list(data_types[data_types=='int64'].index) + list(data_types[data_types=='float64'].index)
id_col = 'id'
target_col = 'loss'
num_cols.remove('id')
num_cols.remove('loss')
SSL = StandardScaler()
skewed_cols = full_data[num_cols].apply(lambda x: skew(x.dropna()))
skewed_cols = skewed_cols[skewed_cols > 0.25].index.values
for skewed_col in skewed_cols:
full_data[skewed_col], lam = boxcox(full_data[skewed_col] + 1)
for num_col in num_cols:
full_data[num_col] = SSL.fit_transform(full_data[num_col].values.reshape(-1,1))
for cat_name in cat_cols:
full_data[cat_name] = full_data[cat_name].astype("category")
full_columns = cat_cols + num_cols
train_x = full_data[:train_size][full_columns]
test_x = full_data[train_size:][full_columns]
train_y = full_data[:train_size].loss.values
ID = full_data.id[:train_size].values
X_train, X_val, y_train, y_val = train_test_split(train_x, train_y, test_size=0.1, random_state=42)
train_set = lgb.Dataset(X_train, label=y_train)
val_set = lgb.Dataset(X_val, label=y_val)
params = {"objective": "regression_l1",
"metric":"l1",
"reg_sqrt":True,
"num_leaves": 130,
"learning_rate": 0.02,
"min_data": 150,
"sub_feature": 0.5,
"bagging_fraction": 0.98,
"bagging_freq": 5,
"max_cat_threshold": 3,
"cat_l2": 20,
"lambda_l2": 5,
}
num_rounds = 2000
used_model = lgb.train(params, train_set, num_rounds, valid_sets=[val_set],early_stopping_rounds=50, verbose_eval=100)
params["reg_sqrt"] = False
model = lgb.train(params, train_set, num_rounds, valid_sets=[val_set],early_stopping_rounds=50, verbose_eval=100)
params["reg_sqrt"] = True
params["objective"] = "regression_l2"
model = lgb.train(params, train_set, num_rounds, valid_sets=[val_set],early_stopping_rounds=50, verbose_eval=100)
params["reg_sqrt"] = False
params["objective"] = "regression_l2"
model = lgb.train(params, train_set, num_rounds, valid_sets=[val_set],early_stopping_rounds=50, verbose_eval=100) output:
overall, the new L1 objective is better. updated: add benchmark for >>> params = {"objective": "mape",
... "metric":"mape",
... "reg_sqrt":True,
... "num_leaves": 130,
... "learning_rate": 0.02,
... "min_data": 150,
... "sub_feature": 0.5,
... "bagging_fraction": 0.98,
... "bagging_freq": 5,
... "max_cat_threshold": 3,
... "cat_l2": 20,
... "lambda_l2": 5,
... }
>>>
>>> num_rounds = 2000
>>> used_model = lgb.train(params, train_set, num_rounds, valid_sets=[val_set],early_stopping_rounds=50, verbose_eval=100)
[LightGBM] [Warning] Met 'abs(label) < 1', will convert them to '1' in Mape objective and metric.
[LightGBM] [Info] Total Bins 3812
[LightGBM] [Info] Number of data: 169486, number of used features: 122
[LightGBM] [Info] Start training from score 38.768780
Training until validation scores don't improve for 50 rounds.
[100] valid_0's mape: 0.45459
[200] valid_0's mape: 0.438935
[300] valid_0's mape: 0.434603
[400] valid_0's mape: 0.433005
[500] valid_0's mape: 0.432443
[600] valid_0's mape: 0.431996
[700] valid_0's mape: 0.431601
[800] valid_0's mape: 0.431349
[900] valid_0's mape: 0.431199
[1000] valid_0's mape: 0.431064
[1100] valid_0's mape: 0.430969
Early stopping, best iteration is:
[1064] valid_0's mape: 0.430941
>>>
>>> params["reg_sqrt"] = False
>>> model = lgb.train(params, train_set, num_rounds, valid_sets=[val_set],early_stopping_rounds=50, verbose_eval=100)
[LightGBM] [Warning] Met 'abs(label) < 1', will convert them to '1' in Mape objective and metric.
[LightGBM] [Info] Total Bins 3812
[LightGBM] [Info] Number of data: 169486, number of used features: 122
[LightGBM] [Info] Start training from score 1096.866699
Training until validation scores don't improve for 50 rounds.
[100] valid_0's mape: 0.458532
[200] valid_0's mape: 0.454369
Early stopping, best iteration is:
[159] valid_0's mape: 0.453678
>>>
>>> params["reg_sqrt"] = True
>>> params["objective"] = "regression_l1"
>>> model = lgb.train(params, train_set, num_rounds, valid_sets=[val_set],early_stopping_rounds=50, verbose_eval=100)
[LightGBM] [Info] Total Bins 3812
[LightGBM] [Info] Number of data: 169486, number of used features: 122
[LightGBM] [Info] Start training from score 45.982388
Training until validation scores don't improve for 50 rounds.
[100] valid_0's mape: 0.520485
[200] valid_0's mape: 0.493813
[300] valid_0's mape: 0.487743
[400] valid_0's mape: 0.485336
[500] valid_0's mape: 0.484067
[600] valid_0's mape: 0.483219
[700] valid_0's mape: 0.482755
[800] valid_0's mape: 0.482325
[900] valid_0's mape: 0.482033
[1000] valid_0's mape: 0.481856
[1100] valid_0's mape: 0.481758
[1200] valid_0's mape: 0.481555
[1300] valid_0's mape: 0.481482
[1400] valid_0's mape: 0.481368
Early stopping, best iteration is:
[1388] valid_0's mape: 0.481353
>>>
>>> params["reg_sqrt"] = False
>>> params["objective"] = "regression_l1"
>>> model = lgb.train(params, train_set, num_rounds, valid_sets=[val_set],early_stopping_rounds=50, verbose_eval=100)
[LightGBM] [Info] Total Bins 3812
[LightGBM] [Info] Number of data: 169486, number of used features: 122
[LightGBM] [Info] Start training from score 2114.379883
Training until validation scores don't improve for 50 rounds.
[100] valid_0's mape: 0.527433
[200] valid_0's mape: 0.497872
[300] valid_0's mape: 0.4918
[400] valid_0's mape: 0.489036
[500] valid_0's mape: 0.487705
[600] valid_0's mape: 0.486944
[700] valid_0's mape: 0.486358
[800] valid_0's mape: 0.48593
[900] valid_0's mape: 0.485569
[1000] valid_0's mape: 0.485325
[1100] valid_0's mape: 0.485153
[1200] valid_0's mape: 0.485083
[1300] valid_0's mape: 0.484913
[1400] valid_0's mape: 0.484787
[1500] valid_0's mape: 0.484732
[1600] valid_0's mape: 0.484713
Early stopping, best iteration is:
[1578] valid_0's mape: 0.484685
>>>
>>>
>>> params["reg_sqrt"] = True
>>> params["objective"] = "regression_l2"
>>> model = lgb.train(params, train_set, num_rounds, valid_sets=[val_set],early_stopping_rounds=50, verbose_eval=100)
[LightGBM] [Info] Total Bins 3812
[LightGBM] [Info] Number of data: 169486, number of used features: 122
[LightGBM] [Info] Start training from score 50.690943
Training until validation scores don't improve for 50 rounds.
[100] valid_0's mape: 0.587769
[200] valid_0's mape: 0.538626
[300] valid_0's mape: 0.525753
[400] valid_0's mape: 0.520771
[500] valid_0's mape: 0.518513
[600] valid_0's mape: 0.517409
[700] valid_0's mape: 0.516687
[800] valid_0's mape: 0.516096
[900] valid_0's mape: 0.515683
[1000] valid_0's mape: 0.515387
[1100] valid_0's mape: 0.515168
[1200] valid_0's mape: 0.514913
[1300] valid_0's mape: 0.514759
[1400] valid_0's mape: 0.514568
[1500] valid_0's mape: 0.514466
[1600] valid_0's mape: 0.51439
[1700] valid_0's mape: 0.514125
[1800] valid_0's mape: 0.514087
[1900] valid_0's mape: 0.513967
Early stopping, best iteration is:
[1855] valid_0's mape: 0.51388
>>>
>>> params["reg_sqrt"] = False
>>> params["objective"] = "regression_l2"
>>> model = lgb.train(params, train_set, num_rounds, valid_sets=[val_set],early_stopping_rounds=50, verbose_eval=100)
[LightGBM] [Info] Total Bins 3812
[LightGBM] [Info] Number of data: 169486, number of used features: 122
[LightGBM] [Info] Start training from score 3037.660480
Training until validation scores don't improve for 50 rounds.
[100] valid_0's mape: 0.693352
[200] valid_0's mape: 0.609559
[300] valid_0's mape: 0.586393
[400] valid_0's mape: 0.577855
[500] valid_0's mape: 0.574299
[600] valid_0's mape: 0.572318
[700] valid_0's mape: 0.571067
[800] valid_0's mape: 0.570203
[900] valid_0's mape: 0.569416
[1000] valid_0's mape: 0.569294
[1100] valid_0's mape: 0.568971
[1200] valid_0's mape: 0.568517
[1300] valid_0's mape: 0.568295
[1400] valid_0's mape: 0.567993
[1500] valid_0's mape: 0.567871
[1600] valid_0's mape: 0.567708
Early stopping, best iteration is:
[1570] valid_0's mape: 0.567646 |
Use sklearn's solution for the objective function with zero heassian:
Some benchmarks on L1 objective function:
dataset: https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/regression.html#YearPredictionMSD
Result:
