### 无重叠重建：处理后交接

In [None]:
import os
import time
import pandas as pd
import torch
import argparse
import numpy as np
import xarray as xr 
from torch import optim
import yaml
import zarr
import torch.nn as nn

import sys
sys.path.append(r"/root/ST-Conv")
from tool.utils import Util
from model.model import STConvNet
from model.block import BasicBlock
from dataset.dataset_process import data_preparation
from log.logger import setup_logger
def reorder_variables(zarr_file_path, config_file_path):
    # 加载配置文件
    with open(config_file_path, 'r') as file:
        vars_config = yaml.safe_load(file)

    # 提取变量顺序
    vars_order = list(vars_config.keys())

    # 加载zarr文件
    data = xr.open_zarr(zarr_file_path,consolidated=True)

    # 确保所有配置文件中的变量都在zarr文件中
    missing_vars = [var for var in vars_order if var not in data.data_vars]
    if missing_vars:
        print(f"Warning: The following variables from the config are not present in the zarr file: {missing_vars}")
    
    # 按照配置文件中的顺序重新排列变量
    data_reordered = xr.Dataset({var: data[var] for var in vars_order if var in data.data_vars})

    print(f"Variables reordered according to the configuration.")
    return data_reordered

def normalize_data(std_values, mean_values, data):
    """
    标准化 NetCDF 数据集中的变量。

    :param std_path: 包含标准差的文件路径。
    :param mean_path: 包含均值的文件路径。
    :param data: xarray 数据集。
    :return: 标准化后的 xarray 数据集。
    """
    # 检查长度是否匹配
    if len(std_values) != len(mean_values) or len(std_values) != len(data.variables)-3:
        raise ValueError("标准差、均值的长度与数据集中的变量数不匹配")

    # 对每个变量应用标准化
    for var, std, mean in zip(data.variables, std_values, mean_values):
        if var in ['lat', 'lon', 'time']:
            continue  # 跳过非数据维度
        data[var] = (data[var] - mean) / std

    return data

def batch_process(model, day_sample, half_day_sample, hour_sample, static_sample, batch_size, device, std=None, mean=None,mode=None):
    time = day_sample.shape[0]
    outputs = []
    if mode == 'ALL':
        day_sample, half_day_sample, hour_sample, static_sample = day_sample.to(device), half_day_sample.to(device), hour_sample.to(device), static_sample.to(device)
        output = model(hour_input=hour_sample, day_input=day_sample, half_day_input=half_day_sample, static_input=static_sample)
        output = output[0].detach().cpu().numpy() if isinstance(output, tuple) else output.detach().cpu().numpy()
        if std is not None and mean is not None:
            output = output * std + mean
        outputs = np.squeeze(output, axis=1)
    else:
        for start_idx in range(0, time, batch_size):
            end_idx = min(start_idx + batch_size, time)
            day_sample, half_day_sample, hour_sample, static_sample = day_sample.to(device), half_day_sample.to(device), hour_sample.to(device), static_sample.to(device)
            current_day_sample, current_half_day_sample, current_hour_sample, current_static_sample = day_sample[start_idx:end_idx], half_day_sample[start_idx:end_idx], hour_sample[start_idx:end_idx], static_sample[start_idx:end_idx]
            current_day_sample,current_half_day_sample,current_hour_sample,current_static_sample = current_day_sample.to(device),current_half_day_sample.to(device),current_hour_sample.to(device),current_static_sample.to(device)
            output = model(hour_input=current_hour_sample, day_input=current_day_sample, half_day_input=current_half_day_sample, static_input=current_static_sample)
            output = output[0].detach().cpu().numpy() if isinstance(output, tuple) else output.detach().cpu().numpy()
            if std is not None and mean is not None:
                output = output * std + mean
            output = np.squeeze(output, axis=1)
            outputs.append(output)
        outputs = np.concatenate(outputs, axis=0)

    return outputs

