用22年所有数据进行训练来预测24年的数据

In [None]:
import pandas as pd
import numpy as np
from math import radians, cos, sin, asin, sqrt
import matplotlib.pyplot as plt
from sklearn.linear_model import Ridge
from sklearn.metrics import r2_score, mean_squared_error
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.impute import SimpleImputer

# 定义处理数据的函数
def process_data(light_file, species_file):
    # 读取数据
    light_data = pd.read_csv(light_file)
    species_data = pd.read_csv(species_file)
    
    # 数据清理
    light_data['latitude'] = pd.to_numeric(light_data['latitude'])
    light_data['longitude'] = pd.to_numeric(light_data['longitude'])
    species_data['latitude'] = pd.to_numeric(species_data['latitude'])
    species_data['longitude'] = pd.to_numeric(species_data['longitude'])
    light_data['nsb'] = pd.to_numeric(light_data['nsb'])
    
    # 打印原始数据量
    print(f"  光强数据量: {len(light_data)}")
    print(f"  物种数据量: {len(species_data)}")
    
    species_with_light = []
    
    for _, species_row in species_data.iterrows():
        nearby_stations = find_nearby_light_stations(species_row, light_data)
        if nearby_stations:
            avg_nsb = sum(station['nsb'] for station in nearby_stations) / len(nearby_stations)
            species_with_light.append({
                'scientific': species_row['scientific'],
                'month': species_row['month'],
                'gno': species_row['gno'],
                'nsb': avg_nsb,
                'stations': [s['location_id'] for s in nearby_stations]
            })
        else:
            species_with_light.append({
                'scientific': species_row['scientific'],
                'month': species_row['month'],
                'gno': species_row['gno'],
                'nsb': None,
                'stations': []
            })
    
    result_df = pd.DataFrame(species_with_light)
    
    # 确保所有年份的物种ID一致
    global species_to_id
    if 'species_to_id' not in globals():
        species_to_id = {scientific: idx + 1 for idx, scientific in enumerate(result_df['scientific'].unique())}
    
    result_df['species_id'] = result_df['scientific'].map(species_to_id)
    
    # 删除不需要的列
    columns_to_drop = []
    result_df = result_df.drop(columns=columns_to_drop)
    
    # 添加物种类别
    result_df['category'] = result_df['species_id'].apply(get_species_category)
    
    return result_df

def haversine_distance(lat1, lon1, lat2, lon2):
    lat1, lon1, lat2, lon2 = map(radians, [lat1, lon1, lat2, lon2])
    dlon = lon2 - lon1
    dlat = lat2 - lat1
    a = sin(dlat/2)**2 + cos(lat1) * cos(lat2) * sin(dlon/2)**2
    c = 2 * asin(sqrt(a))
    r = 6371
    return c * r

def find_nearby_light_stations(species_row, light_data):
    nearby_stations = []
    for _, light_row in light_data.iterrows():
        if light_row['month'] == species_row['month']:
            dist = haversine_distance(
                species_row['latitude'], species_row['longitude'],
                light_row['latitude'], light_row['longitude']
            )
            if dist < 2:
                nearby_stations.append({
                    'location_id': light_row['location_id'],
                    'distance': dist,
                    'nsb': light_row['nsb']
                })
    return nearby_stations

