Data download preprocessing

In [1]:
import cdsapi
import logging
from pathlib import Path
import time

class ERA5Retriever:
    def __init__(self, output_dir='era5_data'):
        """初始化ERA5数据获取器"""
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        self.client = cdsapi.Client()
        self.setup_logging()

    def setup_logging(self):
        """设置日志记录"""
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(levelname)s - %(message)s',
            handlers=[
                logging.FileHandler(self.output_dir / 'era5_retrieval.log'),
                logging.StreamHandler()
            ]
        )
        self.logger = logging.getLogger('ERA5Retriever')

    def retrieve_data(self, year, month, output_file=None):
        """获取ERA5数据
        
        Parameters:
        -----------
        year : int
            年份
        month : int
            月份
        output_file : str, optional
            输出文件名
        """
        if output_file is None:
            output_file = self.output_dir / f'era5_{year}_{month:02d}.nc'

        # 如果文件已存在，跳过下载
        if output_file.exists():
            self.logger.info(f"File {output_file} already exists, skipping...")
            return output_file

        try:
            # 构建请求参数
            request = {
                "format": "netcdf",
                "product_type": "reanalysis",
                "variable": [
                    "2m_temperature",                    # 温度
                    "total_precipitation",               # 降水
                    "mean_sea_level_pressure",          # 海平面气压
                    "10m_u_component_of_wind",          # 风场
                    "10m_v_component_of_wind",
                    "geopotential_at_500hpa",           # 500hPa位势高度
                    "relative_humidity_at_850hpa",      # 850hPa相对湿度
                ],
                "year": str(year),
                "month": f"{month:02d}",
                "day": [f"{day:02d}" for day in range(1, 32)],
                "time": [f"{hour:02d}:00" for hour in range(0, 24, 6)],  # 6小时间隔
                "area": [70, -20, 30, 60],  # 欧洲区域 [North, West, South, East]
            }

            # 获取数据
            self.logger.info(f"Retrieving data for {year}-{month:02d}")
            self.client.retrieve(
                'reanalysis-era5-single-levels',
                request,
                output_file
            )
            self.logger.info(f"Successfully downloaded {output_file}")
            
            # 添加间隔，避免请求过于频繁
            time.sleep(5)
            
            return output_file

        except Exception as e:
            self.logger.error(f"Error retrieving data for {year}-{month:02d}: {str(e)}")
            return None

def main():
    """主函数示例用法"""
    retriever = ERA5Retriever()

    # 定义时间范围（按照标杆论文，我们需要1959-2021的数据）
    # 但建议先下载小部分数据测试
    years = [2020, 2021]  # 测试用年份
    months = [11, 12, 1, 2, 3]  # 冬季月份

    # 获取数据
    for year in years:
        for month in months:
            if month in [1, 2, 3] and year == years[-1]:
                continue  # 跳过最后一年的1-3月
            retriever.retrieve_data(year, month)

    print("Data retrieval completed!")

if __name__ == "__main__":
    main()

