In [5]:
import json
import numpy as np
import pandas as pd
from sklearn.linear_model import Lasso
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

pd.set_option('display.max_rows', None) 
pd.set_option('display.max_columns', None) 

# 加载 JSON 数据
with open("data.json", "r", encoding="utf-8") as json_file:
    data = json.load(json_file)

#print(data)

# 提取特征和标签
def extract_features_and_labels(data):
    features = []
    labels = []
    for country, entries in data.items():
        for entry in entries:
            features.append(entry["Feats"])
            labels.append(entry["label"][0])  # 假设我们只预测第一个标签
    return np.array(features), np.array(labels)

# 提取特征和标签
X, y = extract_features_and_labels(data)

# 将数据集分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 初始化拉索回归模型
lasso = Lasso(alpha=0.1)  # alpha 是正则化强度，可以调整

# 训练模型
lasso.fit(X_train, y_train)

# 预测
y_pred = lasso.predict(X_test)

# 评估模型
mse = mean_squared_error(y_test, y_pred)
print(f"Mean Squared Error: {mse}")

# 输出模型系数
print("Lasso Coefficients:", lasso.coef_)

predictions = pd.DataFrame({
    'Actual_Gold': y_test,
    'Predicted_Gold': y_pred
})
print("预测结果:")
print(predictions)


Mean Squared Error: 17.324425248548177
Lasso Coefficients: [ 1.08853527e-01  1.38266456e-02 -7.94967772e-02  2.40762654e-02
 -8.08636493e-04 -0.00000000e+00  2.38454534e-01  1.04276920e-01
 -5.15265029e-04  2.14107731e-02 -7.38724938e-04 -0.00000000e+00
  2.10954833e-01  9.24474505e-02 -0.00000000e+00  8.25035716e-02
  1.04639416e-04 -0.00000000e+00  7.32777662e+00]
预测结果:
     Actual_Gold  Predicted_Gold
0             17       13.989546
1              0        1.575626
2              1        1.634683
3              3        2.818175
4             11        5.705821
5              1        1.097373
6              7        6.205509
7              2        1.265677
8              1        1.159465
9              4        5.188876
10            26       34.590305
11             8       17.229579
12             1        4.593669
13             1        3.098301
14             1        3.374644
15             3        3.744777
16             1        1.683378
17             0        0.17537