### 数据标准化

In [3]:
import numpy as np
import xarray as xr
def normalize_data(std_dict, mean_dict, data):
    # 提前定义分块策略
    chunks = {'time': 1}  # 例如，这里我们将时间维度的块大小设置为1

    for var in data.data_vars:
        if var in ['lat', 'lon', 'time']:
            continue
        mean = mean_dict[var]
        std = std_dict[var]

        # 避免除以零
        std = std if std != 0 else np.nan

        # 为变量应用分块策略
        temp_data = data[var].chunk(chunks)

        # 执行标准化，并重新赋值给数据集
        data[var] = ((temp_data - mean) / std).chunk(chunks)
    
    return data
data = xr.open_dataset(r"D:\Data_Store\Dataset\Original_Data\0.1_Data.zarr")
std_dict = np.load(r"D:\Data_Store\Dataset\ST_Conv\std_mean\std_dict.npy", allow_pickle=True).item()
mean_dict = np.load(r"D:\Data_Store\Dataset\ST_Conv\std_mean\mean_dict.npy", allow_pickle=True).item()

data = normalize_data(std_dict, mean_dict, data)
data = data.isel(time=slice(None, -1))
data.to_zarr('D:/Data_Store/Dataset/Original_Data/0.1_normal_Data.zarr', mode='w')

<xarray.backends.zarr.ZarrStore at 0x1b3e1b8c340>

### 样本的并行生成办法

In [1]:
import os
import time
import pandas as pd
import torch
import argparse
import numpy as np
import xarray as xr 
from torch import optim
import yaml
import zarr
import torch.nn as nn

import sys
sys.path.append(r"C:\Users\Administrator\Desktop\code\ST-Conv")
from tool.utils import Util

def reorder_variables(zarr_file_path, config_file_path):
    # 加载配置文件
    with open(config_file_path, 'r') as file:
        vars_config = yaml.safe_load(file)

    # 提取变量顺序
    vars_order = list(vars_config.keys())

    # 加载zarr文件
    data = xr.open_zarr(zarr_file_path,consolidated=True)

    # 确保所有配置文件中的变量都在zarr文件中
    missing_vars = [var for var in vars_order if var not in data.data_vars]
    if missing_vars:
        print(f"Warning: The following variables from the config are not present in the zarr file: {missing_vars}")
    
    # 按照配置文件中的顺序重新排列变量
    data_reordered = xr.Dataset({var: data[var] for var in vars_order if var in data.data_vars})

    print(f"Variables reordered according to the configuration.")
    return data_reordered



