In [1]:
import torch
import torch.nn as nn
from my_tools import make_embedding, unfold_func, miniEncoder, miniDecoder, fold_func


class Geoformer(nn.Module):
    def __init__(self, mypara):
        super().__init__()
        self.mypara = mypara
        d_size = mypara.d_size
        self.device = mypara.device
        if self.mypara.needtauxy:
            self.cube_dim = (
                (mypara.input_channal + 2) * mypara.patch_size[0] * mypara.patch_size[1]
            )
        else:
            self.cube_dim = (
                mypara.input_channal * mypara.patch_size[0] * mypara.patch_size[1]
            )
        self.predictor_emb = make_embedding(
            cube_dim=self.cube_dim,
            d_size=d_size,
            emb_spatial_size=mypara.emb_spatial_size,
            max_len=mypara.input_length,
            device=self.device,
        )
        self.predictand_emb = make_embedding(
            cube_dim=self.cube_dim,
            d_size=d_size,
            emb_spatial_size=mypara.emb_spatial_size,
            max_len=mypara.output_length,
            device=self.device,
        )
        enc_layer = miniEncoder(
            d_size, mypara.nheads, mypara.dim_feedforward, mypara.dropout
        )
        dec_layer = miniDecoder(
            d_size, mypara.nheads, mypara.dim_feedforward, mypara.dropout
        )
        self.encoder = multi_enc_layer(
            enc_layer=enc_layer, num_layers=mypara.num_encoder_layers
        )
        self.decoder = multi_dec_layer(
            dec_layer=dec_layer, num_layers=mypara.num_decoder_layers
        )
        self.linear_output = nn.Linear(d_size, self.cube_dim)

    def forward(
        self,
        predictor,
        predictand,
        in_mask=None,
        enout_mask=None,
        train=True,
        sv_ratio=0,
    ):
        """
        Args:
            predictor: (batch, lb, C, H, W)
            predictand: (batch, pre_len, C, H, W)
        Returns:
            outvar_pred: (batch, pre_len, C, H, W)
        """
        en_out = self.encode(predictor=predictor, in_mask=in_mask)
        if train:
            with torch.no_grad():
                connect_inout = torch.cat(
                    [predictor[:, -1:], predictand[:, :-1]], dim=1
                )
                out_mask = self.make_mask_matrix(connect_inout.size(1))
                outvar_pred = self.decode(
                    connect_inout,
                    en_out,
                    out_mask,
                    enout_mask,
                )
            if sv_ratio > 1e-7:
                supervise_mask = torch.bernoulli(
                    sv_ratio
                    * torch.ones(predictand.size(0), predictand.size(1) - 1, 1, 1, 1)
                ).to(self.device)
            else:
                supervise_mask = 0
            predictand = (
                supervise_mask * predictand[:, :-1]
                + (1 - supervise_mask) * outvar_pred[:, :-1]
            )
            predictand = torch.cat([predictor[:, -1:], predictand], dim=1)
            # predicting
            outvar_pred = self.decode(
                predictand,
                en_out,
                out_mask,
                enout_mask,
            )
        else:
            assert predictand is None
            predictand = predictor[:, -1:]
            for t in range(self.mypara.output_length):
                out_mask = self.make_mask_matrix(predictand.size(1))
                outvar_pred = self.decode(
                    predictand,
                    en_out,
                    out_mask,
                    enout_mask,
                )
                predictand = torch.cat([predictand, outvar_pred[:, -1:]], dim=1)
        return outvar_pred

    def encode(self, predictor, in_mask):
        """
        predictor: (B, lb, C, H, W)
        en_out: (Batch, S, lb, d_size)
        """
        lb = predictor.size(1)
        predictor = unfold_func(predictor, self.mypara.patch_size)
        predictor = predictor.reshape(predictor.size(0), lb, self.cube_dim, -1).permute(
            0, 3, 1, 2
        )
        predictor = self.predictor_emb(predictor)
        en_out = self.encoder(predictor, in_mask)
        return en_out

    def decode(self, predictand, en_out, out_mask, enout_mask):
        """
        Args:
            predictand: (B, pre_len, C, H, W)
        output:
            (B, pre_len, C, H, W)
        """
        H, W = predictand.size()[-2:]
        T = predictand.size(1)
        predictand = unfold_func(predictand, self.mypara.patch_size)
        predictand = predictand.reshape(
            predictand.size(0), T, self.cube_dim, -1
        ).permute(0, 3, 1, 2)
        predictand = self.predictand_emb(predictand)
        output = self.decoder(predictand, en_out, out_mask, enout_mask)
        output = self.linear_output(output).permute(0, 2, 3, 1)
        output = output.reshape(
            predictand.size(0),
            T,
            self.cube_dim,
            H // self.mypara.patch_size[0],
            W // self.mypara.patch_size[1],
        )
        output = fold_func(
            output, output_size=(H, W), kernel_size=self.mypara.patch_size
        )
        return output

    def make_mask_matrix(self, sz: int):
        mask = (torch.triu(torch.ones(sz, sz)) == 0).T
        return mask.to(self.mypara.device)


