In [2]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde, linregress
import re
from concurrent.futures import ProcessPoolExecutor

# -------------------- 工具函数 --------------------
def extract_key_period(period):
    """
    Extract key period (e.g., JFM, AMJ) from the full period string.
    """
    key_periods = ["DJF", "MAM", "JJA", "SON", 'Annual', 'Apr-Sep', 'top-10', '98th']
    for key in key_periods:
        if key in period:
            return key
    return None

def get_prefix(filename):
    """
    根据文件名确定前缀。
    如果文件名包含 'daily' 或 'IA'，返回 'IA'；否则返回 'AI'。
    """
    if "FtA" in filename:
        return "FtA"
    return "AtF"

def get_year(filename):
    """
    从文件名中提取年份（假设年份在 2011 - 2020 范围内）。
    """
    match = re.search(r"(20[1-2][0-9])", filename)
    if match:
        return match.group(1)
    return None

def get_axis_label(filename, period=None, year=None):
    """
    根据文件名生成轴标签。
    """
    prefix = get_prefix(filename)
    if "DFT" in filename.upper():
        label = "DFT"
    elif "BarronResult" in filename:
        label = "Barron's Result"
    elif"BarronScript" in filename:
        label = "Barron's Script"
    elif "Python" in filename:
        label = "Python"
    elif "EQUATES" in filename:
        label = "EQUATES"
    elif "Harvard" in filename:
        label = "Harvard"
    else:
        label = "unkown"

    if period and year:
        return f"{prefix}_{label}_{period}"
    elif period:
        return f"{prefix}_{label}_{period}"
    elif year:
        return f"{prefix}_{label}_{year}"
    return f"{prefix}_{label}"

# -------------------- 数据加载 --------------------
fusion_output_files = [
    "/DeepLearning/mnt/shixiansheng/data_fusion/output/2011_Data_WithoutCV/BarronScript_ALL_2011_FtAIndex_InUSA.csv",
]

x_axis_file = "/DeepLearning/mnt/shixiansheng/data_fusion/output/2011_Data_WithoutCV/BarronResult_VNAeVNA_2011_FtAIndex_InUSA.csv"

base_output_dir = '/DeepLearning/mnt/shixiansheng/data_fusion/output'

year_x = get_year(x_axis_file)
year_y = get_year(fusion_output_files[0])

if year_x != year_y:
    print("Warning: The years in the input files do not match!")
else:
    year = year_x

    x_label = get_axis_label(x_axis_file, year=year)
    y_label = get_axis_label(fusion_output_files[0], year=year)

    output_dir = os.path.join(base_output_dir, f"scatter_plots_{y_label}Vs{x_label}")

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        print(f"Created directory: {output_dir}")
    else:
        print(f"Directory already exists: {output_dir}")

# -------------------- 定义绘图函数 --------------------
def plot_density_scatter(dataframe_x, dataframe_y, x_column, y_column, period_column, output_dir, period_value, file_name):
    df_period_x = dataframe_x[dataframe_x[period_column].str.contains(period_value, case=False, na=False)]
    df_period_y = dataframe_y[dataframe_y[period_column].str.contains(period_value, case=False, na=False)]

    if df_period_x.empty or df_period_y.empty:
        print(f"数据中没有有效数据，跳过 Period: {period_value} 的绘图。")
        return

    x_data = df_period_x[x_column].values
    y_data = df_period_y[y_column].values

    valid_indices = ~np.isnan(x_data) & ~np.isnan(y_data)
    x_data = x_data[valid_indices]
    y_data = y_data[valid_indices]

    if len(x_data) == 0 or len(y_data) == 0:
        print(f"数据为空，跳过 Period: {period_value} 的绘图。")
        return

    file_base_name = os.path.basename(file_name).split(".")[0]
    year = get_year(file_name)

    period_search = extract_key_period(period_value)
    if period_search:
        period_value = period_search

    formatted_period = f"{period_value}_{year}"
    x_label = get_axis_label(x_axis_file, period_value, year)
    y_label = get_axis_label(file_name, period_value, year)

    xy = np.vstack([x_data, y_data])
    kde = gaussian_kde(xy)
    z = kde(xy)
    z = (z - z.min()) / (z.max() - z.min())

    fig, ax = plt.subplots(figsize=(6, 5))
    scatter = ax.scatter(x_data, y_data, c=z, cmap='jet', s=20, alpha=0.8)
    fig.colorbar(scatter, ax=ax)

    max_val = max(x_data.max(), y_data.max())
    max_val1 = max_val + 3
    ax.set_xlim(-3, max_val1)
    ax.set_ylim(-3, max_val1)

    ax.plot([0, max_val], [0, max_val], 'k-', lw=0.5, label="1:1 line")

    slope, intercept, r_value, _, _ = linregress(x_data, y_data)
    regression_line = slope * np.array([0, max_val]) + intercept
    ax.plot([0, max_val], regression_line, 'r-', lw=0.5, label="Regression Line")
    r_squared = r_value ** 2
    mae = np.mean(np.abs(y_data - x_data))
    rmse = np.sqrt(np.mean((y_data - x_data) ** 2))
    print(f"RMSE: {rmse}, MAE: {mae}")
    ax.text(0.95, 0.05, f"$R^2$ = {r_squared:.4f}\nMAE = {mae:.3f}\nRMSE = {rmse:.3f}",
            transform=ax.transAxes, ha="right", va="bottom", fontsize=12)

    ax.set_xlabel(f'{x_label} ({x_column})', fontsize=12)
    ax.set_ylabel(f'{y_label} ({y_column})', fontsize=12)
    ax.legend(loc='upper left', fontsize=10)

    fig.subplots_adjust(top=0.85)
    ax.set_title(f'{y_label} vs {x_label}', fontsize=13, loc='center')

    output_file_name = f'{file_base_name}_{formatted_period}_{y_column}_vs_{x_column}_density.png'
    output_path = os.path.join(output_dir, output_file_name)
    plt.tight_layout()
    plt.savefig(output_path, dpi=300)
    plt.show()
    print(f"散点密度图已保存至 {output_path}")

# -------------------- 并行化绘图 --------------------
def process_file(fusion_output_file):
    df_data_y = pd.read_csv(fusion_output_file)
    df_data_x = pd.read_csv(x_axis_file)

    period_column = 'Period'
    x_column_vna = 'vna_ozone'
    x_column_evna = 'evna_ozone'
    y_column_vna = 'vna_ozone'
    y_column_evna = 'evna_ozone'

    keywords = ["DJF", "MAM", "JJA", "SON", 'Annual', 'Apr-Sep', 'top-10', '98th']

    for keyword in keywords:
        plot_density_scatter(df_data_x, df_data_y, x_column_vna, y_column_vna, period_column, output_dir, keyword, fusion_output_file)
        plot_density_scatter(df_data_x, df_data_y, x_column_evna, y_column_evna, period_column, output_dir, keyword, fusion_output_file)

# -------------------- 运行并行化任务 --------------------
def process_files_parallel():
    with ProcessPoolExecutor() as executor:
        executor.map(process_file, fusion_output_files)

process_files_parallel()


Directory already exists: /DeepLearning/mnt/shixiansheng/data_fusion/output/scatter_plots_FtA_Barron's Script_2011VsFtA_Barron's Result_2011


KeyboardInterrupt: 