In [1]:
import numpy as np
import pandas as pd
import seaborn as sns
import lightgbm as lgb
import matplotlib.pyplot as plt
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.metrics import mean_squared_error, mean_absolute_error
from sklearn.linear_model import LinearRegression, Ridge, LassoCV
from sklearn.model_selection import train_test_split, GridSearchCV

%matplotlib inline

This means that in case of installing LightGBM from PyPI via the ``pip install lightgbm`` command, you don't need to install the gcc compiler anymore.
Instead of that, you need to install the OpenMP library, which is required for running LightGBM on the system with the Apple Clang compiler.
You can install the OpenMP library by the following command: ``brew install libomp``.


### 读取预处理好的数据：

In [2]:
train_data = pd.read_csv('data/train_data_v1.csv', sep=' ')
test_data = pd.read_csv('data/test_data_v1.csv', sep=' ')
print(train_data.shape)
print(test_data.shape)

(149999, 37)
(50000, 37)


### 查看训练数据：

In [3]:
train_data.head()

Unnamed: 0,SaleID,bodyType,brand,creatDate,fuelType,gearbox,kilometer,model,notRepairedDamage,power,...,v_9,name_count,regDates,creatDates,regDate_year,regDate_month,regDate_day,creatDate_year,creatDate_month,creatDate_day
0,0,1.0,6,20160404,0.0,0.0,12.5,30.0,0.0,60,...,0.097462,108,2004-04-02,2016-04-04,2004,4,2,2016,4,4
1,1,2.0,1,20160309,0.0,0.0,15.0,40.0,0.0,0,...,0.020582,29,2003-03-01,2016-03-09,2003,3,1,2016,3,9
2,2,1.0,15,20160402,0.0,0.0,12.5,115.0,0.0,163,...,0.027075,3,2004-04-03,2016-04-02,2004,4,3,2016,4,2
3,3,0.0,10,20160312,0.0,1.0,15.0,109.0,0.0,193,...,0.0,2,1996-09-08,2016-03-12,1996,9,8,2016,3,12
4,4,1.0,5,20160313,0.0,0.0,5.0,110.0,0.0,68,...,0.121534,1,2012-01-03,2016-03-13,2012,1,3,2016,3,13


### 查看数据的统计信息：

In [4]:
train_data.describe()

Unnamed: 0,SaleID,bodyType,brand,creatDate,fuelType,gearbox,kilometer,model,notRepairedDamage,power,...,v_7,v_8,v_9,name_count,regDate_year,regDate_month,regDate_day,creatDate_year,creatDate_month,creatDate_day
count,149999.0,149999.0,149999.0,149999.0,149999.0,149999.0,149999.0,149999.0,149999.0,149999.0,...,149999.0,149999.0,149999.0,149999.0,149999.0,149999.0,149999.0,149999.0,149999.0,149999.0
mean,74999.493837,1.738525,8.052727,20160330.0,0.354096,0.215975,12.597144,47.128581,0.095434,116.861752,...,0.124693,0.058144,0.061995,16.621251,2003.357196,5.998393,6.502863,2015.99988,3.161581,15.833826
std,43301.5588,1.760784,7.864982,106.7332,0.539748,0.411498,3.919584,49.536165,0.293814,70.07484,...,0.20141,0.029185,0.035692,48.697958,5.362246,3.52159,3.450316,0.010954,0.38071,9.132285
min,0.0,0.0,0.0,20150620.0,0.0,0.0,0.5,0.0,0.0,0.0,...,0.0,0.0,0.0,1.0,1991.0,1.0,1.0,2015.0,1.0,1.0
25%,37499.5,0.0,1.0,20160310.0,0.0,0.0,12.5,10.0,0.0,75.0,...,0.062474,0.035334,0.03393,1.0,1999.0,3.0,4.0,2016.0,3.0,8.0
50%,74999.0,1.0,6.0,20160320.0,0.0,0.0,15.0,30.0,0.0,110.0,...,0.095867,0.057014,0.058483,1.0,2003.0,6.0,7.0,2016.0,3.0,16.0
75%,112499.5,3.0,13.0,20160330.0,1.0,0.0,15.0,66.0,0.0,150.0,...,0.125243,0.079382,0.087489,7.0,2007.0,9.0,9.0,2016.0,3.0,24.0
max,149999.0,7.0,39.0,20160410.0,6.0,1.0,15.0,247.0,1.0,600.0,...,1.404936,0.160791,0.222787,376.0,2015.0,12.0,12.0,2016.0,12.0,31.0