# 物种分类
birds = [8, 9, 10, 11, 12, 16, 17, 18, 19, 20, 21, 29, 30, 31, 37, 38, 39, 40, 44, 45, 46, 47, 49, 50, 52, 53, 54, 55, 56, 70, 71, 72, 84, 89, 90, 91, 92, 93, 95, 96, 97, 98, 99, 104, 113, 114, 117, 119, 123, 124, 125, 130, 131, 132, 134, 135, 139, 143, 145, 147, 148, 153, 154, 157, 158, 159, 162, 171, 176, 177, 179, 180, 181, 187, 188, 189, 190, 194, 195, 196, 197, 201, 202, 210, 222, 223, 224, 225, 226, 228, 229, 230, 231, 233, 236, 237, 242, 243, 251, 252, 260, 261, 262, 271, 276, 277, 283, 284, 287, 294, 308, 309, 318, 325, 326, 327, 328, 329, 330, 337, 343, 344, 345, 346, 349, 362, 363, 378, 385, 386, 387, 388, 392, 393, 394, 395, 398, 420, 427, 428, 446, 447, 449, 450, 474, 475, 479, 480, 484, 485, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 510, 511, 512, 524, 525, 543, 544, 545, 546, 547, 548, 551, 552, 553, 563, 564, 565, 578, 580, 588, 595, 596, 597, 600, 601, 604, 605, 606, 609, 610, 611, 612, 613, 614, 617, 622, 632, 633, 636, 637, 645, 649, 650, 651, 652, 653, 654, 655, 660, 661, 662, 663, 664, 668, 669, 671, 672, 678, 690, 691]
mammals = [79, 83, 100, 149, 161, 272, 273, 278, 279, 296, 352, 370, 378, 379, 380, 381, 391, 399, 400, 401, 449, 526, 562, 570, 571, 572, 579, 618, 619, 665, 666, 677]
reptiles = [65, 85, 101, 160, 215, 227, 234, 265, 266, 267, 268, 295, 302, 303, 353, 354, 425, 431, 432, 508, 509, 533, 549, 550, 558, 566, 583, 584, 589, 593, 629, 647, 648, 659]
amphibians = [33, 186, 221, 280, 281, 289, 319, 320, 331, 340, 375, 376, 377, 424, 465, 507, 559, 620]
fish = [5, 6, 7, 13, 28, 32, 48, 75, 76, 77, 105, 120, 121, 122, 136, 140, 141, 151, 163, 191, 218, 219, 220, 239, 241, 263, 264, 269, 275, 292, 293, 341, 350, 351, 389, 390, 434, 435, 445, 471, 481, 513, 534, 539, 540, 568, 569, 581, 582, 591, 627, 628, 635, 646, 676, 679, 680]
butterflies_moths = [1, 2, 15, 22, 23, 34, 35, 56, 57, 58, 59, 60, 61, 62, 73, 74, 78, 86, 87, 102, 103, 106, 107, 108, 109, 111, 112, 115, 126, 127, 128, 129, 133, 154, 155, 156, 165, 166, 169, 170, 172, 173, 182, 183, 184, 193, 200, 207, 208, 209, 211, 212, 213, 254, 255, 256, 259, 274, 282, 285, 290, 291, 297, 299, 310, 311, 312, 313, 314, 315, 316, 317, 321, 323, 324, 333, 334, 335, 336, 338, 339, 357, 360, 361, 367, 368, 369, 384, 396, 397, 402, 403, 407, 408, 409, 410, 411, 417, 423, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 472, 473, 476, 483, 484, 502, 503, 505, 515, 516, 517, 521, 522, 523, 527, 528, 531, 561, 573, 577, 594, 602, 603, 615, 621, 623, 624, 625, 626, 630, 631, 640, 641, 642, 658, 667, 673, 674, 681, 682, 683, 684, 685, 686, 687, 688, 689]
dragonflies_damselflies = [14, 26, 27, 36, 41, 42, 43, 63, 80, 82, 88, 94, 116, 142, 144, 150, 182, 185, 198, 203, 249, 250, 257, 258, 301, 305, 332, 355, 356, 359, 360, 365, 383, 404, 405, 412, 413, 414, 439, 440, 441, 442, 443, 451, 462, 463, 464, 529, 530, 532, 535, 536, 537, 538, 541, 573, 574, 575, 576, 590, 592, 638, 639, 643, 644, 656, 657, 670, 692, 693, 694]
fireflies_beetles = [51, 174, 175, 347, 348, 544, 545, 554, 555, 556, 557, 567, 607, 675]
other_invertebrates = [235, 270, 421, 422, 468, 469, 470, 585, 586]

def get_species_category(species_id):
    if species_id in birds:
        return 1
    elif species_id in mammals:
        return 2
    elif species_id in reptiles:
        return 3
    elif species_id in amphibians:
        return 4
    elif species_id in fish:
        return 5
    elif species_id in butterflies_moths:
        return 6
    elif species_id in dragonflies_damselflies:
        return 7
    elif species_id in fireflies_beetles:
        return 8
    else:
        return 9

def add_interaction_terms(df, category_dummies, stations_dummies, degree):
    if degree >= 1:
        df['nsb'] = df['nsb']
    if degree >= 2:
        df['nsb^2'] = df['nsb'] ** 2
        for c in category_dummies.columns:
            for s in stations_dummies.columns:
                df[f'{c}*{s}'] = df[c] * df[s]
                df[f'nsb*{c}'] = df['nsb'] * df[c]
                df[f'nsb*{s}'] = df['nsb'] * df[s]
                df[f'nsb^2*{c}'] = df['nsb^2'] * df[c]
                df[f'nsb^2*{s}'] = df['nsb^2'] * df[s]
                df[f'nsb^2*{c}*{s}'] = df['nsb^2'] * df[c] * df[s]
                df[f'nsb*{c}*{s}'] = df['nsb'] * df[c] * df[s]
    if degree >= 3:
        df['nsb^3'] = df['nsb'] ** 3
        for c in category_dummies.columns:
            for s in stations_dummies.columns:
                df[f'nsb^3*{c}'] = df['nsb^3'] * df[c]
                df[f'nsb^3*{s}'] = df['nsb^3'] * df[s]
                df[f'nsb^3*{c}*{s}'] = df['nsb^3'] * df[c] * df[s]
    if degree >= 4:
        df['nsb^4'] = df['nsb'] ** 4
        for c in category_dummies.columns:
            for s in stations_dummies.columns:
                df[f'nsb^4*{c}'] = df['nsb^4'] * df[c]
                df[f'nsb^4*{s}'] = df['nsb^4'] * df[s]
                df[f'nsb^4*{c}*{s}'] = df['nsb^4'] * df[c] * df[s]
    if degree >= 5:
        df['nsb^5'] = df['nsb'] ** 5
        for c in category_dummies.columns:
            for s in stations_dummies.columns:
                df[f'nsb^5*{c}'] = df['nsb^5'] * df[c]
                df[f'nsb^5*{s}'] = df['nsb^5'] * df[s]
                df[f'nsb^5*{c}*{s}'] = df['nsb^5'] * df[c] * df[s]
    return df