2024-11-17 19:43:59,828 INFO [2024-09-28T00:00:00] **Welcome to the New Climate Data Store (CDS)!** This new system is in its early days of full operations and still undergoing enhancements and fine tuning. Some disruptions are to be expected. Your 
[feedback](https://jira.ecmwf.int/plugins/servlet/desk/portal/1/create/202) is key to improve the user experience on the new CDS for the benefit of everyone. Thank you.
2024-11-17 19:43:59,832 INFO [2024-09-26T00:00:00] Watch our [Forum](https://forum.ecmwf.int/) for Announcements, news and other discussed topics.
2024-11-17 19:43:59,833 INFO [2024-09-16T00:00:00] Remember that you need to have an ECMWF account to use the new CDS. **Your old CDS credentials will not work in new CDS!**
2024-11-17 19:43:59,836 - INFO - Retrieving data for 2020-11
[Forum announcement](https://forum.ecmwf.int/t/final-validated-era5-product-to-differ-from-era5t-in-july-2024/6685)
for details and watch it for further updates on this.
[Forum announcement](https://

94596468c920d7733b6fc1a69e5ded11.nc:   0%|          | 0.00/48.8M [00:00<?, ?B/s]

2024-11-17 19:45:27,284 - INFO - Successfully downloaded era5_data\era5_2020_11.nc
2024-11-17 19:45:32,294 - INFO - Retrieving data for 2020-12
[Forum announcement](https://forum.ecmwf.int/t/final-validated-era5-product-to-differ-from-era5t-in-july-2024/6685)
for details and watch it for further updates on this.
[Forum announcement](https://forum.ecmwf.int/t/final-validated-era5-product-to-differ-from-era5t-in-july-2024/6685)
for details and watch it for further updates on this.
2024-11-17 19:45:34,389 INFO Request ID is f073a9bf-6396-4e7c-a71d-63b2bac29bcc
2024-11-17 19:45:34,389 - INFO - Request ID is f073a9bf-6396-4e7c-a71d-63b2bac29bcc
2024-11-17 19:45:34,817 INFO status has been updated to accepted
2024-11-17 19:45:34,817 - INFO - status has been updated to accepted
2024-11-17 19:45:38,364 INFO status has been updated to running
2024-11-17 19:45:38,364 - INFO - status has been updated to running
2024-11-17 19:46:27,992 INFO status has been updated to successful
2024-11-17 19:46:27

17191cfae8a817905dd3846f2ad73ace.nc:   0%|          | 0.00/51.1M [00:00<?, ?B/s]

2024-11-17 19:46:48,966 - INFO - Successfully downloaded era5_data\era5_2020_12.nc
2024-11-17 19:46:53,971 - INFO - Retrieving data for 2020-01
[Forum announcement](https://forum.ecmwf.int/t/final-validated-era5-product-to-differ-from-era5t-in-july-2024/6685)
for details and watch it for further updates on this.
[Forum announcement](https://forum.ecmwf.int/t/final-validated-era5-product-to-differ-from-era5t-in-july-2024/6685)
for details and watch it for further updates on this.
2024-11-17 19:46:55,337 INFO Request ID is 31a9c734-f17a-407c-b993-aaf87ef33984
2024-11-17 19:46:55,337 - INFO - Request ID is 31a9c734-f17a-407c-b993-aaf87ef33984
2024-11-17 19:46:55,804 INFO status has been updated to accepted
2024-11-17 19:46:55,804 - INFO - status has been updated to accepted
2024-11-17 19:46:59,404 INFO status has been updated to running
2024-11-17 19:46:59,404 - INFO - status has been updated to running
2024-11-17 19:47:06,101 INFO status has been updated to accepted
2024-11-17 19:47:06,1

b5713c0fc2e93e5eaa2f73fe5d72d751.nc:   0%|          | 0.00/50.5M [00:00<?, ?B/s]

2024-11-17 19:48:30,896 - INFO - Successfully downloaded era5_data\era5_2020_01.nc
2024-11-17 19:48:35,905 - INFO - Retrieving data for 2020-02
[Forum announcement](https://forum.ecmwf.int/t/final-validated-era5-product-to-differ-from-era5t-in-july-2024/6685)
for details and watch it for further updates on this.
[Forum announcement](https://forum.ecmwf.int/t/final-validated-era5-product-to-differ-from-era5t-in-july-2024/6685)
for details and watch it for further updates on this.
2024-11-17 19:48:37,542 INFO Request ID is 0b94c20e-24f8-4509-afc5-49609dfb813f
2024-11-17 19:48:37,542 - INFO - Request ID is 0b94c20e-24f8-4509-afc5-49609dfb813f
2024-11-17 19:48:37,970 INFO status has been updated to accepted
2024-11-17 19:48:37,970 - INFO - status has been updated to accepted
2024-11-17 19:48:41,399 INFO status has been updated to running
2024-11-17 19:48:41,399 - INFO - status has been updated to running
2024-11-17 19:49:31,694 INFO status has been updated to successful
2024-11-17 19:49:31

7a47754be08ac75edc9bf9dec591a98b.nc:   0%|          | 0.00/47.2M [00:00<?, ?B/s]

2024-11-17 19:50:03,310 - INFO - Successfully downloaded era5_data\era5_2020_02.nc
2024-11-17 19:50:08,319 - INFO - Retrieving data for 2020-03
[Forum announcement](https://forum.ecmwf.int/t/final-validated-era5-product-to-differ-from-era5t-in-july-2024/6685)
for details and watch it for further updates on this.
[Forum announcement](https://forum.ecmwf.int/t/final-validated-era5-product-to-differ-from-era5t-in-july-2024/6685)
for details and watch it for further updates on this.
2024-11-17 19:50:11,172 INFO Request ID is 28ef174a-636a-4b03-909a-1e7c5086a92a
2024-11-17 19:50:11,172 - INFO - Request ID is 28ef174a-636a-4b03-909a-1e7c5086a92a
2024-11-17 19:50:11,870 INFO status has been updated to accepted
2024-11-17 19:50:11,870 - INFO - status has been updated to accepted
2024-11-17 19:50:15,816 INFO status has been updated to running
2024-11-17 19:50:15,816 - INFO - status has been updated to running
2024-11-17 19:50:22,869 INFO status has been updated to accepted
2024-11-17 19:50:22,8

493d1a84364ee0b4998c1b8e44a6261.nc:   0%|          | 0.00/49.8M [00:00<?, ?B/s]

2024-11-17 19:51:46,539 - INFO - Successfully downloaded era5_data\era5_2020_03.nc
2024-11-17 19:51:51,548 - INFO - Retrieving data for 2021-11
[Forum announcement](https://forum.ecmwf.int/t/final-validated-era5-product-to-differ-from-era5t-in-july-2024/6685)
for details and watch it for further updates on this.
[Forum announcement](https://forum.ecmwf.int/t/final-validated-era5-product-to-differ-from-era5t-in-july-2024/6685)
for details and watch it for further updates on this.
2024-11-17 19:51:53,367 INFO Request ID is 8a711334-6d14-42f2-a3ce-98cfcd8aa350
2024-11-17 19:51:53,367 - INFO - Request ID is 8a711334-6d14-42f2-a3ce-98cfcd8aa350
2024-11-17 19:51:54,041 INFO status has been updated to accepted
2024-11-17 19:51:54,041 - INFO - status has been updated to accepted
2024-11-17 19:52:07,086 INFO status has been updated to running
2024-11-17 19:52:07,086 - INFO - status has been updated to running
2024-11-17 19:52:50,858 INFO status has been updated to successful
2024-11-17 19:52:50

f363bef6e70757b46cc5f8eb363b3dc1.nc:   0%|          | 0.00/49.0M [00:00<?, ?B/s]

2024-11-17 19:53:15,837 - INFO - Successfully downloaded era5_data\era5_2021_11.nc
2024-11-17 19:53:20,853 - INFO - Retrieving data for 2021-12
[Forum announcement](https://forum.ecmwf.int/t/final-validated-era5-product-to-differ-from-era5t-in-july-2024/6685)
for details and watch it for further updates on this.
[Forum announcement](https://forum.ecmwf.int/t/final-validated-era5-product-to-differ-from-era5t-in-july-2024/6685)
for details and watch it for further updates on this.
2024-11-17 19:53:22,357 INFO Request ID is adf569ae-7221-4d99-a9ce-a8bc55ea09c8
2024-11-17 19:53:22,357 - INFO - Request ID is adf569ae-7221-4d99-a9ce-a8bc55ea09c8
2024-11-17 19:53:22,787 INFO status has been updated to accepted
2024-11-17 19:53:22,787 - INFO - status has been updated to accepted
2024-11-17 19:53:26,447 INFO status has been updated to running
2024-11-17 19:53:26,447 - INFO - status has been updated to running
2024-11-17 19:54:17,789 INFO status has been updated to successful
2024-11-17 19:54:17

d3b81914980bbfd43caa545ca89d866.nc:   0%|          | 0.00/51.2M [00:00<?, ?B/s]

2024-11-17 19:54:36,617 - INFO - Successfully downloaded era5_data\era5_2021_12.nc


Data retrieval completed!


Basic descriptive statistics

In [3]:
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
from pathlib import Path
import seaborn as sns

class ERA5Analyzer:
    def __init__(self, data_dir='era5_data'):
        """初始化分析器"""
        self.data_dir = Path(data_dir)
        
    def load_data(self, filename):
        """加载并检查ERA5数据"""
        file_path = self.data_dir / filename
        ds = xr.open_dataset(file_path)
        
        # 打印数据集信息
        print("Dataset dimensions:", list(ds.dims))
        print("Dataset variables:", list(ds.data_vars))
        return ds
    
    def basic_statistics(self, ds):
        """计算基本统计信息"""
        stats = {}
        for var in ds.data_vars:
            stats[var] = {
                'mean': float(ds[var].mean()),
                'std': float(ds[var].std()),
                'min': float(ds[var].min()),
                'max': float(ds[var].max())
            }
        return pd.DataFrame(stats).T
    
    def plot_spatial_mean(self, ds, variable, title=None):
        """绘制空间平均分布"""
        fig = plt.figure(figsize=(15, 10))
        ax = plt.axes(projection=ccrs.PlateCarree())
        
        # 计算时间平均（使用valid_time而不是time）
        mean_data = ds[variable].mean(dim='valid_time')
        
        # 绘制地图
        mean_data.plot(
            ax=ax,
            transform=ccrs.PlateCarree(),
            cmap='RdBu_r',
            robust=True
        )
        
        ax.coastlines()
        ax.gridlines()
        if title:
            plt.title(title)
            
        return fig
    
    def plot_time_series(self, ds, variable, lat=None, lon=None):
        """绘制时间序列"""
        if lat is not None and lon is not None:
            # 提取特定位置的数据
            data = ds[variable].sel(latitude=lat, longitude=lon, method='nearest')
        else:
            # 计算空间平均
            data = ds[variable].mean(dim=['latitude', 'longitude'])
        
        fig, ax = plt.subplots(figsize=(15, 5))
        data.plot(x='valid_time', ax=ax)  # 使用valid_time作为x轴
        plt.title(f'{variable} Time Series')
        return fig
    
    def calculate_anomalies(self, ds, climatology_period=None):
        """计算异常值"""
        # 确保使用valid_time
        time_coord = 'valid_time'
        
        if climatology_period is None:
            climatology = ds.groupby(f'{time_coord}.dayofyear').mean(time_coord)
            anomalies = ds.groupby(f'{time_coord}.dayofyear') - climatology
        else:
            start_year, end_year = climatology_period
            mask = (ds[time_coord].dt.year >= start_year) & (ds[time_coord].dt.year <= end_year)
            climatology = ds.sel({time_coord: mask}).groupby(f'{time_coord}.dayofyear').mean(time_coord)
            anomalies = ds.groupby(f'{time_coord}.dayofyear') - climatology
            
        return anomalies
    
    def plot_hovmoller(self, ds, variable, lat_band=None):
        """绘制霍夫莫勒图"""
        if lat_band is not None:
            # 在特定纬度带计算平均
            data = ds[variable].sel(latitude=slice(*lat_band)).mean('latitude')
        else:
            data = ds[variable].mean('latitude')
            
        fig, ax = plt.subplots(figsize=(15, 10))
        data.plot(
            x='longitude',
            y='valid_time',  # 使用valid_time
            cmap='RdBu_r',
            robust=True
        )
        plt.title(f'{variable} Hovmöller Diagram')
        return fig
    
    def analyze_daily_cycle(self, ds, variable, lat=None, lon=None):
        """分析日变化"""
        if lat is not None and lon is not None:
            data = ds[variable].sel(latitude=lat, longitude=lon, method='nearest')
        else:
            data = ds[variable].mean(dim=['latitude', 'longitude'])
        
        # 获取每个时间点的小时
        hours = pd.to_datetime(data.valid_time.values).hour
        
        # 按小时分组计算平均值
        daily_cycle = data.groupby(hours).mean()
        
        # 绘制日变化
        fig, ax = plt.subplots(figsize=(12, 6))
        daily_cycle.plot(ax=ax)
        plt.xlabel('Hour of Day')
        plt.title(f'{variable} Daily Cycle')
        return fig

def main():
    """主函数示例用法"""
    analyzer = ERA5Analyzer()
    
    # 加载数据
    print("加载数据...")
    ds = analyzer.load_data('era5_2021_12.nc')
    
    # 1. 基本统计分析
    print("\n计算基本统计信息：")
    stats = analyzer.basic_statistics(ds)
    print(stats)
    
    # 2. 空间分布图
    print("\n绘制空间分布图...")
    for var in ds.data_vars:
        try:
            fig = analyzer.plot_spatial_mean(ds, var, 
                title=f'Mean {var} Distribution (December 2021)')
            plt.savefig(f'spatial_mean_{var}.png')
            plt.close()
            print(f"已保存 {var} 的空间分布图")
        except Exception as e:
            print(f"绘制 {var} 的空间分布图时出错: {str(e)}")
    
    # 3. 时间序列分析（柏林）
    print("\n绘制时间序列...")
    berlin_lat, berlin_lon = 52.52, 13.41
    for var in ds.data_vars:
        try:
            fig = analyzer.plot_time_series(ds, var, 
                lat=berlin_lat, lon=berlin_lon)
            plt.savefig(f'time_series_{var}_berlin.png')
            plt.close()
            print(f"已保存 {var} 的时间序列图")
        except Exception as e:
            print(f"绘制 {var} 的时间序列图时出错: {str(e)}")
    
    # 4. 日变化分析
    print("\n分析日变化...")
    for var in ds.data_vars:
        try:
            fig = analyzer.analyze_daily_cycle(ds, var)
            plt.savefig(f'daily_cycle_{var}.png')
            plt.close()
            print(f"已保存 {var} 的日变化图")
        except Exception as e:
            print(f"分析 {var} 的日变化时出错: {str(e)}")

if __name__ == "__main__":
    main()

加载数据...
Dataset dimensions: ['valid_time', 'latitude', 'longitude']
Dataset variables: ['t2m', 'tp', 'msl', 'u10', 'v10']

计算基本统计信息：
              mean          std           min            max
t2m     276.098755    10.294481    235.358200     306.281921
tp        0.000095     0.000280      0.000000       0.009905
msl  101433.968750  1176.391968  95805.125000  104607.125000
u10       0.678640     4.715758    -22.936874      28.554581
v10       0.627906     4.453432    -24.695114      23.146225

绘制空间分布图...




已保存 t2m 的空间分布图
已保存 tp 的空间分布图
已保存 msl 的空间分布图
已保存 u10 的空间分布图
已保存 v10 的空间分布图

绘制时间序列...
已保存 t2m 的时间序列图
已保存 tp 的时间序列图
已保存 msl 的时间序列图
已保存 u10 的时间序列图
已保存 v10 的时间序列图

分析日变化...
分析 t2m 的日变化时出错: `group` must be an xarray.DataArray or the name of an xarray variable or dimension. Received Index([ 0,  6, 12, 18,  0,  6, 12, 18,  0,  6,
       ...
       12, 18,  0,  6, 12, 18,  0,  6, 12, 18],
      dtype='int32', length=124) instead.
分析 tp 的日变化时出错: `group` must be an xarray.DataArray or the name of an xarray variable or dimension. Received Index([ 0,  6, 12, 18,  0,  6, 12, 18,  0,  6,
       ...
       12, 18,  0,  6, 12, 18,  0,  6, 12, 18],
      dtype='int32', length=124) instead.
分析 msl 的日变化时出错: `group` must be an xarray.DataArray or the name of an xarray variable or dimension. Received Index([ 0,  6, 12, 18,  0,  6, 12, 18,  0,  6,
       ...
       12, 18,  0,  6, 12, 18,  0,  6, 12, 18],
      dtype='int32', length=124) instead.
分析 u10 的日变化时出错: `group` must be an xarray.DataArray or the na

pca

In [1]:
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import IncrementalPCA
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')  # 忽略警告信息

# 设置matplotlib和seaborn的绘图样式
plt.style.use('default')
sns.set_theme(style="whitegrid")
sns.set_context("talk")

class ERA5Explorer:
    def __init__(self, data_dir='era5_data', output_dir='analysis_results'):
        """
        初始化数据探索器
        """
        self.data_dir = Path(data_dir)
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(exist_ok=True)
        
    def load_and_preprocess(self):
        """加载并预处理ERA5数据"""
        try:
            print("Loading ERA5 data...")
            data_files = sorted(self.data_dir.glob('era5_*.nc'))
            print(f"Found {len(data_files)} data files")

            # 定义要处理的空间范围（例如，选择经度 0-50，纬度 30-60）
            lon_min, lon_max = 0, 50
            lat_min, lat_max = 30, 60

            # 选择要处理的变量，减少变量数量（例如，只处理 't2m' 和 'tp'）
            variables = ['t2m', 'tp']

            # 准备列表来存储所有月份的异常值数据
            anomalies_list = []
            dates_list = []

            # 存储 climatology 和 std 用于后续计算
            climatology_dict = {}
            std_dict = {}

            # 首先遍历所有数据文件，计算每日的气候态和标准差
            print("Calculating daily climatology and standard deviation...")
            daily_data_list = []

            for file in data_files:
                ds = xr.open_dataset(file)

                # 选择感兴趣的区域和变量
                ds = ds.sel(longitude=slice(lon_min, lon_max), latitude=slice(lat_max, lat_min))
                ds = ds[variables]

                # 计算日平均
                ds_daily = ds.resample(valid_time='1D').mean()

                # 将每日数据添加到列表
                daily_data_list.append(ds_daily)

            # 合并所有日平均数据
            all_daily_data = xr.concat(daily_data_list, dim='valid_time')
            all_daily_data = all_daily_data.sortby('valid_time')

            # 处理缺失值，使用前向填充和后向填充
            all_daily_data = all_daily_data.ffill(dim='valid_time').bfill(dim='valid_time')

            # 计算气候态和标准差
            climatology = all_daily_data.groupby('valid_time.dayofyear').mean('valid_time')
            std_dev = all_daily_data.groupby('valid_time.dayofyear').std('valid_time')

            # 现在逐月处理数据，计算异常值
            print("Processing monthly data and calculating anomalies...")
            for file in data_files:
                ds = xr.open_dataset(file)

                # 选择感兴趣的区域和变量
                ds = ds.sel(longitude=slice(lon_min, lon_max), latitude=slice(lat_max, lat_min))
                ds = ds[variables]

                # 计算日平均
                ds_daily = ds.resample(valid_time='1D').mean()

                # 处理缺失值，使用前向填充和后向填充
                ds_daily = ds_daily.ffill(dim='valid_time').bfill(dim='valid_time')

                # 确保日期有序
                ds_daily = ds_daily.sortby('valid_time')

                # 计算异常值
                anomalies = []
                dates = []
                for time in ds_daily['valid_time'].values:
                    day = pd.to_datetime(time).dayofyear

                    # 获取对应的气候态和标准差
                    clim = climatology.sel(dayofyear=day)
                    std = std_dev.sel(dayofyear=day)

                    # 避免除以零
                    std = std.where(std != 0, np.nan)

                    # 计算异常值
                    anomaly = (ds_daily.sel(valid_time=time) - clim) / std
                    anomalies.append(anomaly)
                    dates.append(time)

                # 将 anomalies 合并
                anomalies_month = xr.concat(anomalies, dim='valid_time')
                anomalies_month['valid_time'] = dates

                # 添加到列表
                anomalies_list.append(anomalies_month)
                dates_list.extend(dates)

            # 合并所有月份的异常值数据
            anomalies = xr.concat(anomalies_list, dim='valid_time')
            anomalies = anomalies.sortby('valid_time')

            print("Data preprocessing completed successfully!")
            return anomalies, dates_list

        except Exception as e:
            print(f"Error in load_and_preprocess: {str(e)}")
            raise

    def perform_incremental_pca(self, anomalies, dates, n_components=3, batch_size=5):
        """
        执行增量式PCA分析
        """
        try:
            print("\nPerforming Incremental PCA analysis...")

            # 准备数据
            variables = list(anomalies.data_vars)
            data_arrays = []

            for var in variables:
                if var in anomalies:
                    # 将空间维度展平，并重置索引
                    flat_data = anomalies[var].stack(spatial=['latitude', 'longitude'])
                    flat_data = flat_data.reset_index('spatial')
                    data_arrays.append(flat_data)

            # 合并所有变量
            combined_data = xr.concat(data_arrays, dim='variable')

            # 将变量和空间维度合并为特征维度
            combined_data = combined_data.stack(features=['variable', 'spatial'])

            # 确保数据维度为 (time, features)
            combined_data = combined_data.transpose('valid_time', 'features')

            # 填充NaN值
            combined_data = combined_data.fillna(0)

            # 将数据转换为 NumPy 数组
            data_np = combined_data.values

            # 创建增量式 PCA 对象
            ipca = IncrementalPCA(n_components=n_components)

            # 创建标准化器
            scaler = StandardScaler()

            # 获取样本数量
            n_samples = data_np.shape[0]

            # 计算批次数量
            n_batches = n_samples // batch_size + (n_samples % batch_size != 0)

            # 逐批次处理数据
            pca_result = []
            dates_array = np.array(dates)
            for i in range(n_batches):
                start_idx = i * batch_size
                end_idx = min((i + 1) * batch_size, n_samples)

                # 提取当前批次的数据
                batch_data = data_np[start_idx:end_idx, :]

                # 标准化
                batch_data = scaler.partial_fit_transform(batch_data)

                # 增量式 PCA 拟合和变换
                batch_result = ipca.partial_fit_transform(batch_data)

                pca_result.append(batch_result)

                print(f"Processed batch {i + 1}/{n_batches}")

            # 将所有批次的结果合并
            pca_result = np.concatenate(pca_result, axis=0)

            # 获取解释方差比例和主成分
            explained_variance_ratio = ipca.explained_variance_ratio_
            components = ipca.components_

            return pca_result, explained_variance_ratio, components, variables, dates_array

        except Exception as e:
            print(f"Error in perform_incremental_pca: {str(e)}")
            raise

    def plot_explained_variance(self, explained_variance_ratio):
        """绘制解释方差比例"""
        # ...（保持不变）

    def plot_time_evolution(self, pca_result, dates, n_components=3):
        """绘制主成分的时间演变"""
        # ...（保持不变）
                    
    def analyze_seasonal_variation(self, pca_result, dates, n_components=3):
        """分析季节性变化"""
        # ...（保持不变）

    def plot_component_patterns(self, components, variables, anomalies, n_components=3):
        """绘制主成分模式（平均加载值）"""
        # ...（保持不变）

def main():
    """主函数"""
    explorer = ERA5Explorer(
        data_dir='G:/workflow0822/githubclimate/paper_ERC23_code/era5_data',
        output_dir='analysis_results'
    )

    # 1. 加载和预处理数据
    anomalies, dates = explorer.load_and_preprocess()

    # 2. 执行增量式PCA分析
    pca_result, explained_variance, components, variables, dates_array = explorer.perform_incremental_pca(
        anomalies, dates, n_components=3, batch_size=5  # 减少主成分数量和批大小
    )

    # 3. 可视化结果
    # 解释方差比例
    explorer.plot_explained_variance(explained_variance)

    # 主成分模式
    explorer.plot_component_patterns(components, variables, anomalies, n_components=3)

    # 时间演变
    explorer.plot_time_evolution(pca_result, dates_array, n_components=3)

    # 季节性变化
    explorer.analyze_seasonal_variation(pca_result, dates_array, n_components=3)

    # 4. 打印主要发现
    print("\nKey Findings:")
    print("-------------")
    print(f"Total explained variance by first 3 PCs: {np.sum(explained_variance)*100:.2f}%")

    # 保存结果
    np.savez(explorer.output_dir / 'pca_results.npz',
             pca_result=pca_result,
             explained_variance=explained_variance,
             components=components,
             variables=variables)

if __name__ == "__main__":
    main()


Loading ERA5 data...
Found 7 data files
Calculating daily climatology and standard deviation...
Processing monthly data and calculating anomalies...
Data preprocessing completed successfully!

Performing Incremental PCA analysis...
Error in perform_incremental_pca: 'StandardScaler' object has no attribute 'partial_fit_transform'


AttributeError: 'StandardScaler' object has no attribute 'partial_fit_transform'

Memory optimized version

In [4]:
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')  # 忽略警告信息

# 设置matplotlib和seaborn的绘图样式
plt.style.use('default')
sns.set_theme(style="whitegrid")
sns.set_context("talk")

class ERA5Explorer:
    def __init__(self, data_dir='era5_data', output_dir='analysis_results'):
        """
        初始化数据探索器
        """
        self.data_dir = Path(data_dir)
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(exist_ok=True)
        
    def load_and_preprocess(self):
        """加载并预处理ERA5数据"""
        try:
            print("Loading ERA5 data...")
            data_files = sorted(self.data_dir.glob('era5_*.nc'))
            print(f"Found {len(data_files)} data files")

            # **仅处理一个文件**
            if not data_files:
                raise FileNotFoundError("No data files found in the specified directory.")
            file = data_files[0]
            print(f"Processing file: {file.name}")

            # 定义要处理的空间范围（可以进一步缩小以减少数据量）
            lon_min, lon_max = 0, 50
            lat_min, lat_max = 30, 60

            # 选择要处理的变量，减少变量数量
            variables = ['t2m', 'tp']

            # 打开数据集
            ds = xr.open_dataset(file)

            # 选择感兴趣的区域和变量
            ds = ds.sel(longitude=slice(lon_min, lon_max), latitude=slice(lat_max, lat_min))
            ds = ds[variables]

            # 计算日平均
            ds_daily = ds.resample(valid_time='1D').mean()

            # 处理缺失值，使用前向填充和后向填充
            ds_daily = ds_daily.ffill(dim='valid_time').bfill(dim='valid_time')

            # 确保日期有序
            ds_daily = ds_daily.sortby('valid_time')

            # 计算气候态和标准差（由于只有一个月的数据，这里直接计算平均值和标准差）
            climatology = ds_daily.mean('valid_time')
            std_dev = ds_daily.std('valid_time')

            # 计算异常值
            anomalies = (ds_daily - climatology) / std_dev

            # 避免出现 NaN 和 Inf
            anomalies = anomalies.fillna(0)
            anomalies = anomalies.where(np.isfinite(anomalies), 0)

            # 重置索引（修复错误）
            if 'dayofyear' in anomalies.indexes:
                anomalies = anomalies.reset_index('dayofyear')

            # 确保 'dayofyear' 不在坐标中
            anomalies = anomalies.reset_coords(drop=True)

            dates_list = ds_daily['valid_time'].values

            print("Data preprocessing completed successfully!")
            return anomalies, dates_list

        except Exception as e:
            print(f"Error in load_and_preprocess: {str(e)}")
            raise

    def perform_pca(self, anomalies, dates, n_components=2):
        """
        执行PCA分析
        """
        try:
            print("\nPerforming PCA analysis...")

            # 准备数据
            variables = list(anomalies.data_vars)
            data_arrays = []

            for var in variables:
                if var in anomalies:
                    # 将空间维度展平，并重置索引
                    flat_data = anomalies[var].stack(spatial=['latitude', 'longitude'])
                    flat_data = flat_data.reset_index('spatial')
                    data_arrays.append(flat_data)

            # 合并所有变量
            combined_data = xr.concat(data_arrays, dim='variable')

            # 将变量和空间维度合并为特征维度
            combined_data = combined_data.stack(features=['variable', 'spatial'])

            # 确保数据维度为 (valid_time, features)
            combined_data = combined_data.transpose('valid_time', 'features')

            # 填充NaN值
            combined_data = combined_data.fillna(0)

            # 将数据转换为 NumPy 数组
            data_np = combined_data.values

            # 标准化数据
            scaler = StandardScaler()
            data_scaled = scaler.fit_transform(data_np)

            # 执行PCA
            pca = PCA(n_components=n_components)
            pca_result = pca.fit_transform(data_scaled)

            # 获取解释方差比例和主成分
            explained_variance_ratio = pca.explained_variance_ratio_
            components = pca.components_

            dates_array = np.array(dates)

            return pca_result, explained_variance_ratio, components, variables, dates_array

        except Exception as e:
            print(f"Error in perform_pca: {str(e)}")
            raise

    def plot_explained_variance(self, explained_variance_ratio):
        """绘制解释方差比例"""
        try:
            plt.figure(figsize=(10, 6))
            bars = plt.bar(range(1, len(explained_variance_ratio) + 1), 
                           explained_variance_ratio * 100)
            for bar in bars:
                height = bar.get_height()
                plt.text(bar.get_x() + bar.get_width()/2., height,
                         f'{height:.1f}%',
                         ha='center', va='bottom')
            plt.xlabel('Principal Component')
            plt.ylabel('Explained Variance Ratio (%)')
            plt.title('Explained Variance Ratio by Principal Components')
            plt.grid(True, alpha=0.3)
            # 添加累计解释方差
            cumsum = np.cumsum(explained_variance_ratio) * 100
            ax2 = plt.gca().twinx()
            ax2.plot(range(1, len(explained_variance_ratio) + 1), cumsum, 
                     'r-', label='Cumulative Variance')
            ax2.set_ylabel('Cumulative Explained Variance (%)')
            plt.tight_layout()
            plt.savefig(self.output_dir / 'explained_variance.png', dpi=300, bbox_inches='tight')
            plt.close()
            print(f"Top {len(explained_variance_ratio)} PCs explain {cumsum[-1]:.1f}% of total variance")
        except Exception as e:
            print(f"Error in plot_explained_variance: {str(e)}")

    def plot_time_evolution(self, pca_result, dates, n_components=2):
        """绘制主成分的时间演变"""
        try:
            plt.figure(figsize=(15, 8))
            # 转换dates为datetime对象
            if not isinstance(dates[0], pd.Timestamp):
                dates = pd.to_datetime(dates)
            # 绘制每个主成分的时间序列
            for i in range(min(n_components, pca_result.shape[1])):
                plt.plot(dates, pca_result[:, i], 
                         label=f'PC{i+1}', alpha=0.7)
            plt.xlabel('Time')
            plt.ylabel('PC Value')
            plt.title('Time Evolution of Principal Components')
            plt.legend()
            plt.grid(True, alpha=0.3)
            plt.xticks(rotation=45)
            plt.tight_layout()
            plt.savefig(self.output_dir / 'pc_time_evolution.png', dpi=300, bbox_inches='tight')
            plt.close()
        except Exception as e:
            print(f"Error in plot_time_evolution: {str(e)}")
                
    def plot_component_patterns(self, components, variables, anomalies, n_components=2):
        """绘制主成分模式（平均加载值）"""
        try:
            # 由于组件是展平的特征，需要重新reshape
            n_features = components.shape[1]
            n_variables = len(variables)
            n_spatial = n_features // n_variables

            fig, axes = plt.subplots(n_components, 1, figsize=(12, 4 * n_components))
            if n_components == 1:
                axes = [axes]
            for i, ax in enumerate(axes):
                # 提取第 i 个主成分的加载值
                component = components[i]
                # 将加载值按变量和空间维度重新reshape
                component_reshaped = component.reshape((n_variables, n_spatial))
                mean_loadings = component_reshaped.mean(axis=1)
                ax.bar(variables, mean_loadings)
                ax.set_title(f'PC{i+1} Mean Loadings')
                ax.set_ylabel('Loading')
                ax.grid(True, alpha=0.3)
            plt.tight_layout()
            plt.savefig(self.output_dir / 'component_patterns.png', dpi=300, bbox_inches='tight')
            plt.close()
        except Exception as e:
            print(f"Error in plot_component_patterns: {str(e)}")

def main():
    """主函数"""
    explorer = ERA5Explorer(
        data_dir='G:/workflow0822/githubclimate/paper_ERC23_code/era5_data',
        output_dir='analysis_results'
    )

    # 1. 加载和预处理数据
    anomalies, dates = explorer.load_and_preprocess()

    # 2. 执行PCA分析
    pca_result, explained_variance, components, variables, dates_array = explorer.perform_pca(
        anomalies, dates, n_components=2
    )

    # 3. 可视化结果
    # 解释方差比例
    explorer.plot_explained_variance(explained_variance)

    # 主成分模式
    explorer.plot_component_patterns(components, variables, anomalies, n_components=2)

    # 时间演变
    explorer.plot_time_evolution(pca_result, dates_array, n_components=2)

    # 4. 打印主要发现
    print("\nKey Findings:")
    print("-------------")
    print(f"Total explained variance by first 2 PCs: {np.sum(explained_variance)*100:.2f}%")

    # 保存结果
    np.savez(explorer.output_dir / 'pca_results.npz',
             pca_result=pca_result,
             explained_variance=explained_variance,
             components=components,
             variables=variables)

if __name__ == "__main__":
    main()


Loading ERA5 data...
Found 7 data files
Processing file: era5_2020_01.nc
Data preprocessing completed successfully!

Performing PCA analysis...
Top 2 PCs explain 26.2% of total variance

Key Findings:
-------------
Total explained variance by first 2 PCs: 26.17%


Predictive training

In [12]:
import tensorflow as tf
import numpy as np
import xarray as xr
import pandas as pd
import dask
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import shap
import logging
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')  # 忽略警告信息

class SubseasonalPredictor:
    def __init__(self, base_dir='data'):
        self.base_dir = Path(base_dir)
        self.scaler = StandardScaler()
        self.model = None
        self.logger = self._setup_logger()

    def _setup_logger(self):
        """设置日志记录器"""
        logger = logging.getLogger('SubseasonalPredictor')
        logger.setLevel(logging.INFO)
        # 避免重复添加处理器
        if not logger.handlers:
            ch = logging.StreamHandler()
            ch.setLevel(logging.INFO)
            fh = logging.FileHandler('subseasonal_prediction.log')
            fh.setLevel(logging.INFO)
            formatter = logging.Formatter(
                '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
            )
            ch.setFormatter(formatter)
            fh.setFormatter(formatter)
            logger.addHandler(ch)
            logger.addHandler(fh)
        return logger

    def validate_data_continuity(self, ds):
        """验证数据的时间连续性和质量"""
        # 检查时间步长
        time_diff = np.diff(ds.valid_time.values)
        expected_diff = np.timedelta64(6, 'h')  # 6小时间隔
        if not np.all(time_diff == expected_diff):
            problematic_times = np.where(time_diff != expected_diff)[0]
            self.logger.warning(f"发现时间不连续: {problematic_times}")
        # 检查异常值
        for var in ['tp', 't2m', 'msl', 'u10', 'v10']:
            if var in ds.data_vars:
                data = ds[var].values
                if np.issubdtype(data.dtype, np.number):
                    mean = np.nanmean(data)
                    std = np.nanstd(data)
                    outliers = np.where(np.abs(data - mean) > 5 * std)
                    if len(outliers[0]) > 0:
                        self.logger.warning(f"{var}存在异常值，位置: {outliers}")
                else:
                    self.logger.info(f"{var}不是数值类型，跳过异常值检查")
        # 检查缺失值比例
        for var in ds.data_vars:
            missing_ratio = ds[var].isnull().mean().values
            if missing_ratio > 0:
                self.logger.warning(
                    f"{var}存在缺失值，比例: {missing_ratio:.2%}"
                )

    def prepare_data(self, ds, target_lead_time=28):
        """准备神经网络输入数据"""
        required_vars = ['tp', 't2m', 'msl', 'u10', 'v10']
        missing_vars = [var for var in required_vars
                        if var not in ds.data_vars]
        if missing_vars:
            raise ValueError(f"缺少必需的变量: {missing_vars}")
        # 处理缺失值
        ds = ds.interpolate_na(dim='valid_time', method='linear')
        ds = ds.ffill('valid_time').bfill('valid_time')
        ds = ds.sortby('valid_time')
        # 处理降水累积值
        if 'tp' in ds.data_vars:
            ds['tp'] = ds['tp'].diff('valid_time').fillna(0)
            ds['tp'] = xr.where(ds['tp'] < 0, 0, ds['tp'])
        # 标准化数据
        for var in required_vars:
            if var in ds.data_vars:
                mean = ds[var].mean()
                std = ds[var].std()
                ds[var] = (ds[var] - mean) / std
        # 提取热带区域降水数据
        tropical_precip = ds['tp'].sel(latitude=slice(26, -26))
        # 定义欧洲三个区域
        regions = {
            'north_europe': {'latitude': slice(65, 55),
                             'longitude': slice(-10, 30)},
            'central_europe': {'latitude': slice(55, 45),
                               'longitude': slice(-5, 20)},
            'south_europe': {'latitude': slice(45, 35),
                             'longitude': slice(-10, 30)}
        }
        targets = {}
        for region, coords in regions.items():
            regional_precip = ds['tp'].sel(**coords)
            targets[region] = regional_precip.mean(
                dim=['latitude', 'longitude']
            )
        # 创建输入和输出数据
        X_list, y_list = [], []
        for i in range(len(ds.valid_time) - target_lead_time):
            X_sample = tropical_precip.isel(
                valid_time=i
            ).values
            y_sample = []
            for region in regions:
                y_value = targets[region].isel(
                    valid_time=i + target_lead_time
                ).values
                y_sample.append(np.sign(y_value))
            X_list.append(X_sample)
            y_list.append(y_sample)
        X = np.array(X_list)
        y = np.array(y_list)
        return X, y

    def build_model(self, input_shape):
        """构建神经网络模型"""
        model = tf.keras.Sequential([
            tf.keras.layers.Conv2D(
                32, (3, 3), activation='relu', input_shape=input_shape
            ),
            tf.keras.layers.MaxPooling2D((2, 2)),
            tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
            tf.keras.layers.MaxPooling2D((2, 2)),
            tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(128, activation='relu'),
            tf.keras.layers.Dropout(0.5),
            tf.keras.layers.Dense(3, activation='softmax')
        ])
        model.compile(
            optimizer='adam',
            loss='categorical_crossentropy',
            metrics=['accuracy']
        )
        return model

    def train(self, X_train, y_train, X_val, y_val,
              epochs=50, batch_size=32):
        """训练模型"""
        input_shape = X_train.shape[1:]
        self.model = self.build_model(input_shape)
        history = self.model.fit(
            X_train, y_train,
            epochs=epochs,
            batch_size=batch_size,
            validation_data=(X_val, y_val),
            callbacks=[
                tf.keras.callbacks.EarlyStopping(
                    monitor='val_loss',
                    patience=5,
                    restore_best_weights=True
                )
            ]
        )
        return history

    def explain_predictions(self, X_test, background_samples=100):
        """使用SHAP解释模型预测"""
        explainer = shap.DeepExplainer(
            self.model, X_test[:background_samples]
        )
        shap_values = explainer.shap_values(X_test)
        return shap_values

    def plot_feature_importance(self, shap_values, region_idx=0):
        """绘制特征重要性图"""
        plt.figure(figsize=(15, 10))
        shap.summary_plot(
            shap_values[region_idx],
            plot_type='bar',
            feature_names=[
                f'Feature_{i}' for i in range(shap_values[0].shape[1])
            ]
        )
        plt.show()

    def analyze_decadal_skill(self, predictions, true_values,
                              window_size=3650):
        """分析预测技巧的年代际变化"""
        accuracies = []
        dates = []
        for i in range(0, len(predictions) - window_size, 365):
            window_preds = predictions[i:i + window_size]
            window_true = true_values[i:i + window_size]
            accuracy = np.mean(window_preds == window_true)
            accuracies.append(accuracy)
            dates.append(i / 365 + 1959)  # 假设从1959年开始
        return dates, accuracies

    def plot_decadal_skill(self, dates, accuracies, region_name):
        """绘制预测技巧的年代际变化"""
        plt.figure(figsize=(12, 6))
        plt.plot(dates, accuracies)
        plt.xlabel('Year')
        plt.ylabel('10-year Moving Average Accuracy')
        plt.title(f'Decadal Variation in Prediction Skill: {region_name}')
        plt.grid(True)
        plt.show()

def main():
    """主函数，包含内存管理和错误处理"""
    # 设置dask配置
    dask.config.set(
        {"array.slicing.split_large_chunks": True,
         "array.chunk-size": "100MB"}
    )
    predictor = SubseasonalPredictor()
    try:
        predictor.logger.info("加载数据...")
        ds = xr.open_mfdataset(
            'era5_data/era5_*.nc',
            chunks={'valid_time': -1,
                    'latitude': 'auto',
                    'longitude': 'auto'},
            combine='by_coords'
        )
        predictor.logger.info("验证数据质量...")
        predictor.validate_data_continuity(ds)
        predictor.logger.info("准备数据...")
        X, y = predictor.prepare_data(ds)
        # 划分训练集和测试集
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=0.2, random_state=42
        )
        predictor.logger.info("训练模型...")
        history = predictor.train(X_train, y_train, X_test, y_test)
        predictor.logger.info("模型解释...")
        shap_values = predictor.explain_predictions(X_test)
        predictor.plot_feature_importance(shap_values)
        predictor.logger.info("分析预测技巧...")
        predictions = predictor.model.predict(X_test)
        for i, region in enumerate(['North', 'Central', 'South']):
            pred_labels = np.argmax(predictions, axis=1)
            true_labels = np.argmax(y_test, axis=1)
            dates, accuracies = predictor.analyze_decadal_skill(
                pred_labels, true_labels
            )
            predictor.plot_decadal_skill(dates, accuracies, region)
            plt.savefig(f'decadal_skill_{region.lower()}.png')
            plt.close()
    except Exception as e:
        predictor.logger.error(f"程序执行出错: {str(e)}")
        raise
    finally:
        import gc
        gc.collect()

if __name__ == "__main__":
    main()


2024-11-17 21:13:28,332 - SubseasonalPredictor - INFO - 加载数据...
2024-11-17 21:13:29,290 - SubseasonalPredictor - INFO - 验证数据质量...
       115, 115, 115, 115, 115, 115, 115, 115, 115, 115, 115, 115, 115,
       115, 115, 115, 116, 116, 116, 116, 116, 128, 128, 128, 128, 128,
       128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128,
       128, 128, 128, 128, 128, 128, 128, 128, 577, 577, 577, 577, 577,
       577, 577, 578, 578, 578, 578, 578, 578, 578, 578, 578, 578, 578,
       578, 578, 578, 578, 578, 578, 578, 578, 578, 578, 579, 579, 579,
       579, 579, 579, 579, 579, 579, 579, 579, 579, 579, 579, 579, 579,
       579, 579, 579, 579, 579, 579, 579, 579, 579, 579, 579, 579, 579,
       579, 579, 579, 579, 579, 579, 579, 579, 579, 579, 579, 579, 579,
       579, 579, 579, 579, 579, 579, 579, 579, 579, 579, 579, 579, 579,
       579, 579, 579, 579, 579, 579, 579, 579, 579, 579, 579, 579, 579,
       579, 579, 579, 579, 579, 579, 579, 579, 579, 579, 579, 579, 579,
      

ValueError: dimension valid_time on 0th function argument to apply_ufunc with dask='parallelized' consists of multiple chunks, but is also a core dimension. To fix, either rechunk into a single array chunk along this dimension, i.e., ``.chunk(dict(valid_time=-1))``, or pass ``allow_rechunk=True`` in ``dask_gufunc_kwargs`` but beware that this may significantly increase memory usage.