def split_dataset_filling_faster(vars_index_map, vars_day_argu, Day_time_step, vars_half_day_argu, Half_Day_time_step, vars_hour_argu, Hour_time_step, vars_static_argu, data):
    # channels, time, lat, lon
    time, lon, lat = data.shape[1], data.shape[3], data.shape[2]
    vars_day_list,vars_half_day_list,vars_hour_list, vars_static_list = list(vars_day_argu.keys()), list(vars_half_day_argu.keys()) ,list(vars_hour_argu.keys()), list(vars_static_argu.keys())
    sm_index = vars_index_map[vars_half_day_list[1]]
    sm_data = data[sm_index, :, :, :]
    valid_counts_per_time_step = np.sum(~np.isnan(sm_data), axis=(1, 2))
    total_count = 16 * 16 * 0.8

    complete_time_indices = np.where(valid_counts_per_time_step > total_count)[0]

    if complete_time_indices.shape[0] == 0:
        return np.array([]), np.array([]), np.array([]), np.array([])
    day_increment = 24//3  # 24小时频率，每3小时一个时间点
    half_day_increment = 12//3  # 12小时频率
    hour_increment = 1  # 3小时频率

    # 初始化半天，每天和每小时的数据列表
    doy_index = vars_index_map['DOY']  # 假设DOY是变量名
    valid_time_length = len(complete_time_indices)
    half_day_data = torch.zeros((valid_time_length, 2, Half_Day_time_step, lon, lat), dtype=torch.float32)
    day_data = torch.zeros((valid_time_length, len(vars_day_list)-1, Day_time_step, lon, lat), dtype=torch.float32)
    hour_data = torch.zeros((valid_time_length, len(vars_hour_list)-1, Hour_time_step, lon, lat), dtype=torch.float32)
    static_data = torch.zeros((valid_time_length, len(vars_static_list)-1, lon, lat), dtype=torch.float32)

    # 计算每日、每半日和每小时数据的offsets和indices
    day_offsets = np.arange(-(Day_time_step // 2), (Day_time_step + 1) // 2)
    half_day_offsets = np.arange(-(Half_Day_time_step // 2), (Half_Day_time_step + 1) // 2)
    hour_offsets = np.arange(-(Hour_time_step // 2), (Hour_time_step + 1) // 2)

    # 索引矩阵
    day_indices = np.add.outer(complete_time_indices, day_offsets * day_increment)
    half_day_indices = np.add.outer(complete_time_indices, half_day_offsets * half_day_increment)
    hour_indices = np.add.outer(complete_time_indices, hour_offsets * hour_increment)
    static_indices = complete_time_indices

    # 确保所有索引都在有效范围内
    day_indices = np.clip(day_indices, 0, time - 1)
    half_day_indices = np.clip(half_day_indices, 0, time - 1)
    hour_indices = np.clip(hour_indices, 0, time - 1)

    # 变量的索引列表
    day_var_indices = [vars_index_map[var] for var in vars_day_list[1:]]
    hour_var_indices = [vars_index_map[var] for var in vars_hour_list[1:]]
    static_var_indices = [vars_index_map[var] for var in vars_static_list[1:]]

    # 使用索引列表直接提取每日数据
    day_data_slices = data[day_var_indices][:, day_indices, :, :]
    day_data_slices = np.nan_to_num(day_data_slices)
    day_data = torch.from_numpy(day_data_slices).permute(1, 0, 2, 3, 4).float()

    # 使用索引列表直接提取每小时数据
    hour_data_slices = data[hour_var_indices][:, hour_indices, :, :]
    hour_data_slices = np.nan_to_num(hour_data_slices)
    hour_data = torch.from_numpy(hour_data_slices).permute(1, 0, 2, 3, 4).float()

    # 遍历每半日数据变量
    # 第一个片为SM空数据，第二片为DOY数据
    doy_data_slices = data[doy_index, half_day_indices, :, :]
    sm_half_day_data_slices = data[sm_index, half_day_indices, :, :]
    half_day_data[:, 0, :, :, :] = torch.from_numpy(np.nan_to_num(sm_half_day_data_slices))
    half_day_data[:, 1, :, :, :] = torch.from_numpy(np.nan_to_num(doy_data_slices))

    # 静态数据
    static_data_slices = data[static_var_indices][:, static_indices, :, :]
    static_data_slices = np.nan_to_num(static_data_slices)
    static_data = torch.from_numpy(static_data_slices).permute(1, 0, 2, 3).float()
    
    # batch_size,channels,time,lat,lon
    return day_data.numpy(), half_day_data.numpy(), hour_data.numpy(), static_data.numpy()

def split_data(vars_index_map, vars_day_argu, Day_time_step, vars_half_day_argu, Half_Day_time_step, vars_hour_argu, Hour_time_step, vars_static_argu, data, patch_size=16, stride=8, time_batch_size=200, time_overlap=17):
    times, lat, lon = data.time.shape[0], data.lat.shape[0], data.lon.shape[0]

    # 定义数据保存路径
    save_path_day = "D:/Data_Store/Dataset/ST_Data/Day"
    save_path_half_day = "D:/Data_Store/Dataset/ST_Data/Half_Day"
    save_path_hour = "D:/Data_Store/Dataset/ST_Data/Hour"
    save_path_static = "D:/Data_Store/Dataset/ST_Data/Static"

    # 确保保存路径存在
    os.makedirs(save_path_day, exist_ok=True)
    os.makedirs(save_path_half_day, exist_ok=True)
    os.makedirs(save_path_hour, exist_ok=True)
    os.makedirs(save_path_static, exist_ok=True)

    for time_start in range(0, times, time_batch_size - time_overlap):
        batch_start_time = time.time()
        time_end = min(time_start + time_batch_size, times)
        data_batch = data.isel(time=slice(time_start, time_end)).to_array().compute().values

        day_samples, half_day_samples, hour_samples, static_samples = None, None, None, None

        # 从左下开始预测（示例中未显示预测逻辑，仅数据切分和保存）
        for i in range(0, lat - patch_size + 1, stride):
            for j in range(0, lon - patch_size + 1, stride):
                data_chunk = data_batch[:, :, i:i+patch_size, j:j+patch_size]
                day_sample, half_day_sample, hour_sample, static_sample = split_dataset_filling_faster(vars_index_map, vars_day_argu, Day_time_step, vars_half_day_argu, Half_Day_time_step, vars_hour_argu, Hour_time_step, vars_static_argu, data_chunk)
                if day_sample.size > 0:
                    day_samples = np.concatenate((day_samples, day_sample), axis=0) if day_samples is not None else day_sample
                    half_day_samples = np.concatenate((half_day_samples, half_day_sample), axis=0) if half_day_samples is not None else half_day_sample
                    hour_samples = np.concatenate((hour_samples, hour_sample), axis=0) if hour_samples is not None else hour_sample
                    static_samples = np.concatenate((static_samples, static_sample), axis=0) if static_samples is not None else static_sample

        # 保存数据到指定路径，每个时间批次保存一次，无需再次合并
        np.save(os.path.join(save_path_day, f'day_data_{time_start}_{time_end}.npy'), day_samples)
        np.save(os.path.join(save_path_half_day, f'half_day_data_{time_start}_{time_end}.npy'), half_day_samples)
        np.save(os.path.join(save_path_hour, f'hour_data_{time_start}_{time_end}.npy'), hour_samples)
        np.save(os.path.join(save_path_static, f'static_data_{time_start}_{time_end}.npy'), static_samples)

        batch_end_time = time.time()
        print(f"处理时间段 {time_start} 到 {time_end} 完成. 总耗时：{batch_end_time - batch_start_time:.2f}s")

    print("数据切分和保存完成")

def main(args):
    models_config = Util.load_config(args.models_config_path)
    path_config = Util.load_config(args.path_config_path)
    vars_day_argu = Util.load_config(args.vars_day_config_path)
    vars_half_day_argu = Util.load_config(args.vars_half_day_config_path)
    vars_hour_argu = Util.load_config(args.vars_hour_config_path)
    vars_static_argu = Util.load_config(args.vars_static_config_path)
    data = reorder_variables(path_config['0.37_normal_Data'], "C:\\Users\\Administrator\\Desktop\\code\\ST-Conv\\config\\construct_Data\\vars_config.yaml")
    vars_index_map = {var: i for i, var in enumerate(data.data_vars)}
    # 数据分割和保存
    split_data(vars_index_map, vars_day_argu, 5, vars_half_day_argu, 5, vars_hour_argu, 5, vars_static_argu, data, patch_size=32, stride=6, time_batch_size=50, time_overlap=17)
    print("数据切分和保存过程完成。")

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--models_config_path', type=str, default='../config/models_config.yaml')
    parser.add_argument('--path_config_path', type=str, default='../config/path_config.yaml')
    parser.add_argument('--vars_day_config_path', type=str, default='../config/construct_Data/vars_day_config.yaml')
    parser.add_argument('--vars_half_day_config_path', type=str, default='../config/construct_Data/vars_half_day_config.yaml')
    parser.add_argument('--vars_hour_config_path', type=str, default='../config/construct_Data/vars_hour_config.yaml')
    parser.add_argument('--vars_static_config_path', type=str, default='../config/construct_Data/vars_static_config.yaml')
    args = parser.parse_known_args()[0]

    main(args)

Variables reordered according to the configuration.
处理时间段 0 到 50 完成. 总耗时：949.10s
处理时间段 33 到 83 完成. 总耗时：11.80s
处理时间段 66 到 116 完成. 总耗时：11.38s
处理时间段 99 到 149 完成. 总耗时：11.17s
处理时间段 132 到 182 完成. 总耗时：11.63s
处理时间段 165 到 215 完成. 总耗时：10.96s
处理时间段 198 到 248 完成. 总耗时：10.95s
处理时间段 231 到 281 完成. 总耗时：10.92s
处理时间段 264 到 314 完成. 总耗时：10.94s
处理时间段 297 到 347 完成. 总耗时：10.91s
处理时间段 330 到 380 完成. 总耗时：9.05s
处理时间段 363 到 413 完成. 总耗时：11.19s
处理时间段 396 到 446 完成. 总耗时：11.40s
处理时间段 429 到 479 完成. 总耗时：11.15s
处理时间段 462 到 512 完成. 总耗时：10.59s
处理时间段 495 到 545 完成. 总耗时：10.95s
处理时间段 528 到 578 完成. 总耗时：10.69s
处理时间段 561 到 611 完成. 总耗时：10.73s
处理时间段 594 到 644 完成. 总耗时：10.58s
处理时间段 627 到 677 完成. 总耗时：10.67s
处理时间段 660 到 710 完成. 总耗时：10.49s


#### 拼接数据

In [None]:
import os
import numpy as np

def list_and_sort_files(directory):
    # 列出目录下的所有文件
    files = os.listdir(directory)
    # 完整路径
    files = [os.path.join(directory, file) for file in files]
    # 排序
    files.sort(key=lambda x: os.path.getmtime(x))
    return files
def concatenate_data(files):
    if not files:
        return np.array([])  # 如果没有文件，返回空数组

    # 读取第一个文件以初始化数据
    data = np.load(files[0])
    
    # 读取剩余文件并拼接
    for file in files[1:]:
        next_data = np.load(file)
        data = np.concatenate((data, next_data), axis=0)
    
    return data
def save_concatenated_data(directory):
    files = list_and_sort_files(directory)
    data = concatenate_data(files)
    if data.size > 0:  # 确保数据不为空
        # 构造保存文件的路径和名称
        base_name = os.path.basename(directory)  # 获取文件夹的最后一部分作为基础名
        save_path = os.path.join(directory, base_name + ".npy")  # 构造保存路径
        np.save(save_path, data)  # 保存数据
        print(f"Data saved to {save_path}")

directories = [
    "D:\\Data_Store\\Dataset\\ST_Conv\\Data\\Day",
    "D:\\Data_Store\\Dataset\\ST_Conv\\Data\\Half_Day",
    "D:\\Data_Store\\Dataset\\ST_Conv\\Data\\Hour",
    "D:\\Data_Store\\Dataset\\ST_Conv\\Data\\Static"
]

for directory in directories:
    save_concatenated_data(directory)

#### 查看数据分布情况

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# 加载数据
ST_Static_Data_path = 'D:/Data_Store/Dataset/ST_Data/Static/Static_Data.npy'
ST_Day_Data_path = 'D:/Data_Store/Dataset/ST_Data/Day/Day_Data.npy'
ST_Half_Day_Data_path = "D:\Data_Store\Dataset\ST_Data\Half_Day\Half_Day.npy"

ST_Half_Day_Data = np.load(ST_Half_Day_Data_path)


# 计算 ST_Half_Day_Data 中每个样本的缺失率
# 缺失数据定义为 NaN
relevant_slice = ST_Half_Day_Data[:, 0, 4, :, :]
zero_nan_proportions = np.mean((relevant_slice == 0) | np.isnan(relevant_slice), axis=(1, 2))
# 获取缺失率大于 40% 的索引
high_missing_indices = np.where(zero_nan_proportions > 0.2)[0]

# 绘制连线图
plt.hist(zero_nan_proportions, bins=np.arange(0, 1.01, 0.01), edgecolor='black')
plt.xlabel('0和NaN数据占比')
plt.ylabel('样本数')
plt.title('ST_Half_Day_Data[:, 4, :, :, :]中0和NaN数据占比分布')
plt.grid(True)
plt.show()

#### 维度转换

In [None]:
import os
import numpy as np
directories = [
    "/root/autodl-tmp/ST_Data/Day",
    "/root/autodl-tmp/ST_Data/Half_Day",
    "/root/autodl-tmp/ST_Data/Hour",
    "/root/autodl-tmp/ST_Data/Static"
]
for directory in directories:
    # 获取目录的名称
    dirname = os.path.basename(directory)
    # 构造完整路径
    file_path = os.path.join(directory, dirname + '.npy')
    
    # 检查文件
    if os.path.exists(file_path):
        data = np.load(file_path)
        print(f"Original dimensions for {file_path}: {data.shape}")
        
        # 检查数据维度，确保它至少有两个维度
        if data.ndim >= 2:
            new_data = np.moveaxis(data, 1, -1)
            print(f"New dimensions for {file_path}: {new_data.shape}")
            np.save(file_path, new_data)
            print(f"Updated file saved to {file_path}")
        else:
            print(f"File {file_path} does not have enough dimensions to reorder.")
    else:
        print(f"File {file_path} does not exist.")