### 提取所需特征列：

In [5]:
numerical_cols = train_data.select_dtypes(exclude = 'object').columns
print(numerical_cols)

Index(['SaleID', 'bodyType', 'brand', 'creatDate', 'fuelType', 'gearbox',
       'kilometer', 'model', 'notRepairedDamage', 'power', 'price', 'regDate',
       'regionCode', 'v_0', 'v_1', 'v_10', 'v_11', 'v_12', 'v_13', 'v_14',
       'v_2', 'v_3', 'v_4', 'v_5', 'v_6', 'v_7', 'v_8', 'v_9', 'name_count',
       'regDate_year', 'regDate_month', 'regDate_day', 'creatDate_year',
       'creatDate_month', 'creatDate_day'],
      dtype='object')


In [6]:
feature_cols = [x for x in train_data.columns if x not in ['SaleID','name','regDates','creatDates','price','model','brand','regionCode','creatDate']]

### 构建数据集并进行划分（Train, Val）：

In [7]:
train_X = train_data[feature_cols]
test_X = test_data[feature_cols]
train_Y = train_data['price']

In [8]:
print('X train shape:',train_X.shape)
print('X test shape:',test_X.shape)
print('Y train shape:',train_Y.shape)

X train shape: (149999, 29)
X test shape: (50000, 29)
Y train shape: (149999,)


In [9]:
x_train, x_val, y_train, y_val = train_test_split(train_X, train_Y, test_size=0.3)

### 通过五种模型进行训练，并用MAE评价标准进行比较：

In [10]:
# 线性回归
model_1 = LinearRegression()
model_1.fit(x_train, y_train)
pred_1 = model_1.predict(x_val)
mae_1 = mean_absolute_error(y_val, pred_1)
print('MAE = ', mae_1)

MAE =  0.19394101366403463


In [11]:
# 岭回归
model_2 = Ridge(alpha=0.8)
model_2.fit(x_train, y_train)
pred_2 = model_2.predict(x_val)
mae_2 = mean_absolute_error(y_val, pred_2)
print('MAE = ', mae_2)

MAE =  0.1972789542194812


In [12]:
# Lasso回归
model_3 = LassoCV()
model_3.fit(x_train, y_train)
pred_3 = model_3.predict(x_val)
mae_3 = mean_absolute_error(y_val, pred_3)
print('MAE = ', mae_3)

MAE =  0.6046481117934874


In [13]:
# GDBT
gdbt = GradientBoostingRegressor()
gdbt.fit(x_train, y_train)
pred_4 = gdbt.predict(x_val)
mae_4 = mean_absolute_error(y_val, pred_4)
print('MAE = ', mae_4)

MAE =  0.17866209854560852


In [14]:
# LightGBM
estimator = lgb.LGBMRegressor(num_leaves=63, n_estimators=100)
param_grid = {
    'learning_rate': [0.01, 0.05, 0.1],
    }
gbm = GridSearchCV(estimator, param_grid)
gbm.fit(x_train, y_train)
pred_5 = gbm.predict(x_val)
mae_5 = mean_absolute_error(y_val, pred_5)
print('MAE = ', mae_5)

MAE =  0.13801210531454872


##### 通过对比可知，LightGBM训练得到的模型效果更好，故我们采用LightGBM训练模型并进行预测。

### 采用LightGBM模型在原始数据集上进行预测：

In [15]:
estimator = lgb.LGBMRegressor(num_leaves=63, n_estimators=100)
param_grid = {
    'learning_rate': [0.01, 0.05, 0.1],
    }
pred_model = GridSearchCV(estimator, param_grid)
pred_model.fit(train_X, train_Y)
price = pred_model.predict(test_X)

### 将预测值生成指定格式的csv文件：

In [16]:
submit = pd.DataFrame()
submit['SaleID'] = test_data.SaleID
submit['price'] = price
submit.to_csv('output/submit.csv',index=False)

In [17]:
submit.head(10)

Unnamed: 0,SaleID,price
0,200000,7.146314
1,200001,7.54877
2,200002,8.888396
3,200003,7.08409
4,200004,7.582857
5,200005,7.127663
6,200006,5.990862
7,200007,8.110016
8,200008,9.461447
9,200009,6.423404
