In [5]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report
from sklearn.preprocessing import LabelEncoder
import xgboost as xgb
from sklearn.metrics import accuracy_score

In [6]:
data = pd.read_csv('fire_point_data.csv')
# 删除不必要的列
data.drop(columns=['Unnamed: 0', 'date'], inplace=True)
# 对分类特征进行独热编码
data = pd.get_dummies(data, columns=['area'])

In [7]:
train_data = data[data['year']!=2017]
pred_data = data[data['year']==2017]

In [8]:
# 定义特征和目标变量
X = train_data.drop('fire', axis=1)
y = train_data['fire']

# 分割数据集为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 显示训练集的前几行
X_train.head()

Unnamed: 0,year,month,day,lon,lat,Alti,TEM_Max,TEM_Min,RHU_Min,PRE_Time_2020,...,DC,FWI,ISI,BUI,DSR,FFDI,ic,area_吉林省,area_辽宁省,area_黑龙江省
2211,2013,4,19,123.3379,40.196,78.447477,12.689836,-1.060047,15.224945,0.0,...,55.193265,34.201992,18.915878,49.43749,14.162946,0.415855,40.414552,0,1,0
2823,2015,7,7,122.6464,52.2967,481.94938,28.191598,15.567442,48.336408,71659.227505,...,349.578361,30.779807,10.383757,94.216332,12.061102,0.419228,12.838865,0,0,1
1601,2011,5,7,130.2803,48.6792,90.4,13.4,3.4,68.0,0.0,...,159.59,5.0,1.81,43.9,0.47,0.17,3.0,0,0,1
208,2010,6,27,123.601,51.5407,514.5,38.3,16.8,20.0,0.0,...,397.15,73.71,34.87,116.45,54.96,0.73,49.0,0,0,1
2553,2015,4,20,124.3087,41.4067,179.26302,19.118347,3.277118,7.0,0.0,...,231.381114,41.750692,25.242003,50.650119,20.223636,0.442752,59.082664,0,1,0


In [10]:
# 创建DMatrix对象
dtrain = xgb.DMatrix(X_train, label=y_train)
dtest = xgb.DMatrix(X_test, label=y_test)

# 设置参数
params = {
    'objective': 'binary:logistic',
    'eval_metric': 'error',
    'eta': 0.1,
    'max_depth': 5,
}

# 训练模型并显示进度条
evals = [(dtrain, 'train'), (dtest, 'eval')]
model = xgb.train(params, dtrain, num_boost_round=100, evals=evals, early_stopping_rounds=10)

# 预测测试集
y_pred = model.predict(dtest)
y_pred_binary = [1 if p > 0.5 else 0 for p in y_pred]

# 计算准确率
accuracy = accuracy_score(y_test, y_pred_binary)

[0]	train-error:0.15856	eval-error:0.15805
[1]	train-error:0.15856	eval-error:0.15805
[2]	train-error:0.13080	eval-error:0.12462
[3]	train-error:0.12167	eval-error:0.11550
[4]	train-error:0.11901	eval-error:0.10790
[5]	train-error:0.11863	eval-error:0.10638
[6]	train-error:0.11559	eval-error:0.10030
[7]	train-error:0.10951	eval-error:0.10638
[8]	train-error:0.10608	eval-error:0.09878
[9]	train-error:0.10418	eval-error:0.09878
[10]	train-error:0.10266	eval-error:0.09726
[11]	train-error:0.10000	eval-error:0.09878
[12]	train-error:0.09278	eval-error:0.10030
[13]	train-error:0.08707	eval-error:0.09878
[14]	train-error:0.08517	eval-error:0.09726
[15]	train-error:0.08289	eval-error:0.10030
[16]	train-error:0.07985	eval-error:0.09422
[17]	train-error:0.08061	eval-error:0.09726
[18]	train-error:0.07833	eval-error:0.09878
[19]	train-error:0.07795	eval-error:0.10030
[20]	train-error:0.07681	eval-error:0.09878
[21]	train-error:0.07490	eval-error:0.09878
[22]	train-error:0.07605	eval-error:0.1003

In [11]:
model_path = 'save/fire_risk_model.xgb'

# 保存模型
model.save_model(model_path)


In [12]:
# 从预测数据中分离特征
X_pred = pred_data.drop('fire', axis=1)

# 创建DMatrix对象
dpred = xgb.DMatrix(X_pred)

# 使用保存的模型进行预测
predictions = model.predict(dpred)

# 将预测结果转换为二进制分类输出
predictions_binary = [1 if p > 0.5 else 0 for p in predictions]

# 将预测结果添加到pred_data中
pred_data['fire_prediction'] = predictions_binary

# 显示预测数据的前5行
pred_data.head()

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  pred_data['fire_prediction'] = predictions_binary


Unnamed: 0,year,month,day,lon,lat,fire,Alti,TEM_Max,TEM_Min,RHU_Min,...,FWI,ISI,BUI,DSR,FFDI,ic,area_吉林省,area_辽宁省,area_黑龙江省,fire_prediction
3288,2017,3,17,121.7095,39.607,0,61.76906,14.623319,0.705787,25.523332,...,39.074034,17.693094,73.119396,18.249361,0.673881,30.52073,0,1,0,1
3289,2017,3,18,121.7095,39.607,0,61.76906,15.617297,3.70516,24.223707,...,40.580502,18.260607,75.602026,19.569267,0.680577,32.34983,0,1,0,1
3290,2017,3,19,121.7095,39.607,1,61.76906,16.882907,1.500834,17.976893,...,41.127981,18.185337,78.456737,20.090944,0.68895,37.311649,0,1,0,1
3291,2017,3,20,121.7095,39.607,0,61.76906,10.633853,3.368595,22.640467,...,51.983364,26.129648,80.194564,30.026138,0.696162,35.839511,0,1,0,1
3292,2017,3,31,124.4428,40.36,0,86.571498,15.426407,4.293616,11.872208,...,41.804042,23.999176,54.968026,20.265244,0.606583,51.948266,0,1,0,1
