# 2022 农业系统模型与大数据分析实验课4: 课后作业参考

In [1]:
import numpy as np
import pandas as pd
import random

from sklearn.cross_decomposition import PLSRegression
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import mean_squared_error, r2_score

In [2]:
df_raw = pd.read_csv('m5.csv', header=0, index_col=None)
df_train, df_test = train_test_split(df_raw, test_size=0.2, random_state=42)
X_train, X_test = df_train.iloc[:, :-4], df_test.iloc[:, :-4]
y_train, y_test = df_train.iloc[:, -2], df_test.iloc[:, -2]

In [3]:
# 构建模型,使用GridSearchCV函数实现交叉验证
pls = PLSRegression(scale=False)
param_dist = {'n_components':[k for k in range(1,21)]}
grid_search = GridSearchCV(estimator = pls, param_grid = param_dist, 
                           cv=5, scoring = 'neg_mean_squared_error')
grid_search.fit(X_train, y_train)
optim_estimator = grid_search.best_estimator_
y_train_pred = optim_estimator.predict(X_train)
y_test_pred = optim_estimator.predict(X_test)

# RMSE计算
rmse_train = np.sqrt(mean_squared_error(y_train, y_train_pred))
rmse_test = np.sqrt(mean_squared_error(y_test, y_test_pred))

# R2计算
r2_train = r2_score(y_train, y_train_pred)
r2_test = r2_score(y_test, y_test_pred)

In [4]:
print("Optimized hyperparamter:", grid_search.best_params_)

Optimized hyperparamter: {'n_components': 20}


In [5]:
print('Train_RMSE:{:.4}, Test_RMSE:{:.4}\nTrain_R2:{:.4}, Test_R2:{:.4}'.format(rmse_train, 
                                                                                rmse_test,r2_train,r2_test))

Train_RMSE:0.02214, Test_RMSE:0.1165
Train_R2:0.998, Test_R2:0.9392


In [6]:
result_table = pd.DataFrame(columns=['Training RMSE', 'Training R2',
                                     'Test RMSE', 'Test R2'])
result = {'Training RMSE': rmse_train, 'Training R2': r2_train, 
          'Test RMSE': rmse_test, 'Test R2': r2_test}   
result_table = result_table.append(result, sort=False, ignore_index=True)

In [7]:
result_table

Unnamed: 0,Training RMSE,Training R2,Test RMSE,Test R2
0,0.02214,0.998047,0.116507,0.939236
