In [2]:
import sys
import torch
import argparse
sys.path.append(r"C:\Users\Administrator\Desktop\code\ST-Conv")
from torchstat import stat
from model.block import BasicBlock,Bottleneck
from model.model import STConvNet
from tool.utils import Util
from torchsummaryX import summary
from fvcore.nn import FlopCountAnalysis, parameter_count_table

def flops_to_string(flops, units='GFLOPs', precision=2):
    """Convert FLOPs to a human readable string."""
    if units == 'GFLOPs':
        flops_count = flops / 1e9
    elif units == 'MFLOPs':
        flops_count = flops / 1e6
    else:  # Default to FLOPs
        flops_count = flops
    return f"{flops_count:.{precision}f} {units}"

def main(args):
    model_config = Util.load_config(args.model_config_path)
    path_config = Util.load_config(args.path_config_path)
    return model_config, path_config

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--path_config_path', type=str, default='../config/path_config.yaml')
    parser.add_argument('--model_config_path', type=str, default='../config/model_config.yaml')
    args = parser.parse_known_args()[0]
    Hyperparameter_dict, path_dict = main(args)
    
    # 设定设备
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 从Hyperparameter_dict获取batch_size，或者设置一个特定的值
    batch_size = Hyperparameter_dict['batch_size']

    # 创建模拟输入
    hour_input_shape = (batch_size, Hyperparameter_dict['hour_in_channels'], 17, 16, 16)
    day_input_shape = (batch_size, Hyperparameter_dict['day_in_channels'], 5, 16, 16)
    half_day_input_shape = (batch_size, 1, 9, 16, 16)
    static_input_shape = (batch_size, 11, 16, 16)

    # 创建模拟数据
    hour_input = torch.randn(hour_input_shape)
    day_input = torch.randn(day_input_shape)
    half_day_input = torch.randn(half_day_input_shape)
    static_input = torch.randn(static_input_shape)
    print("Hour input shape:", hour_input_shape)
    print("Day input shape:", day_input_shape)
    print("Half day input shape:", half_day_input_shape)
    print("Static input shape:", static_input_shape)
    # 创建模型
    model = STConvNet(Hyperparameter_dict['hour_in_channels'],Hyperparameter_dict['hour_step'], 
                      Hyperparameter_dict['day_in_channels'],Hyperparameter_dict['day_step'], 
                      Hyperparameter_dict['half_day_in_channels'],Hyperparameter_dict['half_day_step'], 
                      Hyperparameter_dict['static_in_channels'],Hyperparameter_dict['temporal_kernel_size'], Hyperparameter_dict['target_time_steps'],
                    BasicBlock, Hyperparameter_dict['layer18'], Hyperparameter_dict['spatial_out_channels'])

    # 模型分析
    model.cpu()
    summary(model, hour_input, day_input, half_day_input, static_input)
    # FLOPs 和 参数 统计
    flop_analyzer = FlopCountAnalysis(model, (hour_input, day_input, half_day_input, static_input))


    # 在模型分析的代码中使用这个函数
    flops = flop_analyzer.total()
    print(f"FLOPs: {flops_to_string(flops, 'MFLOPs')}")
    print(parameter_count_table(model))

Hour input shape: (32, 8, 17, 16, 16)
Day input shape: (32, 22, 5, 16, 16)
Half day input shape: (32, 1, 9, 16, 16)
Static input shape: (32, 11, 16, 16)


  df_sum = df.sum()


                                                          Kernel Shape          Output Shape     Params   Mult-Adds
Layer                                                                                                              
0_temporal_conv.hour_block.0.Conv3d_conv              [8, 16, 2, 1, 1]  [32, 16, 16, 16, 16]      272.0   1.048576M
1_temporal_conv.hour_block.0.BatchNorm3d_batchnorm                [16]  [32, 16, 16, 16, 16]       32.0        16.0
2_temporal_conv.hour_block.0.LeakyReLU_relu                          -  [32, 16, 16, 16, 16]          -           -
3_temporal_conv.hour_block.1.Conv3d_conv             [16, 16, 2, 1, 1]  [32, 16, 14, 16, 16]      528.0   1.835008M
4_temporal_conv.hour_block.1.BatchNorm3d_batchnorm                [16]  [32, 16, 14, 16, 16]       32.0        16.0
5_temporal_conv.hour_block.1.LeakyReLU_relu                          -  [32, 16, 14, 16, 16]          -           -
6_temporal_conv.hour_block.2.Conv3d_conv             [16, 16, 2, 1, 1]  

Unsupported operator aten::add_ encountered 44 time(s)
Unsupported operator aten::leaky_relu_ encountered 9 time(s)
Unsupported operator aten::max_pool2d encountered 1 time(s)


FLOPs: 29917.54 MFLOPs
| name                               | #elements or shape    |
|:-----------------------------------|:----------------------|
| model                              | 56.7M                 |
|  temporal_conv                     |  23.9K                |
|   temporal_conv.hour_block         |   2.0K                |
|    temporal_conv.hour_block.0      |    0.3K               |
|    temporal_conv.hour_block.1      |    0.6K               |
|    temporal_conv.hour_block.2      |    0.6K               |
|    temporal_conv.hour_block.3      |    0.6K               |
|   temporal_conv.day_block          |   6.1K                |
|    temporal_conv.day_block.0       |    2.1K               |
|    temporal_conv.day_block.1       |    4.0K               |
|   temporal_conv.half_day_block     |   38                  |
|    temporal_conv.half_day_block.0  |    10                 |
|    temporal_conv.half_day_block.1  |    14                 |
|    temporal_conv.half_day_bloc

In [None]:
import numpy as np
file_data = r"D:\Data_Store\Dataset\ST_Conv\std_mean\mean.npy"
data = np.load(file_data)
print(len(data))