# 处理各年份数据
print("正在处理2022年数据...")
data_2022 = process_data('/Users/hansen/Desktop/MATH3836proj/香港物种数据/2022年光强.csv', 
                         '/Users/hansen/Desktop/MATH3836proj/香港物种数据/2022年物种_最终版.csv')
print("正在处理2024年数据...")
data_2024 = process_data('/Users/hansen/Desktop/MATH3836proj/香港物种数据/2024年光强.csv', 
                         '/Users/hansen/Desktop/MATH3836proj/香港物种数据/2024年物种_最终版.csv')

# 筛选有效数据
valid_data_2022 = data_2022[data_2022['nsb'].notna()].reset_index(drop=True)
valid_data_2024 = data_2024[data_2024['nsb'].notna()].reset_index(drop=True)

print(f"2022年有效数据量: {len(valid_data_2022)}")
print(f"2024年有效数据量: {len(valid_data_2024)}")

# 打印各年份的物种类别分布
for year, data in [('2022', valid_data_2022), ('2024', valid_data_2024)]:
    print(f"\n{year}年物种类别分布:")
    print(data['category'].value_counts().sort_index())

# 合并所有年份的数据，以确保相同的特征集
all_data = pd.concat([valid_data_2022, valid_data_2024], axis=0)
all_category_dummies = pd.get_dummies(all_data['category'], prefix='category')
all_stations_dummies = pd.get_dummies(all_data['stations'].apply(lambda x: x[0] if x else None), prefix='station')

# 为每个年份的数据添加相同的哑变量
def prepare_data_with_dummies(valid_data, all_category_dummies, all_stations_dummies):
    # 创建该年份的哑变量
    category_dummies = pd.get_dummies(valid_data['category'], prefix='category')
    stations_dummies = pd.get_dummies(valid_data['stations'].apply(lambda x: x[0] if x else None), prefix='station')
    
    # 添加缺失的列
    for col in all_category_dummies.columns:
        if col not in category_dummies.columns:
            category_dummies[col] = 0
    
    for col in all_stations_dummies.columns:
        if col not in stations_dummies.columns:
            stations_dummies[col] = 0
    
    # 确保列的顺序一致
    category_dummies = category_dummies[all_category_dummies.columns]
    stations_dummies = stations_dummies[all_stations_dummies.columns]
    
    return pd.concat([valid_data.reset_index(drop=True), 
                     category_dummies.reset_index(drop=True), 
                     stations_dummies.reset_index(drop=True)], axis=1)

valid_data_2022_with_dummies = prepare_data_with_dummies(valid_data_2022, all_category_dummies, all_stations_dummies)
valid_data_2024_with_dummies = prepare_data_with_dummies(valid_data_2024, all_category_dummies, all_stations_dummies)

# 定义模型
models = {
    "Ridge": Ridge,  # 使用类，而不是实例
}

# 初始化结果存储
results = {}