class multi_enc_layer(nn.Module):
    def __init__(self, enc_layer, num_layers):
        super().__init__()
        self.layers = nn.ModuleList([enc_layer for _ in range(num_layers)])

    def forward(self, x, mask=None):
        for layer in self.layers:
            x = layer(x, mask)
        return x


class multi_dec_layer(nn.Module):
    def __init__(self, dec_layer, num_layers):
        super().__init__()
        self.layers = nn.ModuleList([dec_layer for _ in range(num_layers)])

    def forward(self, x, en_out, out_mask, enout_mask):
        for layer in self.layers:
            x = layer(x, en_out, out_mask, enout_mask)
        return x


In [5]:
predictor.shape

torch.Size([1, 10, 3, 64, 64])

In [2]:
import torch

# 假设 mypara 是模型的超参数类或对象
class MyPara:
    def __init__(self):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.input_channal = 8  # 输入通道数，例如 RGB 图像
        self.patch_size = (8, 8)  # Patch 大小
        self.d_size = 64  # 嵌入维度
        self.emb_spatial_size = 8  # 空间嵌入大小
        self.input_length = 16  # 输入时间步长
        self.output_length = 1  # 输出时间步长
        self.nheads = 2  # 多头注意力头数
        self.dim_feedforward = 256  # 前馈网络维度
        self.dropout = 0.1  # Dropout 概率
        self.num_encoder_layers = 1  # 编码器层数
        self.num_decoder_layers = 1  # 解码器层数
        self.needtauxy = False  # 是否需要附加的时间和空间坐标信息

# 初始化 mypara
mypara = MyPara()

# 初始化模型
model = Geoformer(mypara).to(mypara.device)

# 随机生成输入数据
batch_size = 1
channels = mypara.input_channal
height, width = 721, 1440  # 假设输入图像大小
patch_size = mypara.patch_size
if height % patch_size[0] != 0 or width % patch_size[1] != 0:
    height = (height // patch_size[0]) * patch_size[0]
    width = (width // patch_size[1]) * patch_size[1]

input_length = mypara.input_length
output_length = mypara.output_length

# 生成随机输入和目标
predictor = torch.rand(batch_size, input_length, channels, height, width).to(mypara.device)
predictand = torch.rand(batch_size, output_length, channels, height, width).to(mypara.device)

# 推理测试
model.eval()  # 设置模型为评估模式
with torch.no_grad():
    output = model(predictor, None, train=False)

# 打印输出形状
print("Output shape:", output.shape)


first embedded_space shape: torch.Size([1, 8, 1, 64])
first x shape: torch.Size([1, 16200, 16, 64])
x shape: torch.Size([1, 16200, 16, 64])
pe_time shape: torch.Size([1, 1, 16, 64])
embedded_space shape: torch.Size([1, 16200, 16, 64])
first embedded_space shape: torch.Size([1, 8, 1, 64])
first x shape: torch.Size([1, 16200, 1, 64])
x shape: torch.Size([1, 16200, 1, 64])
pe_time shape: torch.Size([1, 1, 1, 64])
embedded_space shape: torch.Size([1, 16200, 1, 64])
Output shape: torch.Size([1, 1, 8, 720, 1440])


In [2]:
import numpy as np
import torch
input = np.load("E:/data/input_25_all.npy")
nsr = np.load("E:/data/nsr_25_all.npy")
print('load data ok!')
nsr_expanded = nsr[:, np.newaxis, :, :]
input_array = np.concatenate((input, nsr_expanded), axis=1)
input_array = torch.tensor(input_array)
mean_all = torch.tensor([[[[ 2.8679e+02]],
                          [[ 1.0096e+05]],
                          [[-5.3626e+06]],
                          [[-5.1725e-02]],
                          [[ 1.8698e-01]],
                          [[ 5.4089e+04]],
                          [[ 1.3745e+04]],
                          [[ 1.1180e+07]]]])
std_all = torch.tensor([[[[1.1627e+01]],
                         [[1.0610e+03]],
                         [[4.9920e+06]],
                         [[3.8811e+00]],
                         [[2.4887e+00]],
                         [[3.2341e+03]],
                         [[1.3297e+03]],
                         [[7.8841e+06]]]])


# 标准化数据
normalized_data = (input_array - mean_all) / std_all
normalized_data[:, 0, :, :] = torch.nan_to_num(normalized_data[:, 0, :, :], nan=0.0)
torch.save(normalized_data, "normalized_train.pt")

load data ok!