def split_dataset_filling_faster(vars_index_map, vars_day_argu, Day_time_step, vars_half_day_argu, Half_Day_time_step, vars_hour_argu, Hour_time_step, vars_static_argu, data):
    # channels, time, lat, lon
    time, lon, lat = data.shape[1], data.shape[3], data.shape[2]
    vars_day_list, vars_hour_list, vars_static_list = list(vars_day_argu.keys()), list(vars_hour_argu.keys()), list(vars_static_argu.keys())
    
    day_increment = 24//3  # 24小时频率，每3小时一个时间点
    half_day_increment = 12//3  # 12小时频率
    hour_increment = 1  # 3小时频率

    # 初始化半天，每天和每小时的数据列表
    doy_index = vars_index_map['DOY']  # 假设DOY是变量名
    half_day_data = torch.zeros((time, 2, Half_Day_time_step, lat, lon), dtype=torch.float32)
    day_data = torch.zeros((time, len(vars_day_list)-1, Day_time_step, lat, lon), dtype=torch.float32)
    hour_data = torch.zeros((time, len(vars_hour_list)-1, Hour_time_step, lat, lon), dtype=torch.float32)
    static_data = torch.zeros((time, len(vars_static_list)-1, lat, lon), dtype=torch.float32)

    # 计算每日、每半日和每小时数据的offsets和indices
    day_offsets = np.arange(-(Day_time_step // 2), (Day_time_step + 1) // 2)
    half_day_offsets = np.arange(-(Half_Day_time_step // 2), (Half_Day_time_step + 1) // 2)
    hour_offsets = np.arange(-(Hour_time_step // 2), (Hour_time_step + 1) // 2)

    # 索引矩阵
    day_indices = np.add.outer(np.arange(time), day_offsets * day_increment)
    half_day_indices = np.add.outer(np.arange(time), half_day_offsets * half_day_increment)
    hour_indices = np.add.outer(np.arange(time), hour_offsets * hour_increment)
    static_indices = np.arange(time)

    # 确保所有索引都在有效范围内
    day_indices = np.clip(day_indices, 0, time - 1)
    half_day_indices = np.clip(half_day_indices, 0, time - 1)
    hour_indices = np.clip(hour_indices, 0, time - 1)

    # 变量的索引列表
    day_var_indices = [vars_index_map[var] for var in vars_day_list[1:]]
    hour_var_indices = [vars_index_map[var] for var in vars_hour_list[1:]]
    static_var_indices = [vars_index_map[var] for var in vars_static_list[1:]]

    # 使用索引列表直接提取每日数据
    day_data_slices = data[day_var_indices][:, day_indices, :, :]
    day_data_slices = np.nan_to_num(day_data_slices)
    day_data = torch.from_numpy(day_data_slices).permute(1, 0, 2, 3, 4).float()

    # 使用索引列表直接提取每小时数据
    hour_data_slices = data[hour_var_indices][:, hour_indices, :, :]
    hour_data_slices = np.nan_to_num(hour_data_slices)
    hour_data = torch.from_numpy(hour_data_slices).permute(1, 0, 2, 3, 4).float()

    # 遍历每半日数据变量
    # 第一个片为SM空数据，第二片为DOY数据
    doy_data_slices = data[doy_index, half_day_indices, :, :]
    half_day_data[:, 1, :, :, :] = torch.from_numpy(np.nan_to_num(doy_data_slices))

    # 静态数据
    static_data_slices = data[static_var_indices][:, static_indices, :, :]
    static_data_slices = np.nan_to_num(static_data_slices)
    static_data = torch.from_numpy(static_data_slices).permute(1, 0, 2, 3).float()
    
    # batch_size,channels,time,lat,lon
    return day_data, half_day_data, hour_data, static_data

# 重叠区域计算函数
def calculate_overlap(first_data, second_data):
    '''
    对于重叠区域，计算两个数据块的均值。
    '''
    overlap_mask = np.where(first_data != 0, 1, 0)
    inv_overlap_mask = 1 - overlap_mask
    average_data = ((first_data * overlap_mask + second_data * overlap_mask) / 2) + inv_overlap_mask * second_data
    return average_data

def reconstruct_data(file_path,vars_index_map,vars_day_argu, Day_time_step, vars_half_day_argu, Half_Day_time_step, vars_hour_argu,Hour_time_step, vars_static_argu, data, model, patch_size=16, stride=12,std = None,mean = None,time_batch_size=100, time_overlap=18, mode=None):
    '''7
    先挨个将数据集从从左下开始划分，划分之后直接进行预测，填充回原数据集，
    划分不完整的数据，最后重新从右上开始划分，全部预测后，只填充没有被划分到数据
    '''
    times, lat, lon = data.time.shape[0], data.lat.shape[0], data.lon.shape[0]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()
    reconstructed_data = np.zeros((times, lat, lon))
    time_shift = max(int(Day_time_step // 2), int(Half_Day_time_step // 2), int(Hour_time_step // 2))
        # 从左下开始预测
    for time_start in range(0, times, time_batch_size - time_overlap):
        batch_start_time = time.time()  # 开始计时

        # 偏移量计算
        effective_start = time_start if time_start == 0 else time_start + time_shift
        effective_end = time_end

        time_end = min(time_start + time_batch_size, times)
        data_batch = data.isel(time=slice(time_start, time_end)).to_array().compute().values  # channels, times, lat, lon
        if mode is None:
            # 从左下开始预测
            for i in range(0, lat - patch_size + 1, stride):
                for j in range(0, lon - patch_size + 1, stride):
                    data_chunk = data_batch[:, :, i:i+patch_size, j:j+patch_size]  # channels, times, lat, lon
                    day_sample, half_day_sample, hour_sample, static_sample = split_dataset_filling_faster(vars_index_map, vars_day_argu, Day_time_step, vars_half_day_argu, Half_Day_time_step, vars_hour_argu, Hour_time_step, vars_static_argu, data_chunk)
                    output = batch_process(model, day_sample, half_day_sample, hour_sample, static_sample, batch_size=16, device=device, std=std, mean=mean)
                    reconstructed_data[effective_start:effective_end, i:i+patch_size, j:j+patch_size] = output[(time_shift if time_start + time_shift < time_end else 0):] if time_start != 0 else output

            # 纵向不完整处理
            if lon % patch_size != 0:
                for j in range(0, lat - patch_size + 1, stride):
                    data_chunk = data_batch[:, :, j:j+patch_size, lon - patch_size:lon]
                    day_sample, half_day_sample, hour_sample, static_sample = split_dataset_filling_faster(vars_index_map, vars_day_argu, Day_time_step, vars_half_day_argu, Half_Day_time_step, vars_hour_argu, Hour_time_step, vars_static_argu, data_chunk)
                    output = batch_process(model, day_sample, half_day_sample, hour_sample, static_sample, batch_size=16, device=device, std=std, mean=mean)
                    reconstructed_data[effective_start:effective_end, j:j+patch_size, lon - patch_size:lon] = output[(time_shift if time_start + time_shift < time_end else 0):] if time_start != 0 else output

            # 横向不完整处理
            if lat % patch_size != 0:
                for i in range(0, lon - patch_size + 1, stride):
                    data_chunk = data_batch[:, :, lat - patch_size:lat, i:i+patch_size]
                    day_sample, half_day_sample, hour_sample, static_sample = split_dataset_filling_faster(vars_index_map, vars_day_argu, Day_time_step, vars_half_day_argu, Half_Day_time_step, vars_hour_argu, Hour_time_step, vars_static_argu, data_chunk)
                    output = batch_process(model, day_sample, half_day_sample, hour_sample, static_sample, batch_size=16, device=device, std=std, mean=mean)
                    reconstructed_data[effective_start:effective_end, lat - patch_size:lat, i:i+patch_size] = output[(time_shift if time_start + time_shift < time_end else 0):] if time_start != 0 else output

            if lon % patch_size != 0 and lat % patch_size != 0:
                data_chunk = data_batch[:, :, lat - patch_size:lat, lon - patch_size:lon]
                day_sample, half_day_sample, hour_sample, static_sample = split_dataset_filling_faster(vars_index_map, vars_day_argu, Day_time_step, vars_half_day_argu, Half_Day_time_step, vars_hour_argu, Hour_time_step, vars_static_argu, data_chunk)
                output = batch_process(model, day_sample, half_day_sample, hour_sample, static_sample, batch_size=16, device=device, std=std, mean=mean)  
                reconstructed_data[effective_start:effective_end, lat - patch_size:lat, lon - patch_size:lon] = output[(time_shift if time_start + time_shift < time_end else 0):] if time_start != 0 else output

        elif mode == 'all':
            day_sample, half_day_sample, hour_sample, static_sample = split_dataset_filling_faster(vars_index_map, vars_day_argu, Day_time_step, vars_half_day_argu, Half_Day_time_step, vars_hour_argu, Hour_time_step, vars_static_argu, data_batch)
            output = batch_process(model, day_sample, half_day_sample, hour_sample, static_sample, batch_size=2, device=device, std=std, mean=mean)
            reconstructed_data[effective_start:effective_end, :, :] = output[(time_shift if time_start + time_shift < time_end else 0):] if time_start != 0 else output
            
        batch_end_time = time.time()  # 结束计时
        print(f"处理时间段 {time_start} 到 {time_end} 完成. 总耗时：{batch_end_time - batch_start_time:.2f}s")
        reconstructed_data = np.clip(reconstructed_data, 0, 1)
        np.save(file_path, reconstructed_data)
    return reconstructed_data

def main(args):
    models_config = Util.load_config(args.models_config_path)
    train_config = models_config['train_shared_parameter'] 
    path_config = Util.load_config(args.path_config_path)
    dataloader_hyperparameter_config = Util.load_config(args.dataloader_hyperparameter_config_path)
    vars_argu = Util.load_config(args.vars_config_path)
    vars_day_argu = Util.load_config(args.vars_day_config_path)
    vars_half_day_argu = Util.load_config(args.vars_half_day_config_path)
    vars_hour_argu = Util.load_config(args.vars_hour_config_path)
    vars_static_argu = Util.load_config(args.vars_static_config_path)
    logger = setup_logger(path_config['log'],append=True)
    var_list = list(vars_argu.keys())
    return train_config, models_config, path_config, dataloader_hyperparameter_config, logger, var_list, vars_argu, vars_day_argu, vars_half_day_argu, vars_hour_argu, vars_static_argu

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--models_config_path', type=str, default='../config/models_config.yaml')
    parser.add_argument('--path_config_path', type=str, default='../config/path_config.yaml')
    parser.add_argument('--dataloader_hyperparameter_config_path', type=str, default='../config/dataloader_hyperparameter_config.yaml')
    parser.add_argument('--vars_config_path', type=str, default='../config/construct_Data/vars_config.yaml')
    parser.add_argument('--vars_day_config_path', type=str, default='../config/construct_Data/vars_day_config.yaml')
    parser.add_argument('--vars_half_day_config_path', type=str, default='../config/construct_Data/vars_half_day_config.yaml')
    parser.add_argument('--vars_hour_config_path', type=str, default='../config/construct_Data/vars_hour_config.yaml')
    parser.add_argument('--vars_static_config_path', type=str, default='../config/construct_Data/vars_static_config.yaml')
    args = parser.parse_known_args()[0]

    train_dict, models_config, path_config, dataloader_hyperparameter_config, logger, var_list, vars_argu, vars_day_argu, vars_half_day_argu, vars_hour_argu, vars_static_argu = main(args)
    Util.random_seed(seed=train_dict['seed'])
    data = reorder_variables(path_config['0.37_normal_Data'], r"/root/ST-Conv/config/construct_Data/vars_config.yaml")
    vars_index_map = {var: i for i, var in enumerate(data.data_vars)}
    # 解包model_config
    experiment_groups = models_config['experiment_groups']      # 实验组参数-实验组共享参数+实验组模型参数
    std = np.load("/root/autodl-fs/ST_Data/std_mean/std.npy")
    mean = np.load("/root/autodl-fs/ST_Data/std_mean/mean.npy")

    for group in experiment_groups:
        if group['group_name'] == "Experiment Group 7":
            print(f"Running experiments for group: {group['group_name']}")
            shared_parameters = models_config['Shared_parameter'] 
            experiment_shared_params = group.get('experiment_shared_parameter', {})

            # 训练实验组中每个模型
            for model_config in group['models']:
                model = STConvNet(**shared_parameters, **experiment_shared_params, **model_config['parameters'])
                device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
                model = model.to(device)
                if torch.cuda.device_count() > 1:
                    print(f"Let's use {torch.cuda.device_count()} GPUs!")
                model = nn.DataParallel(model)
                # 设置优化器
                optimizer = optim.AdamW(model.parameters(), lr=0, weight_decay=0)  # 优化器
                Util.load_model_and_optimizer(model, optimizer, device, experiment_shared_params['best_model_path'], logger=logger, model_index=-1,model_name=model_config['model_name'])
                file_path = os.path.join(path_config['reconstruct_data_path'], model_config['model_name'] + '16.npy')
                reconstructed_data = reconstruct_data(file_path, vars_index_map, vars_day_argu, shared_parameters['day_step'], vars_half_day_argu, shared_parameters['half_day_step'], vars_hour_argu,shared_parameters['hour_step'], vars_static_argu, data, model, patch_size=16, stride=16, std = std[0],mean = mean[0],mode = 'all')


### 重叠重建

In [None]:
import os
import time
import pandas as pd
import torch
import argparse
import numpy as np
import xarray as xr 
from torch import optim
import yaml
import zarr
import torch.nn as nn

import sys
sys.path.append(r"/root/ST-Conv")
from tool.utils import Util
from model.model import STConvNet
from model.block import BasicBlock
from dataset.dataset_process import data_preparation
from log.logger import setup_logger
def reorder_variables(zarr_file_path, config_file_path):
    # 加载配置文件
    with open(config_file_path, 'r') as file:
        vars_config = yaml.safe_load(file)

    # 提取变量顺序
    vars_order = list(vars_config.keys())

    # 加载zarr文件
    data = xr.open_zarr(zarr_file_path,consolidated=True)

    # 确保所有配置文件中的变量都在zarr文件中
    missing_vars = [var for var in vars_order if var not in data.data_vars]
    if missing_vars:
        print(f"Warning: The following variables from the config are not present in the zarr file: {missing_vars}")
    
    # 按照配置文件中的顺序重新排列变量
    data_reordered = xr.Dataset({var: data[var] for var in vars_order if var in data.data_vars})

    print(f"Variables reordered according to the configuration.")
    return data_reordered

def normalize_data(std_values, mean_values, data):
    """
    标准化 NetCDF 数据集中的变量。

    :param std_path: 包含标准差的文件路径。
    :param mean_path: 包含均值的文件路径。
    :param data: xarray 数据集。
    :return: 标准化后的 xarray 数据集。
    """
    # 检查长度是否匹配
    if len(std_values) != len(mean_values) or len(std_values) != len(data.variables)-3:
        raise ValueError("标准差、均值的长度与数据集中的变量数不匹配")

    # 对每个变量应用标准化
    for var, std, mean in zip(data.variables, std_values, mean_values):
        if var in ['lat', 'lon', 'time']:
            continue  # 跳过非数据维度
        data[var] = (data[var] - mean) / std

    return data

def batch_process(model, day_sample, half_day_sample, hour_sample, static_sample, batch_size, device, std=None, mean=None,mode=None):
    time = day_sample.shape[0]
    outputs = []
    if mode == 'ALL':
        day_sample, half_day_sample, hour_sample, static_sample = day_sample.to(device), half_day_sample.to(device), hour_sample.to(device), static_sample.to(device)
        output = model(hour_input=hour_sample, day_input=day_sample, half_day_input=half_day_sample, static_input=static_sample)
        output = output[0].detach().cpu().numpy() if isinstance(output, tuple) else output.detach().cpu().numpy()
        if std is not None and mean is not None:
            output = output * std + mean
        outputs = np.squeeze(output, axis=1)
    else:
        for start_idx in range(0, time, batch_size):
            end_idx = min(start_idx + batch_size, time)
            day_sample, half_day_sample, hour_sample, static_sample = day_sample.to(device), half_day_sample.to(device), hour_sample.to(device), static_sample.to(device)
            current_day_sample, current_half_day_sample, current_hour_sample, current_static_sample = day_sample[start_idx:end_idx], half_day_sample[start_idx:end_idx], hour_sample[start_idx:end_idx], static_sample[start_idx:end_idx]
            current_day_sample,current_half_day_sample,current_hour_sample,current_static_sample = current_day_sample.to(device),current_half_day_sample.to(device),current_hour_sample.to(device),current_static_sample.to(device)
            output = model(hour_input=current_hour_sample, day_input=current_day_sample, half_day_input=current_half_day_sample, static_input=current_static_sample)
            output = output[0].detach().cpu().numpy() if isinstance(output, tuple) else output.detach().cpu().numpy()
            if std is not None and mean is not None:
                output = output * std + mean
            output = np.squeeze(output, axis=1)
            outputs.append(output)
        outputs = np.concatenate(outputs, axis=0)

    return outputs

def split_dataset_filling_faster(vars_index_map, vars_day_argu, Day_time_step, vars_half_day_argu, Half_Day_time_step, vars_hour_argu, Hour_time_step, vars_static_argu, data):
    # channels, time, lat, lon
    time, lon, lat = data.shape[1], data.shape[3], data.shape[2]
    vars_day_list, vars_hour_list, vars_static_list = list(vars_day_argu.keys()), list(vars_hour_argu.keys()), list(vars_static_argu.keys())
    
    day_increment = 24//3  # 24小时频率，每3小时一个时间点
    half_day_increment = 12//3  # 12小时频率
    hour_increment = 1  # 3小时频率

    # 初始化半天，每天和每小时的数据列表
    doy_index = vars_index_map['DOY']  # 假设DOY是变量名
    half_day_data = torch.zeros((time, 2, Half_Day_time_step, lat, lon), dtype=torch.float32)
    day_data = torch.zeros((time, len(vars_day_list)-1, Day_time_step, lat, lon), dtype=torch.float32)
    hour_data = torch.zeros((time, len(vars_hour_list)-1, Hour_time_step, lat, lon), dtype=torch.float32)
    static_data = torch.zeros((time, len(vars_static_list)-1, lat, lon), dtype=torch.float32)

    # 计算每日、每半日和每小时数据的offsets和indices
    day_offsets = np.arange(-(Day_time_step // 2), (Day_time_step + 1) // 2)
    half_day_offsets = np.arange(-(Half_Day_time_step // 2), (Half_Day_time_step + 1) // 2)
    hour_offsets = np.arange(-(Hour_time_step // 2), (Hour_time_step + 1) // 2)

    # 索引矩阵
    day_indices = np.add.outer(np.arange(time), day_offsets * day_increment)
    half_day_indices = np.add.outer(np.arange(time), half_day_offsets * half_day_increment)
    hour_indices = np.add.outer(np.arange(time), hour_offsets * hour_increment)
    static_indices = np.arange(time)

    # 确保所有索引都在有效范围内
    day_indices = np.clip(day_indices, 0, time - 1)
    half_day_indices = np.clip(half_day_indices, 0, time - 1)
    hour_indices = np.clip(hour_indices, 0, time - 1)

    # 变量的索引列表
    day_var_indices = [vars_index_map[var] for var in vars_day_list[1:]]
    hour_var_indices = [vars_index_map[var] for var in vars_hour_list[1:]]
    static_var_indices = [vars_index_map[var] for var in vars_static_list[1:]]

    # 使用索引列表直接提取每日数据
    day_data_slices = data[day_var_indices][:, day_indices, :, :]
    day_data_slices = np.nan_to_num(day_data_slices)
    day_data = torch.from_numpy(day_data_slices).permute(1, 0, 2, 3, 4).float()

    # 使用索引列表直接提取每小时数据
    hour_data_slices = data[hour_var_indices][:, hour_indices, :, :]
    hour_data_slices = np.nan_to_num(hour_data_slices)
    hour_data = torch.from_numpy(hour_data_slices).permute(1, 0, 2, 3, 4).float()

    # 遍历每半日数据变量
    # 第一个片为SM空数据，第二片为DOY数据
    doy_data_slices = data[doy_index, half_day_indices, :, :]
    half_day_data[:, 1, :, :, :] = torch.from_numpy(np.nan_to_num(doy_data_slices))

    # 静态数据
    static_data_slices = data[static_var_indices][:, static_indices, :, :]
    static_data_slices = np.nan_to_num(static_data_slices)
    static_data = torch.from_numpy(static_data_slices).permute(1, 0, 2, 3).float()
    
    # batch_size,channels,time,lat,lon
    return day_data, half_day_data, hour_data, static_data

# 重叠区域计算函数
def calculate_overlap(first_data, second_data):
    '''
    对于重叠区域，计算两个数据块的均值。
    '''
    overlap_mask = np.where(first_data != 0, 1, 0)
    inv_overlap_mask = 1 - overlap_mask
    average_data = ((first_data * overlap_mask + second_data * overlap_mask) / 2) + inv_overlap_mask * second_data
    return average_data

def reconstruct_data(file_path,vars_index_map,vars_day_argu, Day_time_step, vars_half_day_argu, Half_Day_time_step, vars_hour_argu,Hour_time_step, vars_static_argu, data, model, patch_size=16, stride=12,std = None,mean = None,time_batch_size=100, time_overlap=18, mode=None):
    '''
    先挨个将数据集从从左下开始划分，划分之后直接进行预测，填充回原数据集，
    划分不完整的数据，最后重新从右上开始划分，全部预测后，只填充没有被划分到数据
    '''
    times, lat, lon = data.time.shape[0], data.lat.shape[0], data.lon.shape[0]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()
    reconstructed_data = np.zeros((times, lat, lon))
    time_shift = max(int(Day_time_step // 2), int(Half_Day_time_step // 2), int(Hour_time_step // 2))
        # 从左下开始预测
    for time_start in range(0, times, time_batch_size - time_overlap):
        batch_start_time = time.time()  # 开始计时

        # 偏移量计算
        effective_start = time_start if time_start == 0 else time_start + time_shift
        effective_end = time_end

        time_end = min(time_start + time_batch_size, times)
        data_batch = data.isel(time=slice(time_start, time_end)).to_array().compute().values  # channels, times, lat, lon
        if mode is None:
            # 从左下开始预测
            for i in range(0, lat - patch_size + 1, stride):
                for j in range(0, lon - patch_size + 1, stride):
                    data_chunk = data_batch[:, :, i:i+patch_size, j:j+patch_size]  # channels, times, lat, lon
                    day_sample, half_day_sample, hour_sample, static_sample = split_dataset_filling_faster(vars_index_map, vars_day_argu, Day_time_step, vars_half_day_argu, Half_Day_time_step, vars_hour_argu, Hour_time_step, vars_static_argu, data_chunk)
                    output = batch_process(model, day_sample, half_day_sample, hour_sample, static_sample, batch_size=16, device=device, std=std, mean=mean)
                    reconstructed_data[time_start:time_end, i:i+patch_size, j:j+patch_size] = calculate_overlap(reconstructed_data[time_start:time_end, i:i+patch_size, j:j+patch_size], output)
                    reconstructed_data[effective_start:effective_end, i:i+patch_size, j:j+patch_size] = calculate_overlap(reconstructed_data[effective_start:effective_end, i:i+patch_size, j:j+patch_size], output, time_shift if time_start != 0 else 0)

            # 纵向不完整处理
            if lon % patch_size != 0:
                for j in range(0, lat - patch_size + 1, stride):
                    data_chunk = data_batch[:, :, j:j+patch_size, lon - patch_size:lon]
                    day_sample, half_day_sample, hour_sample, static_sample = split_dataset_filling_faster(vars_index_map, vars_day_argu, Day_time_step, vars_half_day_argu, Half_Day_time_step, vars_hour_argu, Hour_time_step, vars_static_argu, data_chunk)
                    output = batch_process(model, day_sample, half_day_sample, hour_sample, static_sample, batch_size=16, device=device, std=std, mean=mean)
                    reconstructed_data[time_start:time_end, j:j+patch_size, lon - patch_size:lon] = calculate_overlap(reconstructed_data[time_start:time_end, j:j+patch_size, lon - patch_size:lon], output)
            # 横向不完整处理
            if lat % patch_size != 0:
                for i in range(0, lon - patch_size + 1, stride):
                    data_chunk = data_batch[:, :, lat - patch_size:lat, i:i+patch_size]
                    day_sample, half_day_sample, hour_sample, static_sample = split_dataset_filling_faster(vars_index_map, vars_day_argu, Day_time_step, vars_half_day_argu, Half_Day_time_step, vars_hour_argu, Hour_time_step, vars_static_argu, data_chunk)
                    output = batch_process(model, day_sample, half_day_sample, hour_sample, static_sample, batch_size=16, device=device, std=std, mean=mean)
                    reconstructed_data[time_start:time_end, lat - patch_size:lat, i:i+patch_size] = calculate_overlap(reconstructed_data[time_start:time_end, lat - patch_size:lat, i:i+patch_size], output)
            if lon % patch_size != 0 and lat % patch_size != 0:
                data_chunk = data_batch[:, :, lat - patch_size:lat, lon - patch_size:lon]
                day_sample, half_day_sample, hour_sample, static_sample = split_dataset_filling_faster(vars_index_map, vars_day_argu, Day_time_step, vars_half_day_argu, Half_Day_time_step, vars_hour_argu, Hour_time_step, vars_static_argu, data_chunk)
                output = batch_process(model, day_sample, half_day_sample, hour_sample, static_sample, batch_size=16, device=device, std=std, mean=mean)
                reconstructed_data[time_start:time_end, lat - patch_size:lat, lon - patch_size:lon] = calculate_overlap(reconstructed_data[time_start:time_end, lat - patch_size:lat, lon - patch_size:lon], output)    
        elif mode == 'all':
            day_sample, half_day_sample, hour_sample, static_sample = split_dataset_filling_faster(vars_index_map, vars_day_argu, Day_time_step, vars_half_day_argu, Half_Day_time_step, vars_hour_argu, Hour_time_step, vars_static_argu, data_batch)
            output = batch_process(model, day_sample, half_day_sample, hour_sample, static_sample, batch_size=2, device=device, std=std, mean=mean)
            reconstructed_data[time_start:time_end, :, :] = output
        batch_end_time = time.time()  # 结束计时
        print(f"处理时间段 {time_start} 到 {time_end} 完成. 总耗时：{batch_end_time - batch_start_time:.2f}s")
        reconstructed_data = np.clip(reconstructed_data, 0, 1)
        np.save(file_path, reconstructed_data)
    return reconstructed_data

def main(args):
    models_config = Util.load_config(args.models_config_path)
    train_config = models_config['train_shared_parameter'] 
    path_config = Util.load_config(args.path_config_path)
    dataloader_hyperparameter_config = Util.load_config(args.dataloader_hyperparameter_config_path)
    vars_argu = Util.load_config(args.vars_config_path)
    vars_day_argu = Util.load_config(args.vars_day_config_path)
    vars_half_day_argu = Util.load_config(args.vars_half_day_config_path)
    vars_hour_argu = Util.load_config(args.vars_hour_config_path)
    vars_static_argu = Util.load_config(args.vars_static_config_path)
    logger = setup_logger(path_config['log'],append=True)
    var_list = list(vars_argu.keys())
    return train_config, models_config, path_config, dataloader_hyperparameter_config, logger, var_list, vars_argu, vars_day_argu, vars_half_day_argu, vars_hour_argu, vars_static_argu

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--models_config_path', type=str, default='../config/models_config.yaml')
    parser.add_argument('--path_config_path', type=str, default='../config/path_config.yaml')
    parser.add_argument('--dataloader_hyperparameter_config_path', type=str, default='../config/dataloader_hyperparameter_config.yaml')
    parser.add_argument('--vars_config_path', type=str, default='../config/construct_Data/vars_config.yaml')
    parser.add_argument('--vars_day_config_path', type=str, default='../config/construct_Data/vars_day_config.yaml')
    parser.add_argument('--vars_half_day_config_path', type=str, default='../config/construct_Data/vars_half_day_config.yaml')
    parser.add_argument('--vars_hour_config_path', type=str, default='../config/construct_Data/vars_hour_config.yaml')
    parser.add_argument('--vars_static_config_path', type=str, default='../config/construct_Data/vars_static_config.yaml')
    args = parser.parse_known_args()[0]

    train_dict, models_config, path_config, dataloader_hyperparameter_config, logger, var_list, vars_argu, vars_day_argu, vars_half_day_argu, vars_hour_argu, vars_static_argu = main(args)
    Util.random_seed(seed=train_dict['seed'])
    data = reorder_variables(path_config['0.37_normal_Data'], r"/root/ST-Conv/config/construct_Data/vars_config.yaml")
    vars_index_map = {var: i for i, var in enumerate(data.data_vars)}
    # 解包model_config
    experiment_groups = models_config['experiment_groups']      # 实验组参数-实验组共享参数+实验组模型参数
    std = np.load("/root/autodl-fs/ST_Data/std_mean/std.npy")
    mean = np.load("/root/autodl-fs/ST_Data/std_mean/mean.npy")

    for group in experiment_groups:
        if group['group_name'] == "Experiment Group 7":
            print(f"Running experiments for group: {group['group_name']}")
            shared_parameters = models_config['Shared_parameter'] 
            experiment_shared_params = group.get('experiment_shared_parameter', {})

            # 训练实验组中每个模型
            for model_config in group['models']:
                model = STConvNet(**shared_parameters, **experiment_shared_params, **model_config['parameters'])
                device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
                model = model.to(device)
                if torch.cuda.device_count() > 1:
                    print(f"Let's use {torch.cuda.device_count()} GPUs!")
                model = nn.DataParallel(model)
                # 设置优化器
                optimizer = optim.AdamW(model.parameters(), lr=0, weight_decay=0)  # 优化器
                Util.load_model_and_optimizer(model, optimizer, device, experiment_shared_params['best_model_path'], logger=logger, model_index=-1,model_name=model_config['model_name'])
                file_path = os.path.join(path_config['reconstruct_data_path'], model_config['model_name'] + '16.npy')
                reconstructed_data = reconstruct_data(file_path, vars_index_map, vars_day_argu, shared_parameters['day_step'], vars_half_day_argu, shared_parameters['half_day_step'], vars_hour_argu,shared_parameters['hour_step'], vars_static_argu, data, model, patch_size=16, stride=16, std = std[0],mean = mean[0])


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.block import FFN
class TemporalDilatedConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding, dilation):
        super(TemporalDilatedConvBlock, self).__init__()
        self.expand_ratio = 6
        self.conv_time = nn.Sequential(nn.Conv3d(in_channels, out_channels, kernel_size=(kernel_size, 1, 1), padding=(padding, 0, 0), dilation=(dilation, 1, 1)),
                                    nn.GELU())
        self.conv1 = nn.Sequential(nn.Conv3d(out_channels, out_channels*self.expand_ratio, kernel_size=(1, 1, 1)),
                                   nn.GELU())
        self.conv2 = nn.Sequential(nn.Conv3d(out_channels*self.expand_ratio, out_channels, kernel_size=(1, 1, 1)),
                                   nn.GELU())
    def forward(self, x):
        x = self.conv_time(x)
        x = self.conv1(x)
        x = self.conv2(x)
        return x

class TemporalDilatedConv(nn.Module):
    def __init__(self, hour_in_channels,hour_step,day_in_channels,day_step,half_day_in_channels,half_day_step,static_in_channels, kernel_size, target_time_steps,temporal_usage_max_avg=True,temporal_usage_position_embedding=True,reduction_ratio=8):
        super(TemporalDilatedConv, self).__init__()
        # 时间维度的卷积核大小为2，不使用填充
        self.expand_ratio = 6 
        self.temporal_usage_max_avg = temporal_usage_max_avg
        self.temporal_usage_position_embedding = temporal_usage_position_embedding
        self.target_time_steps = target_time_steps
        if temporal_usage_position_embedding == True:
            hour_in_channels -= 1
            day_in_channels -= 1
            half_day_in_channels -= 1
            static_in_channels -= 2

        self.hour_block = self._create_block(hour_in_channels, hour_in_channels, hour_step, kernel_size)
        self.day_block = self._create_block(day_in_channels, day_in_channels, day_step, kernel_size)
        self.half_day_block = self._create_block(half_day_in_channels, half_day_in_channels, half_day_step, kernel_size)  # 数据输入通道只有一个

        self.out_channels = hour_in_channels + day_in_channels + half_day_in_channels 

        self.final_conv = FFN(self.out_channels, self.out_channels, num_layers=6, expansion_ratio=6)

        self.out_channels = self.out_channels  + static_in_channels + self.extra_channel if temporal_usage_max_avg == True else self.out_channels + static_in_channels
        
        self.conv_2d = nn.Sequential(nn.Conv2d(self.out_channels, self.out_channels*self.expand_ratio, kernel_size=1),
                                 nn.GELU(),
                                 nn.Conv2d(self.out_channels*self.expand_ratio, self.out_channels, kernel_size=1))
        
    def _create_block(self, in_channels, out_channels, time_steps, kernel_size):
        layers = []
        current_time_steps = time_steps
        dilation = 1
        while current_time_steps > self.target_time_steps:
            required_reduction = current_time_steps - self.target_time_steps
            adjusted_kernel_size = required_reduction + 1 if required_reduction + 1 < kernel_size else kernel_size
            padding = 0

            layers.append(TemporalDilatedConvBlock(in_channels, out_channels, adjusted_kernel_size, padding, dilation))
            current_time_steps = current_time_steps - dilation * (adjusted_kernel_size - 1)
            in_channels = out_channels
            dilation *= 2

        return nn.Sequential(*layers)

    def process_temporal_data(self, input_data):
        processed_data = input_data[:, :-1, :, :, :]
        # 将时间通道加到其他所有通道上#time_channel = input_data[:, -1:, :, :, :]
        processed_data += input_data[:, -1:, :, :, :]

        return processed_data

    def process_static_data(self, static_input, hour_input, day_input, half_day_input):
        # 提取经度和纬度通道
        lon = static_input[:, -2:-1, :, :]  # [batch_size, 1, height, width]
        lat = static_input[:, -1:, :, :]    # [batch_size, 1, height, width]

        # 从原始数据中移除经度和纬度通道
        processed_static = static_input[:, :-2, :, :]

        # 扩展经度和纬度
        lon_expanded_hour = lon.unsqueeze(2).expand(-1, -1, hour_input.shape[2], -1, -1)
        lat_expanded_hour = lat.unsqueeze(2).expand(-1, -1, hour_input.shape[2], -1, -1)

        lon_expanded_day = lon.unsqueeze(2).expand(-1, -1, day_input.shape[2], -1, -1)
        lat_expanded_day = lat.unsqueeze(2).expand(-1, -1, day_input.shape[2], -1, -1)

        lon_expanded_half_day = lon.unsqueeze(2).expand(-1, -1, half_day_input.shape[2], -1, -1)
        lat_expanded_half_day = lat.unsqueeze(2).expand(-1, -1, half_day_input.shape[2], -1, -1)

        # 将经度和纬度加到其他所有通道上
        hour_input += lon_expanded_hour + lat_expanded_hour
        day_input += lon_expanded_day + lat_expanded_day
        half_day_input += lon_expanded_half_day + lat_expanded_half_day

        return processed_static, hour_input, day_input, half_day_input

    def get_out_channels(self):
        return self.out_channels

    def forward(self, hour_input, day_input, half_day_input, static_input):
        if self.temporal_usage_position_embedding == True:

            # 绝对时间编码
            hour_input = self.process_temporal_data(hour_input)
            day_input = self.process_temporal_data(day_input)
            half_day_input = self.process_temporal_data(half_day_input)

            # 绝对位置编码
            static_input,hour_input,day_input,half_day_input = self.process_static_data(static_input,hour_input,day_input,half_day_input)

        hour_output = self.hour_block(hour_input)
        day_output = self.day_block(day_input)
        half_day_output = self.half_day_block(half_day_input)
        
        combined_output = torch.cat((hour_output, day_output, half_day_output), dim=1)
        final_output = self.final_conv(combined_output)
        final_output = final_output.squeeze(2)

        final_output = torch.cat((final_output, static_input), dim=1)
        final_outconv = self.conv_2d(final_output)
        return final_outconv
    

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.block import FFN


class SqueezeLayer(nn.Module):
    def __init__(self, dim):
        super(SqueezeLayer, self).__init__()
        self.dim = dim
    def forward(self, x):
        return x.squeeze(self.dim)
    
class TemporalDilatedConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding, dilation):
        super(TemporalDilatedConvBlock, self).__init__()
        self.expand_ratio = 6
        self.conv_first = nn.Sequential(nn.Conv3d(in_channels, out_channels*self.expand_ratio, kernel_size=(kernel_size, 1, 1), padding=(padding, 0, 0), dilation=(dilation, 1, 1)),
                                    nn.GELU())
        self.conv1= nn.Sequential(nn.Conv3d(out_channels*self.expand_ratio, out_channels, kernel_size=(1, 1, 1)),
                                nn.GELU())
        
    def forward(self, x):
        x_initial = self.conv_first(x)
        x = self.conv1(x_initial)
        return x

class TemporalDilatedConv(nn.Module):
    def __init__(self, hour_in_channels,hour_step,day_in_channels,day_step,half_day_in_channels,half_day_step,static_in_channels, kernel_size, target_time_steps,temporal_usage_max_avg=True,temporal_usage_position_embedding=True,reduction_ratio=8):
        super(TemporalDilatedConv, self).__init__()
        # 时间维度的卷积核大小为2，不使用填充
        self.expand_ratio = 6 
        self.temporal_usage_max_avg = temporal_usage_max_avg
        self.temporal_usage_position_embedding = temporal_usage_position_embedding
        self.target_time_steps = target_time_steps
        if temporal_usage_position_embedding == True:
            hour_in_channels -= 1
            day_in_channels -= 1
            half_day_in_channels -= 1
            static_in_channels -= 2
        self.num_layers = 3
        self.hour_blocks = self._create_block(hour_in_channels, hour_in_channels, hour_step, kernel_size, num_layers = self.num_layers)
        self.day_blocks = self._create_block(day_in_channels, day_in_channels, day_step, kernel_size, num_layers = self.num_layers)
        self.half_day_blocks = self._create_block(half_day_in_channels, half_day_in_channels, half_day_step, kernel_size, num_layers = self.num_layers)  # 数据输入通道只有一个
        self.out_channels = hour_in_channels + day_in_channels + half_day_in_channels + static_in_channels

        self.final_conv = FFN(self.out_channels, self.out_channels, num_layers=4, expansion_ratio=6)

    def _create_block(self, in_channels, out_channels, time_steps, kernel_size, num_layers):
        TCN_layers = []
        current_time_steps = time_steps
        dilation = 1
        layer_count = 0
        while current_time_steps > self.target_time_steps or layer_count < num_layers:
            padding = (kernel_size - 1) * dilation // 2 if layer_count < num_layers else 0
            TCN_layers.append(TemporalDilatedConvBlock(in_channels, out_channels, kernel_size, padding, dilation))

            current_time_steps = max(current_time_steps - dilation * (kernel_size - 1), 1) if layer_count >= num_layers else time_steps-1
            layer_count += 1
            dilation = 1 if layer_count == num_layers else dilation * 2
            
        if current_time_steps > 1:
            TCN_layers.append(nn.Sequential(nn.Conv3d(in_channels, out_channels, kernel_size=(current_time_steps, 1, 1),padding=(0, 0, 0)),
                                            nn.GELU()))
        return TCN_layers

    def process_temporal_data(self, input_data):
        processed_data = input_data[:, :-1, :, :, :]
        # 将时间通道加到其他所有通道上#time_channel = input_data[:, -1:, :, :, :]
        processed_data += input_data[:, -1:, :, :, :]

        return processed_data

    def process_static_data(self, static_input, hour_input, day_input, half_day_input):
        # 提取经度和纬度通道
        lon = static_input[:, -2:-1, :, :]  # [batch_size, 1, height, width]
        lat = static_input[:, -1:, :, :]    # [batch_size, 1, height, width]

        # 从原始数据中移除经度和纬度通道
        processed_static = static_input[:, :-2, :, :]

        # 扩展经度和纬度
        lon_expanded_hour = lon.unsqueeze(2).expand(-1, -1, hour_input.shape[2], -1, -1)
        lat_expanded_hour = lat.unsqueeze(2).expand(-1, -1, hour_input.shape[2], -1, -1)

        lon_expanded_day = lon.unsqueeze(2).expand(-1, -1, day_input.shape[2], -1, -1)
        lat_expanded_day = lat.unsqueeze(2).expand(-1, -1, day_input.shape[2], -1, -1)

        lon_expanded_half_day = lon.unsqueeze(2).expand(-1, -1, half_day_input.shape[2], -1, -1)
        lat_expanded_half_day = lat.unsqueeze(2).expand(-1, -1, half_day_input.shape[2], -1, -1)

        # 将经度和纬度加到其他所有通道上
        hour_input += lon_expanded_hour + lat_expanded_hour
        day_input += lon_expanded_day + lat_expanded_day
        half_day_input += lon_expanded_half_day + lat_expanded_half_day

        return processed_static, hour_input, day_input, half_day_input

    def get_out_channels(self):
        return self.out_channels
    
    def apply_layers(self, layers, x):
        for layer in layers:
            x = layer(x)
        return x
    
    def forward(self, hour_input, day_input, half_day_input, static_input):
        if self.temporal_usage_position_embedding == True:

            # 绝对时间编码
            hour_input = self.process_temporal_data(hour_input)
            day_input = self.process_temporal_data(day_input)
            half_day_input = self.process_temporal_data(half_day_input)

            # 绝对位置编码
            static_input,hour_input,day_input,half_day_input = self.process_static_data(static_input,hour_input,day_input,half_day_input)

        hour_output = self.apply_layers(self.hour_blocks, hour_input, mode='TCN')
        day_output = self.apply_layers(self.day_blocks, day_input, mode='TCN')
        half_day_output = self.general_layers(self.half_day_blocks, half_day_input, mode='TCN')
    

        combined_out = torch.cat([hour_concat,day_concat,half_day_concat,static_input], dim=1)
        final_out = self.final_conv(combined_out)

        return final_out
    