# 对于每个多项式度数
for degree in range(1, 6):
    print(f"\n正在处理多项式次数 {degree}...")
    
    # 添加交互项
    train_data = add_interaction_terms(valid_data_2022_with_dummies.copy(), 
                                      all_category_dummies, 
                                      all_stations_dummies, 
                                      degree)
    
    test_data_2024 = add_interaction_terms(valid_data_2024_with_dummies.copy(), 
                                         all_category_dummies, 
                                         all_stations_dummies, 
                                         degree)
    
    # 选择特征
    numeric_columns = train_data.select_dtypes(include=[np.number]).columns.tolist()
    if 'gno' in numeric_columns:
        numeric_columns.remove('gno')
    
    # 确保所有数据集有相同的特征列
    common_features = list(set(numeric_columns).intersection(
                        set(test_data_2024.select_dtypes(include=[np.number]).columns)))
    
    # 移除'gno'如果它在common_features中
    if 'gno' in common_features:
        common_features.remove('gno')
    
    print(f"使用 {len(common_features)} 个共同特征进行训练和测试")
    
    # 准备训练数据
    X_train = train_data[common_features]
    y_train = train_data['gno']
    
    # 检查并处理NaN值
    nan_count = X_train.isna().sum().sum()
    if nan_count > 0:
        print(f"训练数据中有 {nan_count} 个NaN值，将使用均值填充")
    
    # 准备2024年测试数据
    X_test = test_data_2024[common_features]
    y_test = test_data_2024['gno']
    
    # 检查并处理NaN值
    nan_count = X_test.isna().sum().sum()
    if nan_count > 0:
        print(f"2024年测试数据中有 {nan_count} 个NaN值，将使用均值填充")
    
    for name, model_class in models.items():
        print(f"\n使用 {name} 模型...")
        
        # 创建模型实例
        model = model_class(alpha=1.0)
        
        # 创建填充器和标准化器
        imputer = SimpleImputer(strategy='mean')
        scaler = StandardScaler()
        
        # 填充训练数据中的缺失值
        X_train_imputed = imputer.fit_transform(X_train)
        
        # 标准化训练数据
        X_train_scaled = scaler.fit_transform(X_train_imputed)
        
        # 训练模型
        model.fit(X_train_scaled, y_train)
        
        # 在训练集上评估模型
        y_pred_train = model.predict(X_train_scaled)
        r2_train = r2_score(y_train, y_pred_train)
        rmse_train = np.sqrt(mean_squared_error(y_train, y_pred_train))
        
        # 填充测试数据中的缺失值
        X_test_imputed = imputer.transform(X_test)
        
        # 标准化测试数据
        X_test_scaled = scaler.transform(X_test_imputed)
        
        # 在测试集上评估模型
        y_pred_test = model.predict(X_test_scaled)
        r2_test = r2_score(y_test, y_pred_test)
        rmse_test = np.sqrt(mean_squared_error(y_test, y_pred_test))
        
        # 存储结果
        if name not in results:
            results[name] = {
                'degree': [],
                'r2_train': [], 'rmse_train': [],
                'r2_test': [], 'rmse_test': []
            }
        
        results[name]['degree'].append(degree)
        results[name]['r2_train'].append(r2_train)
        results[name]['rmse_train'].append(rmse_train)
        results[name]['r2_test'].append(r2_test)
        results[name]['rmse_test'].append(rmse_test)
        
        print(f"  多项式次数 {degree}：")
        print(f"    2022年(训练集) R²: {r2_train:.4f}, RMSE: {rmse_train:.4f}")
        print(f"    2024年(测试集) R²: {r2_test:.4f}, RMSE: {rmse_test:.4f}")

# 可视化结果 - 使用英文标签
for name, result in results.items():
    plt.figure(figsize=(15, 10))
    
    # R² 值比较
    plt.subplot(2, 1, 1)
    plt.plot(result['degree'], result['r2_train'], marker='o', label='2022 (Training Set)')
    plt.plot(result['degree'], result['r2_test'], marker='^', label='2024 (Test Set)')
    plt.title(f'{name} Model R² Values on Different Datasets')
    plt.xlabel('Polynomial Degree')
    plt.ylabel('R² Value')
    plt.xticks(result['degree'])
    plt.grid(True)
    plt.legend()
    
    # RMSE 值比较
    plt.subplot(2, 1, 2)
    plt.plot(result['degree'], result['rmse_train'], marker='o', label='2022 (Training Set)')
    plt.plot(result['degree'], result['rmse_test'], marker='^', label='2024 (Test Set)')
    plt.title(f'{name} Model RMSE Values on Different Datasets')
    plt.xlabel('Polynomial Degree')
    plt.ylabel('RMSE Value')
    plt.xticks(result['degree'])
    plt.grid(True)
    plt.legend()
    
    plt.tight_layout()
    plt.savefig(f'{name}_model_results.png')
    plt.show()

# 打印最佳模型结果
for name, result in results.items():
    best_degree_train = result['degree'][np.argmax(result['r2_train'])]
    best_degree_test = result['degree'][np.argmax(result['r2_test'])]
    
    print(f"\n{name} 模型结果总结:")
    print(f"2022年训练集最佳多项式次数: {best_degree_train}, R²: {max(result['r2_train']):.4f}, RMSE: {min(result['rmse_train']):.4f}")
    print(f"2024年测试集最佳多项式次数: {best_degree_test}, R²: {max(result['r2_test']):.4f}, RMSE: {min(result['rmse_test']):.4f}")


正在处理2022年数据...
  光强数据量: 112
  物种数据量: 11178
