In [2]:
import torch
def Inverse_Derivatives_embedding(embedded_data, dt, m=2, patch_len=16, stride=16, pad_len=0, type='CI', pad_type='L0', B=32, T=96, C=7):
    # input (BC, P, D)

    # 剥离C
    embedded_data = embedded_data.reshape(B, C, -1, (m + 1) * patch_len)

    # 获取原始时间长度 T
    total_patches = (T - patch_len) // stride + 1

    # 初始化一个张量，用来存放还原的数据
    recovered_data = torch.zeros((B, C, total_patches * stride + patch_len - stride, m + 1), device=embedded_data.device)

    # 手动将 unfolded 数据恢复为原始形状
    for i in range(total_patches):
        start = i * stride
        end = start + patch_len
        recovered_data[:, :, start:end, :] += embedded_data[:, :, i, :].reshape(B, C, patch_len, m + 1)

    # 根据pad_type移除填充
    if pad_type == 'L0':
        recovered_data = recovered_data[:, :, pad_len:]
    elif pad_type == 'R0':
        recovered_data = recovered_data[:, :, :-pad_len]
    elif pad_type == 'L':
        recovered_data = recovered_data[:, :, pad_len:]
    elif pad_type == 'R':
        recovered_data = recovered_data[:, :, :-pad_len]
    else:
        raise NotImplementedError

    # 恢复原始数据，逐步通过导数反推
    recovered_x = torch.zeros((B, T, C), device=embedded_data.device)
    recovered_x[:, :, :] = recovered_data[:, :, :, 0]  # 初始化为原始数据的第0阶

    for i in range(1, m + 1):
        recovered_x[:, dt:, :] += recovered_data[:, :-dt, :, i] * dt

    return recovered_x


In [None]:
x = torch.rand(224, 6, 64)
y = Inverse_Derivatives_embedding