## Encoder

In [None]:
import torch
import torch.nn as nn
from VQGAN.helper import ResidualBlock, NonLocalBlock, DownSampleBlock, GroupNorm, Swish

class Encoder(nn.Module):
    def __init__(self, args):
        super(Encoder, self).__init__()
        channels = [32, 32, 32, 64, 64]
        attn_resolutions = [2]
        num_res_blocks = 1
        resolution = 256

        # 初始卷积层
        self.conv_in = nn.Conv3d(args.image_channels, channels[0], kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1)) #改变最初的kernal size即可改变T维度的缩放
        
        # 第一层（含残差块和注意力模块）
        self.layer1 = self._make_layer(channels[0], channels[1], num_res_blocks, resolution, attn_resolutions)
        
        # 下采样与第二层
        self.downsample1 = DownSampleBlock(channels[1])
        self.layer2 = self._make_layer(channels[1], channels[2], num_res_blocks, resolution // 2, attn_resolutions)

        # Further downsampling and third layer
        self.downsample2 = DownSampleBlock(channels[2])
        self.layer3 = self._make_layer(channels[2], channels[3], num_res_blocks, resolution // 4, attn_resolutions)

        # 中间层的残差块和注意力模块
        self.mid_block1 = ResidualBlock(channels[3], channels[3])
        #self.mid_attn = NonLocalBlock(channels[3])
        self.mid_block2 = ResidualBlock(channels[3], channels[3])
        
        # 输出层的归一化、激活和最终卷积层
        self.norm_out = GroupNorm(channels[3])
        self.act_out = Swish()
        self.conv_out = nn.Conv3d(channels[3], args.latent_dim, kernel_size=3, stride=1, padding=(1,2,1))

    def _make_layer(self, in_channels, out_channels, num_res_blocks, resolution, attn_resolutions):
        layers = []
        for _ in range(num_res_blocks):
            layers.append(ResidualBlock(in_channels, out_channels))
            in_channels = out_channels
            if resolution in attn_resolutions:
                layers.append(NonLocalBlock(in_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        # 初始卷积
        x = self.conv_in(x)
        print(x.shape)
        # 第一层，并存储跳跃连接
        x = self.layer1(x)
        skip = x  # 保存第一层输出，用于后续跳跃连接
        print(x.shape)
        # 下采样，进入第二层
        x = self.downsample1(x)
        x = self.layer2(x)
        print(x.shape)
        # Further downsample and third layer
        x = self.downsample2(x)
        print(x.shape)
        x = self.layer3(x)
        print(x.shape)
        # 中间层的残差块和注意力模块
        x = self.mid_block1(x)
        #x = self.mid_attn(x)
        x = self.mid_block2(x)
        
        # 最终的归一化、激活和卷积输出层
        x = self.norm_out(x)
        x = self.act_out(x)
        x = self.conv_out(x)[:, :, :, :181, :360]
        
        return x




import torch
class Args:
    image_channels = 14  # RGB image channels
    latent_dim = 64     # Set a latent dimension for the output

# Instantiate Args and Encoder
args = Args()
encoder = Encoder(args)

# Create a random input tensor (batch size of 1, RGB image with resolution 256x256)
input_tensor = torch.randn(1, Args.image_channels, 2, 721, 1440).cuda()
#input_tensor = torch.randn(1, Args.image_channels, 2, 721, 1440).cuda()
model = encoder.cuda()

# Forward pass through the encoder
output = model(input_tensor)

# Print the output shape
print("Output shape:", output.shape)

torch.Size([1, 32, 2, 721, 1440])
torch.Size([1, 32, 2, 721, 1440])
torch.Size([1, 32, 2, 722, 1441]) after pad
torch.Size([1, 32, 2, 361, 721])
torch.Size([1, 32, 2, 362, 722]) after pad
torch.Size([1, 32, 2, 181, 361])
torch.Size([1, 64, 2, 181, 361])
Output shape: torch.Size([1, 64, 2, 181, 360])


## Encoder 4D

In [1]:
import torch
import torch.nn as nn
from VQGAN.helper import ResidualBlock, NonLocalBlock, DownSampleBlock, GroupNorm, Swish, ResidualBlock4D, DownSampleBlock4D
from VQGAN.conv import Conv4d

class Encoder4D(nn.Module):
    def __init__(self, args):
        super(Encoder4D, self).__init__()
        channels = [32, 32, 32, 64, 64]
        attn_resolutions = [2]
        num_res_blocks = 1
        resolution = 256

        # 初始卷积层
        self.conv_in = Conv4d(args.image_channels, channels[0], kernel_size=(1, 1, 3, 3), stride=1, padding=(0, 0, 1, 1))
        
        # 第一层（含残差块和注意力模块）
        self.layer1 = self._make_layer(channels[0], channels[1], num_res_blocks, resolution, attn_resolutions)
        
        # 下采样与第二层
        self.downsample1 = DownSampleBlock4D(channels[1])
        self.layer2 = self._make_layer(channels[1], channels[2], num_res_blocks, resolution // 2, attn_resolutions)

        # Further downsampling and third layer
        self.downsample2 = DownSampleBlock4D(channels[2])
        self.layer3 = self._make_layer(channels[2], channels[3], num_res_blocks, resolution // 4, attn_resolutions)

        # 中间层的残差块和注意力模块
        self.mid_block1 = ResidualBlock4D(channels[3], channels[3])
        #self.mid_attn = NonLocalBlock(channels[3])
        self.mid_block2 = ResidualBlock4D(channels[3], channels[3])
        
        # 输出层的归一化、激活和最终卷积层
        self.norm_out = GroupNorm(channels[3])
        self.act_out = Swish()
        self.conv_out = Conv4d(channels[3], args.latent_dim, kernel_size=(2, 1, 3, 3), stride=(2, 1, 1, 1), padding=(0,0,1,1))

    def _make_layer(self, in_channels, out_channels, num_res_blocks, resolution, attn_resolutions):
        layers = []
        for _ in range(num_res_blocks):
            layers.append(ResidualBlock4D(in_channels, out_channels))
            in_channels = out_channels
            if resolution in attn_resolutions:
                layers.append(NonLocalBlock(in_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        # 初始卷积
        x = self.conv_in(x)
        print(x.shape)
        # 第一层，并存储跳跃连接
        x = self.layer1(x)
        skip = x  # 保存第一层输出，用于后续跳跃连接
        print(x.shape)
        # 下采样，进入第二层
        x = self.downsample1(x)
        x = self.layer2(x)
        print(x.shape)
        # Further downsample and third layer
        x = self.downsample2(x)
        print(x.shape)
        x = self.layer3(x)
        print(x.shape)
        # 中间层的残差块和注意力模块
        x = self.mid_block1(x)
        #x = self.mid_attn(x)
        x = self.mid_block2(x)
        
        # 最终的归一化、激活和卷积输出层
        x = self.norm_out(x)
        x = self.act_out(x)
        x = self.conv_out(x)[:, :, :, :, :181, :360]
        
        return x




import torch
class Args:
    image_channels = 7  # RGB image channels
    latent_dim = 64     # Set a latent dimension for the output

# Instantiate Args and Encoder
args = Args()
encoder = Encoder4D(args)

# Create a random input tensor (batch size of 1, RGB image with resolution 256x256)
input_tensor = torch.randn(1, Args.image_channels, 4, 2, 721, 1440).cuda()
#input_tensor = torch.randn(1, Args.image_channels, 2, 721, 1440).cuda()
model = encoder.cuda()

# Forward pass through the encoder
output = model(input_tensor)

# Print the output shape
print("Output shape:", output.shape)

torch.Size([1, 32, 4, 2, 721, 1440])
torch.Size([1, 32, 4, 2, 721, 1440])
torch.Size([1, 32, 4, 2, 722, 1441]) after pad
torch.Size([1, 32, 4, 2, 361, 721])
torch.Size([1, 32, 4, 2, 362, 722]) after pad
torch.Size([1, 32, 4, 2, 181, 361])
torch.Size([1, 64, 4, 2, 181, 361])
Output shape: torch.Size([1, 64, 2, 2, 181, 360])


## Decoder

In [None]:
import torch
import torch.nn as nn
from VQGAN.helper import ResidualBlock, NonLocalBlock, UpSampleBlock, GroupNorm, Swish


# 定义 Decoder 类（略去细节，假设已实现）
class Decoder(nn.Module):
    def __init__(self, args):
        super(Decoder, self).__init__()
        channels = [64, 64, 32, 32]  # Decoder 的通道配置
        num_res_blocks = 1  # 与 Encoder 对齐

        # 初始卷积层
        self.conv_in = nn.Conv3d(args.latent_dim, channels[0], kernel_size=3, stride=1, padding=1)
        
        # 第一层残差块
        self.layer1 = self._make_layer(channels[0], channels[1], num_res_blocks)
        
        # 上采样和第二层残差块
        self.upsample1 = UpSampleBlock(channels[1])
        self.layer2 = self._make_layer(channels[1], channels[2], num_res_blocks)

        self.upsample2 = UpSampleBlock(channels[2])
        self.layer3 = self._make_layer(channels[2], channels[3], num_res_blocks)
        
        # 中间层的残差块
        self.mid_block1 = ResidualBlock(channels[3], channels[3])
        self.mid_block2 = ResidualBlock(channels[3], channels[3])
        
        # 最终输出层
        self.norm_out = GroupNorm(channels[3])
        self.act_out = Swish()
        self.conv_out = nn.Conv3d(channels[3], args.image_channels, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1))

    def _make_layer(self, in_channels, out_channels, num_res_blocks):
        # 创建指定数量的残差块
        layers = [ResidualBlock(in_channels, out_channels) for _ in range(num_res_blocks)]
        return nn.Sequential(*layers)

    def forward(self, x):
        # 初始卷积
        x = self.conv_in(x)

        # 第一层残差块
        x = self.layer1(x)

        # 上采样和第二层残差块
        x = self.upsample1(x)  # 上采样后通道数保持不变
        print(x.shape)
        x = self.layer2(x)     # 确保输入与 layer2 的期望通道数匹配

        x = self.upsample2(x)  # 上采样后通道数保持不变
        print(x.shape)
        x = self.layer3(x)     # 确保输入与 layer2 的期望通道数匹配
        
        # 中间层的残差块
        x = self.mid_block1(x)
        x = self.mid_block2(x)
        
        # 最终的归一化、激活和卷积输出层
        x = self.norm_out(x)
        x = self.act_out(x)
        print(x.shape)
        x = self.conv_out(x)[:, :, :, :721, :1440]
        
        return x
    
# 创建实例

class Args:
    image_channels = 14   # 输入图像的通道数（例如 RGB 图像为 3）
    latent_dim = 64      # 潜在空间的通道数
    num_codebook_vectors = 512
    beta = 0.25
args = Args()
decoder = Decoder(args).cuda()

# 输入张量
input_tensor = torch.randn(1, 64, 2, 181, 360).cuda()

# 编码过程
decoder = decoder(input_tensor)
print("decoder shape:", decoder.shape)  # 期望输出: torch.Size([1, 64, 360, 720])



torch.Size([1, 64, 2, 181, 360]) before inter
torch.Size([1, 64, 2, 362, 720])
torch.Size([1, 32, 2, 362, 720]) before inter
torch.Size([1, 32, 2, 724, 1440])
torch.Size([1, 32, 2, 724, 1440])
decoder shape: torch.Size([1, 14, 2, 721, 1440])


## Decoder4D

In [1]:
import torch
import torch.nn as nn
from VQGAN.helper import ResidualBlock, NonLocalBlock, UpSampleBlock, GroupNorm, Swish, UpSampleBlock4D, ResidualBlock4D
from VQGAN.conv import ConvTranspose4d
from VQGAN.conv import Conv4d

# 定义 Decoder 类（略去细节，假设已实现）
class Decoder(nn.Module):
    def __init__(self, args):
        super(Decoder, self).__init__()
        channels = [64, 64, 32, 32]  # Decoder 的通道配置
        num_res_blocks = 1  # 与 Encoder 对齐

        # 初始卷积层
        self.conv_in = Conv4d(args.latent_dim, channels[0], kernel_size=(1, 1, 3, 3), stride=1, padding=(0, 0, 1, 1))
        
        # 第一层残差块
        self.layer1 = self._make_layer(channels[0], channels[1], num_res_blocks)
        
        # 上采样和第二层残差块
        #self.upsample1 = UpSampleBlock4D(channels[1])
        self.upsample1 = ConvTranspose4d(channels[1], channels[1], kernel_size=(2, 1, 2, 2), stride=(2, 1, 2, 2),padding=(0, 0, 0, 0))
        self.layer2 = self._make_layer(channels[1], channels[2], num_res_blocks)

        #self.upsample2 = UpSampleBlock4D(channels[2])
        self.upsample2 = ConvTranspose4d(channels[2], channels[2], kernel_size=(1, 1, 2, 2), stride=(1, 1, 2, 2),padding=(0, 0, 0, 0))
        self.layer3 = self._make_layer(channels[2], channels[3], num_res_blocks)
        
        # 中间层的残差块
        self.mid_block1 = ResidualBlock4D(channels[3], channels[3])
        self.mid_block2 = ResidualBlock4D(channels[3], channels[3])
        
        # 最终输出层
        self.norm_out = GroupNorm(channels[3])
        self.act_out = Swish()
        self.conv_out = Conv4d(channels[3], args.image_channels, kernel_size=(1, 1, 3, 3), stride=1, padding=(0, 0, 1, 1))

    def _make_layer(self, in_channels, out_channels, num_res_blocks):
        # 创建指定数量的残差块
        layers = [ResidualBlock4D(in_channels, out_channels) for _ in range(num_res_blocks)]
        return nn.Sequential(*layers)

    def forward(self, x):
        # 初始卷积
        x = self.conv_in(x)

        # 第一层残差块
        x = self.layer1(x)

        # 上采样和第二层残差块
        x = self.upsample1(x)  # 上采样后通道数保持不变
        print(x.shape)
        x = self.layer2(x)     # 确保输入与 layer2 的期望通道数匹配

        x = self.upsample2(x)  # 上采样后通道数保持不变
        print(x.shape)
        x = self.layer3(x)     # 确保输入与 layer2 的期望通道数匹配
        
        # 中间层的残差块
        x = self.mid_block1(x)
        x = self.mid_block2(x)
        
        # 最终的归一化、激活和卷积输出层
        x = self.norm_out(x)
        x = self.act_out(x)
        print(x.shape)
        x = self.conv_out(x)[:, :, :, :721, :1440]
        
        return x
    
# 创建实例

class Args:
    image_channels = 14   # 输入图像的通道数（例如 RGB 图像为 3）
    latent_dim = 64      # 潜在空间的通道数
    num_codebook_vectors = 512
    beta = 0.25
args = Args()
decoder = Decoder(args).cuda()

# 输入张量
input_tensor = torch.randn(1, 64, 2, 2, 181, 360).cuda()

# 编码过程
decoder = decoder(input_tensor)
print("decoder shape:", decoder.shape)  # 期望输出: torch.Size([1, 64, 360, 720])



torch.Size([1, 64, 4, 2, 362, 720])
torch.Size([1, 32, 4, 2, 724, 1440])
torch.Size([1, 32, 4, 2, 724, 1440])
decoder shape: torch.Size([1, 14, 4, 2, 724, 1440])


## Codebook

In [5]:
import torch
import torch.nn as nn


class Codebook(nn.Module):
    def __init__(self, args):
        super(Codebook, self).__init__()
        self.num_codebook_vectors = args.num_codebook_vectors
        self.latent_dim = args.latent_dim
        self.beta = args.beta

        self.embedding = nn.Embedding(self.num_codebook_vectors, self.latent_dim)
        self.embedding.weight.data.uniform_(-1.0 / self.num_codebook_vectors, 1.0 / self.num_codebook_vectors)

    def forward(self, z):
        z = z.view(z.size(0), -1, z.size(3), z.size(4))
        z = z.permute(0, 2, 3, 1).contiguous()
        z_flattened = z.view(-1, self.latent_dim)

        d = torch.sum(z_flattened**2, dim=1, keepdim=True) + \
            torch.sum(self.embedding.weight**2, dim=1) - \
            2*(torch.matmul(z_flattened, self.embedding.weight.t()))

        min_encoding_indices = torch.argmin(d, dim=1)
        z_q = self.embedding(min_encoding_indices).view(z.shape)

        loss = torch.mean((z_q.detach() - z)**2) + self.beta * torch.mean((z_q - z.detach())**2)

        z_q = z + (z_q - z).detach()

        z_q = z_q.permute(0, 3, 1, 2)

        return z_q, min_encoding_indices, loss
    
class Args:
    num_codebook_vectors = 512  # 代码本中向量的数量
    latent_dim = 64  # 潜在空间的维度
    beta = 0.25  # 损失中的 beta 参数

args = Args()

# 创建 Codebook 实例
codebook = Codebook(args).cuda()

# 创建输入张量，形状为 [1, 64, 360, 720]
input_tensor = torch.randn(1, 2, 96, 181, 360).cuda()

# 前向传播
z_q, min_encoding_indices, q_loss = codebook(input_tensor)
print("离散的潜在表示 z_q 的形状:", z_q.shape)
print("编码索引的形状:", min_encoding_indices.shape)
print("量化损失:", q_loss.item())

离散的潜在表示 z_q 的形状: torch.Size([1, 192, 181, 360])
编码索引的形状: torch.Size([195480])
量化损失: 1.2486979961395264


In [4]:
import torch

# 创建一个随机张量
x = torch.randn(1, 2, 96, 181, 360).cuda()

# 调整维度，合并第二维和第三维
x_reshaped = x.view(x.size(0), -1, x.size(3), x.size(4))

# 打印结果维度
print("Original shape:", x.shape)
print("Reshaped shape:", x_reshaped.shape)


Original shape: torch.Size([1, 2, 96, 181, 360])
Reshaped shape: torch.Size([1, 192, 181, 360])


In [13]:
import torch.nn as nn
from VQGAN.helper import ResidualBlock, NonLocalBlock, DownSampleBlock, UpSampleBlock, GroupNorm, Swish


class Encoder(nn.Module):
    def __init__(self, args):
        super(Encoder, self).__init__()
        channels = [128, 128, 128, 256, 256, 512]
        attn_resolutions = [2]
        num_res_blocks = 1
        resolution = 256
        layers = [nn.Conv2d(args.image_channels, channels[0], 3, 1, 1)]
        for i in range(len(channels)-1):
            in_channels = channels[i]
            out_channels = channels[i + 1]
            for j in range(num_res_blocks):
                layers.append(ResidualBlock(in_channels, out_channels))
                in_channels = out_channels
                if resolution in attn_resolutions:
                    layers.append(NonLocalBlock(in_channels))
            if i != len(channels)-2:
                layers.append(DownSampleBlock(channels[i+1]))
                resolution //= 2
        layers.append(ResidualBlock(channels[-1], channels[-1]))
        layers.append(NonLocalBlock(channels[-1]))
        layers.append(ResidualBlock(channels[-1], channels[-1]))
        layers.append(GroupNorm(channels[-1]))
        layers.append(Swish())
        layers.append(nn.Conv2d(channels[-1], args.latent_dim, 3, 1, 1))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

In [14]:
import torch
class Args:
    image_channels = 14  # RGB image channels
    latent_dim = 64     # Set a latent dimension for the output

# Instantiate Args and Encoder
args = Args()
encoder = Encoder(args)

# Create a random input tensor (batch size of 1, RGB image with resolution 256x256)
input_tensor = torch.randn(1, Args.image_channels, 721, 1440)

# Forward pass through the encoder
output = encoder(input_tensor)

# Print the output shape
print("Output shape:", output.shape)

Output shape: torch.Size([1, 64, 45, 90])
