# 模块代码

In [3]:
import torch

from torch import nn
from torch.nn import init
class UNet3D(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, init_features=32):
        super(UNet3D, self).__init__()

        features = init_features
        # 编码器部分
        self.encoder1 = UNet3D._block(in_channels, features, name="enc1")
        self.pool1 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.encoder2 = UNet3D._block(features, features * 2, name="enc2")
        self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.encoder3 = UNet3D._block(features * 2, features * 4, name="enc3")
        self.pool3 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.encoder4 = UNet3D._block(features * 4, features * 8, name="enc4")
        self.pool4 = nn.MaxPool3d(kernel_size=2, stride=2)

        # 瓶颈层
        self.bottleneck = UNet3D._block(features * 8, features * 16, name="bottleneck")

        # 解码器部分
        self.upconv4 = nn.ConvTranspose3d(features * 16, features * 8, kernel_size=2, stride=2)
        self.decoder4 = UNet3D._block(features * 16, features * 8, name="dec4")
        self.upconv3 = nn.ConvTranspose3d(features * 8, features * 4, kernel_size=2, stride=2)
        self.decoder3 = UNet3D._block(features * 8, features * 4, name="dec3")
        self.upconv2 = nn.ConvTranspose3d(features * 4, features * 2, kernel_size=2, stride=2)
        self.decoder2 = UNet3D._block(features * 4, features * 2, name="dec2")
        self.upconv1 = nn.ConvTranspose3d(features * 2, features, kernel_size=2, stride=2)
        self.decoder1 = UNet3D._block(features * 2, features, name="dec1")

        # 输出层
        self.conv = nn.Conv3d(in_channels=features, out_channels=out_channels, kernel_size=1)

    def forward(self, x):
        # 编码器
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))
        enc4 = self.encoder4(self.pool3(enc3))

        # 瓶颈层
        bottleneck = self.bottleneck(self.pool4(enc4))

        # 解码器
        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)

        # 输出层
        return torch.sigmoid(self.conv(dec1))

    @staticmethod
    def _block(in_channels, features, name):
        return nn.Sequential(
            nn.Conv3d(
                in_channels=in_channels,
                out_channels=features,
                kernel_size=3,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm3d(num_features=features),
            nn.ReLU(inplace=True),
            nn.Conv3d(
                in_channels=features,
                out_channels=features,
                kernel_size=3,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm3d(num_features=features),
            nn.ReLU(inplace=True),
        )

In [4]:
def init_unet_weights(model, init_method='kaiming', gain=1.0):
    """
    初始化 UNet 模型的权重

    参数:
    - model: UNet 模型
    - init_method: 初始化方法，可选 'xavier', 'kaiming', 'normal', 'uniform', 'zeros', 'ones'
    - gain: 用于 Xavier 和 Kaiming 初始化的增益因子
    """
    for name, param in model.named_parameters():
        if 'weight' in name:
            if init_method == 'xavier':
                if param.dim() >= 2:  # 适用于卷积和全连接层
                    init.xavier_uniform_(param, gain=gain)
                else:
                    init.normal_(param, mean=0.0, std=gain)
            elif init_method == 'kaiming':
                if param.dim() >= 2:
                    init.kaiming_uniform_(param, mode='fan_in', nonlinearity='relu')
                else:
                    init.normal_(param, mean=0.0, std=gain)
            elif init_method == 'normal':
                init.normal_(param, mean=0.0, std=gain)
            elif init_method == 'uniform':
                init.uniform_(param, a=-gain, b=gain)
            elif init_method == 'zeros':
                init.zeros_(param)
            elif init_method == 'ones':
                init.ones_(param)
            else:
                raise ValueError(f"不支持的初始化方法: {init_method}")
        elif 'bias' in name:
            init.zeros_(param)  # 偏置初始化为 0

In [5]:
# 创建 UNet 模型实例
unet_model = UNet3D(in_channels=1, out_channels=1, init_features=32)

# 初始化权重
init_unet_weights(unet_model, init_method='kaiming', gain=1.0)

# 打印初始化后的权重
for name, param in unet_model.named_parameters():
    print(f"{name}: {param.data}")

encoder1.0.weight: tensor([[[[[ 0.2731,  0.0355, -0.2945],
           [-0.3625,  0.3545,  0.4521],
           [ 0.4416, -0.2019, -0.0697]],

          [[ 0.3156, -0.3857, -0.4274],
           [ 0.1869, -0.1828, -0.3395],
           [-0.1326,  0.3905, -0.1627]],

          [[-0.1255, -0.3285,  0.3013],
           [-0.2309,  0.1467,  0.4419],
           [ 0.0172,  0.2533,  0.3436]]]],



        [[[[-0.3671,  0.1547,  0.2232],
           [-0.1733,  0.0373, -0.1901],
           [ 0.4335, -0.0732,  0.3145]],

          [[-0.2750, -0.4008, -0.2283],
           [ 0.4510,  0.1334, -0.2706],
           [-0.3896, -0.3180,  0.3793]],

          [[ 0.1505, -0.2814, -0.2281],
           [-0.2015,  0.3748, -0.0906],
           [-0.2720, -0.4594, -0.1935]]]],



        [[[[ 0.4473, -0.3912, -0.0800],
           [-0.3515, -0.0071,  0.4003],
           [-0.1885, -0.4669,  0.2405]],

          [[-0.4402, -0.0878,  0.3297],
           [ 0.4555, -0.0183, -0.4329],
           [-0.4660,  0.4001, -0.2567]]

## EMA


In [4]:
import torch

from torch import nn

class EMA(nn.Module):
    def __init__(self, channels, c2=None, factor=32):
        super(EMA, self).__init__()
        self.group = factor
        assert channels // self.group > 0
        self.softmax = nn.Softmax(dim=-1)
        self.averagePooling = nn.AdaptiveAvgPool2d((1,1))
        self.Pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.Pool_w = nn.AdaptiveAvgPool2d((1, None))

        self.groupNorm = nn.GroupNorm(channels // self.group, channels//self.group)
        self.conv1x1 = nn.Conv2d(channels // self.group, channels // self.group, kernel_size=1, stride=1, padding=0)
        self.conv3x3 = nn.Conv2d(channels // self.group, channels // self.group, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        b, c, h, w = x.size()
        group_x = x.reshape(b*self.group, -1, h, w)
        x_h = self.Pool_h(group_x)  # 高度方向池化
        x_w = self.Pool_w(group_x).permute(0, 1, 3, 2)  # 宽度方向池化

        hw = self.conv1x1(torch.cat([x_h, x_w], dim=2)) # 拼接之后卷积
        x_h, x_w = torch.split(hw, [h, w], dim=2)       # 拆分

        # 1x1路径
        x1 = self.groupNorm(group_x * x_h.sigmoid() * x_w.permute(0, 1, 3, 2).sigmoid())          # 高度的注意力
        x11 = self.softmax(self.averagePooling(x1).reshape(b*self.group, -1, 1).permute(0, 2, 1)) # 对 x1 进行平均池化，然后进行 softmax 操作
        x12 = x1.reshape(b*self.group, c//self.group, -1)

        # 3x3路径
        x2 = self.conv3x3(group_x) # 通过 3x3卷积层
        x21 = self.softmax(self.averagePooling(x2).reshape(b*self.group, -1, 1).permute(0, 2, 1)) # 对 x2 进行平均池化，然后进行 softmax 操作
        x22 = x2.reshape(b*self.group, c//self.group, -1)

        weights = (torch.matmul(x11, x22) + torch.matmul(x21, x12)).reshape(b * self.group, 1, h, w)
        return (group_x * weights.sigmoid()).reshape(b, c, h, w)

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

ema = EMA(128).to(device)
input_data = torch.rand(1, 128, 128, 128).to(device)
output_data = ema(input_data)


print(ema)

print(output_data.shape)

EMA(
  (softmax): Softmax(dim=-1)
  (averagePooling): AdaptiveAvgPool2d(output_size=(1, 1))
  (Pool_h): AdaptiveAvgPool2d(output_size=(None, 1))
  (Pool_w): AdaptiveAvgPool2d(output_size=(1, None))
  (groupNorm): GroupNorm(4, 4, eps=1e-05, affine=True)
  (conv1x1): Conv2d(4, 4, kernel_size=(1, 1), stride=(1, 1))
  (conv3x3): Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
torch.Size([1, 128, 128, 128])


## EMA3D

In [19]:
import torch.nn.functional as F

# class SoftPool3D(nn.Module):
#     def __init__(self, kernel_size=2, stride=2, padding=0):
#         super(SoftPool3D, self).__init__()
        
#         self.kernel_size = kernel_size
#         self.stride = stride
#         self.padding = padding
        
#     def forward(self, x):
        
#         x_exp = torch.exp(x)
        
#         sum_exp = F.max_pool3d(x_exp, self.kernel_size, self.stride, self.padding)
        
#         sum_x = F.max_pool3d(x, self.kernel_size, self.stride, self.padding)
        
#         soft_pool = sum_x / (sum_exp + 1e-8)
        
#         return soft_pool

class SE(nn.Module):
    def __init__(self, in_channels, reduction_ratio=4):
        super(SE, self).__init__()

        if in_channels // reduction_ratio <= 0:
                    raise ValueError(f"Reduction ratio {reduction_ratio} is too large for the number of input channels {in_channels}.")
        self.avg_pool = nn.AdaptiveAvgPool3d(1)
        self.fc = nn.Sequential(
            nn.Conv3d(in_channels, in_channels // reduction_ratio, kernel_size=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv3d(in_channels // reduction_ratio, in_channels, kernel_size=1, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _, _ = x.size()
        y = self.avg_pool(x)
        y = self.fc(y)
        return x * y.expand_as(x)

class EMA3D(nn.Module):
    def __init__(self, channels, factor=8):
        super(EMA3D, self).__init__()
        self.group = factor
        assert channels // self.group > 0
        self.softmax = nn.Softmax(dim=-1)
        self.averagePooling = nn.AdaptiveAvgPool3d((1, 1, 1))  # 3D 全局平均池化
        self.maxPooling = nn.AdaptiveMaxPool3d((1, 1, 1))      # 3D 全局最大池化
        self.Pool_h = nn.AdaptiveAvgPool3d((None, 1, 1))       # 高度方向池化
        self.Pool_w = nn.AdaptiveAvgPool3d((1, None, 1))       # 宽度方向池化
        self.Pool_d = nn.AdaptiveAvgPool3d((1, 1, None))       # 深度方向池化

        self.groupNorm = nn.GroupNorm(channels // self.group, channels // self.group)
        self.conv1x1x1 = nn.Sequential(
            nn.Conv3d(channels // self.group, channels // self.group, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm3d(channels // self.group),
            nn.ReLU(inplace=True)
        )
        self.conv3x3x3 = nn.Sequential(
            nn.Conv3d(channels // self.group, channels // self.group, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm3d(channels // self.group),
            nn.ReLU(inplace=True)
        )
        self.SE = SE(channels // self.group)

    def forward(self, x):
        b, c, d, h, w = x.size()
        group_x = x.reshape(b * self.group, -1, d, h, w)  # 分组处理

        # 高度、宽度、深度方向池化
        x_c = self.maxPooling(group_x)  # [B*G, C/G, 1, 1, 1]
        x_h = self.Pool_h(group_x)  # [B*G, C/G, D, 1, 1]
        x_w = self.Pool_w(group_x).permute(0, 1, 3, 2, 4)  # [B*G, C/G, 1, H, 1]
        x_d = self.Pool_d(group_x).permute(0, 1, 4, 3, 2)  # [B*G, C/G, 1, 1, W]

        # 拼接并卷积
        hwd = self.conv1x1x1(torch.cat([x_h, x_w, x_d], dim=2))  # 拼接后卷积
        x_h, x_w, x_d = torch.split(hwd, [d, h, w], dim=2)       # 拆分

        # Apply sigmoid activation
        x_h_sigmoid = x_h.sigmoid().view(b*self.group, c // self.group, d, 1, 1)
        x_w_sigmoid = x_w.sigmoid().view(b*self.group, c // self.group, 1, h, 1)
        x_d_sigmoid = x_d.sigmoid().view(b*self.group, c // self.group, 1, 1, w)
        
        # Apply attention maps using broadcasting
        x_attended = group_x * x_h_sigmoid * x_w_sigmoid * x_d_sigmoid
        
        x1 = self.groupNorm(group_x * x_attended)  # 高度、宽度、深度注意力
        x11 = self.softmax(self.averagePooling(x1).reshape(b * self.group, -1, 1).permute(0, 2, 1))  # 全局平均池化 + softmax
        x12 = x1.reshape(b * self.group, c // self.group, -1)

        # 3x3x3 路径
        x2 = self.SE(self.conv3x3x3(group_x))  # 通过 3x3x3 卷积层
        x21 = self.softmax(self.averagePooling(x2).reshape(b * self.group, -1, 1).permute(0, 2, 1))  # 全局平均池化 + softmax
        x22 = x2.reshape(b * self.group, c // self.group, -1)

        # 计算权重
        weights = (torch.matmul(x11, x22) + torch.matmul(x21, x12)).reshape(b * self.group, -1, d, h, w)
        return (group_x * weights.sigmoid()).reshape(b, c, d, h, w)

In [18]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

ema3d = EMA3D(128).to(device)
input_data = torch.rand(1, 128, 32, 32, 32).to(device)  # 3D 输入
output_data = ema3d(input_data)

print(ema3d)
print(output_data.shape)  # 输出形状应与输入形状一致

EMA3D(
  (softmax): Softmax(dim=-1)
  (averagePooling): AdaptiveAvgPool3d(output_size=(1, 1, 1))
  (maxPooling): AdaptiveMaxPool3d(output_size=(1, 1, 1))
  (Pool_h): AdaptiveAvgPool3d(output_size=(None, 1, 1))
  (Pool_w): AdaptiveAvgPool3d(output_size=(1, None, 1))
  (Pool_d): AdaptiveAvgPool3d(output_size=(1, 1, None))
  (groupNorm): GroupNorm(16, 16, eps=1e-05, affine=True)
  (conv1x1x1): Sequential(
    (0): Conv3d(16, 16, kernel_size=(1, 1, 1), stride=(1, 1, 1))
    (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (conv3x3x3): Sequential(
    (0): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (SE): SE(
    (avg_pool): AdaptiveAvgPool3d(output_size=1)
    (fc): Sequential(
      (0): Conv3d(16, 4, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
      (1): ReLU

## ISSA

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ISSA(nn.Module):
    def __init__(self, channels, reduction_ratio=8):
        super(ISSA, self).__init__()
        
        self.channels = channels
        self.reduction = reduction_ratio
        
        # Pointwise convolutions to generate Q, K, V
        self.query_conv = nn.Conv3d(channels, channels // reduction_ratio, kernel_size=1)
        self.key_conv = nn.Conv3d(channels, channels // reduction_ratio, kernel_size=1)
        self.value_conv = nn.Conv3d(channels, channels, kernel_size=1)
        
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, x):
        B, C, L1, L2, L3 = x.size()
        
        Q = self.query_conv(x)
        K = self.key_conv(x)
        V = self.value_conv(x)
        
        Q_reshaped = Q.permute(0, 2, 3, 1, 4).contiguous().view(B, L1*L2, L3, -1)
        K_reshaped = K.permute(0, 2, 3, 1, 4).contiguous().view(B, L1*L2, -1, L3)
        V_reshaped = V.permute(0, 2, 3, 1, 4).contiguous().view(B, L1*L2, L3, -1)
        
        scale_factor = (self.channels // self.reduction) ** 0.5
        
        attention_scores = torch.matmul(Q_reshaped, K_reshaped) / scale_factor
        attention_weights = self.softmax(attention_scores)
        
        attention_out = torch.matmul(attention_weights, V_reshaped)
        
        return attention_out.view(B, C, L1, L2, L3)

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

ema3d = ISSA(128).to(device)
input_data = torch.rand(1, 128, 32, 32, 32).to(device)  # 3D 输入
output_data = ema3d(input_data)

print(ema3d)
print(output_data.shape)  # 输出形状应与输入形状一致

ISSA(
  (query_conv): Conv3d(128, 16, kernel_size=(1, 1, 1), stride=(1, 1, 1))
  (key_conv): Conv3d(128, 16, kernel_size=(1, 1, 1), stride=(1, 1, 1))
  (value_conv): Conv3d(128, 128, kernel_size=(1, 1, 1), stride=(1, 1, 1))
  (softmax): Softmax(dim=-1)
)
torch.Size([1, 128, 32, 32, 32])


## SelfAttention

In [10]:
class SelfAttention3D(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention3D, self).__init__()
        self.query_conv =   nn.Conv3d(in_channels, in_channels // 8, kernel_size=1) # 全连接
        self.key_conv   =   nn.Conv3d(in_channels, in_channels // 8, kernel_size=1)
        self.value_conv =   nn.Conv3d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch_size, C, D, H, W = x.size()
        
        # Permute and reshape the input to separate batch and spatial dimensions
        query = self.query_conv(x).view(batch_size, -1, D, H, W).permute(0, 2, 3, 4, 1).contiguous()
        key = self.key_conv(x).view(batch_size, -1, D, H, W).permute(0, 2, 3, 4, 1).contiguous()
        value = self.value_conv(x).view(batch_size, -1, D, H, W).permute(0, 2, 3, 4, 1).contiguous()
        
        # Calculate the attention scores
        energy = torch.matmul(query, key.transpose(-1, -2))
        attention = F.softmax(energy, dim=-1)
        
        # Apply the attention to the values
        out = torch.matmul(attention, value)
        
        # Reshape and permute the output to match the original input shape
        out = out.permute(0, 4, 1, 2, 3).contiguous().view(batch_size, C, D, H, W)
        
        # Scale and add a residual connection
        out = self.gamma * out + x
        
        return out

In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

ema3d = SelfAttention3D(128).to(device)
input_data = torch.rand(1, 128, 32, 32, 32).to(device)  # 3D 输入
output_data = ema3d(input_data)

print(ema3d)
print(output_data.shape)  # 输出形状应与输入形状一致

SelfAttention3D(
  (query_conv): Conv3d(128, 16, kernel_size=(1, 1, 1), stride=(1, 1, 1))
  (key_conv): Conv3d(128, 16, kernel_size=(1, 1, 1), stride=(1, 1, 1))
  (value_conv): Conv3d(128, 128, kernel_size=(1, 1, 1), stride=(1, 1, 1))
)
torch.Size([1, 128, 32, 32, 32])


## CPCA

In [12]:
class CPCA_ChannelAttention(nn.Module):
    def __init__(self, input_channels, internal_neurons):
        super(CPCA_ChannelAttention, self).__init__()
        self.fc1 = nn.Conv3d(in_channels=input_channels, out_channels=internal_neurons, kernel_size=1, stride=1, bias=True)
        self.fc2 = nn.Conv3d(in_channels=internal_neurons, out_channels=input_channels, kernel_size=1, stride=1, bias=True)
        self.input_channels = input_channels

    def forward(self, inputs):
        x1 = F.adaptive_avg_pool2d(inputs, output_size=(1, 1))
        x1 = self.fc1(x1)
        x1 = F.relu(x1, inplace=True)
        x1 = self.fc2(x1)
        x1 = torch.sigmoid(x1)
        x2 = F.adaptive_max_pool2d(inputs, output_size=(1, 1))
        x2 = self.fc1(x2)
        x2 = F.relu(x2, inplace=True)
        x2 = self.fc2(x2)
        x2 = torch.sigmoid(x2)
        x = x1 + x2
        x = x.view(-1, self.input_channels, 1, 1)
        return inputs * x

class CPCA(nn.Module):
    def __init__(self, channels, channelAttention_reduce=4):
        super().__init__()
        self.ca = CPCA_ChannelAttention(input_channels=channels, internal_neurons=channels // channelAttention_reduce)
        self.dconv5_5 = nn.Conv3d(channels, channels, kernel_size=5, padding=2, groups=channels)
        self.dconv1_7 = nn.Conv3d(channels, channels, kernel_size=(1, 7), padding=(0, 3), groups=channels)
        self.dconv7_1 = nn.Conv3d(channels, channels, kernel_size=(7, 1), padding=(3, 0), groups=channels)
        self.dconv1_11 = nn.Conv3d(channels, channels, kernel_size=(1, 11), padding=(0, 5), groups=channels)
        self.dconv11_1 = nn.Conv3d(channels, channels, kernel_size=(11, 1), padding=(5, 0), groups=channels)
        self.dconv1_21 = nn.Conv3d(channels, channels, kernel_size=(1, 21), padding=(0, 10), groups=channels)
        self.dconv21_1 = nn.Conv3d(channels, channels, kernel_size=(21, 1), padding=(10, 0), groups=channels)
        self.conv = nn.Conv3d(channels, channels, kernel_size=1, padding=0)
        self.act = nn.GELU()

    def forward(self, inputs):
        inputs = self.conv(inputs)
        inputs = self.act(inputs)
        inputs = self.ca(inputs)
        x_init = self.dconv5_5(inputs)
        x_1 = self.dconv1_7(x_init)
        x_1 = self.dconv7_1(x_1)
        x_2 = self.dconv1_11(x_init)
        x_2 = self.dconv11_1(x_2)
        x_3 = self.dconv1_21(x_init)
        x_3 = self.dconv21_1(x_3)
        x = x_1 + x_2 + x_3 + x_init
        spatial_att = self.conv(x)
        out = spatial_att * inputs
        out = self.conv(out)
        return out

In [24]:
dconv1_1_7 = nn.Conv3d(32, 64, kernel_size=(1, 1, 7), padding=(0, 0, 3), groups=32)
dconv1_7_1 = nn.Conv3d(32, 64, kernel_size=(1, 7, 1), padding=(0, 3, 0), groups=32)
dconv1_1_7 = nn.Conv3d(32, 64, kernel_size=(7, 1, 1), padding=(3, 0, 0), groups=32)

input_data = torch.rand(1, 32, 32, 32, 32)

out = dconv1_1_7(input_data)
out.shape

torch.Size([1, 64, 32, 32, 32])

In [13]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

ema3d = CPCA(128).to(device)
input_data = torch.rand(1, 128, 32, 32, 32).to(device)  # 3D 输入
output_data = ema3d(input_data)

print(ema3d)
print(output_data.shape)  # 输出形状应与输入形状一致

RuntimeError: adaptive_max_pool2d(): Expected 3D or 4D tensor, but got: [1, 128, 32, 32, 32]

## SKF

In [14]:
class SKFusionv3(nn.Module):
    def __init__(self, dim=1, height=2, reduction=4, kernel_sizes=[3, 5, 7]):
        super(SKFusionv3, self).__init__()
        
        self.height = height
        self.kernel_sizes = kernel_sizes
        # 全局平均池化
        self.avg_pool = nn.AdaptiveAvgPool3d(1)
        # 多尺度卷积层，使用不同的卷积核大小
        self.conv1d_layers = nn.ModuleList([
            nn.Conv1d(1, self.height, kernel_size, stride=1, padding=kernel_size//2)
            for kernel_size in self.kernel_sizes
        ])
        # 组合最大池化和平均池化
        self.max_pool = nn.AdaptiveMaxPool3d(1)
        self.combined_pool = nn.Sequential(
            nn.Conv1d(len(kernel_sizes)*self.height, self.height, kernel_size=1, bias=False),
            nn.BatchNorm1d(self.height),
            nn.ReLU(inplace=True)
        )
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, in_feats):
        B, C, D, H, W = in_feats[0].shape

        # 假设 in_feats 是一个列表，长度为 height，每个元素形状为 (B, C, D, H, W)
        in_feats = torch.cat(in_feats, dim=1)  # (B, C * height, D, H, W)
        in_feats = in_feats.view(B, self.height, C, D, H, W)  # (B, height, C, D, H, W)

        # 特征求和
        feats_sum = torch.sum(in_feats, dim=1)  # (B, C, D, H, W)

        # 全局平均池化和全局最大池化
        avg_attn = self.avg_pool(feats_sum)  # (B, C, 1, 1, 1)
        max_attn = self.max_pool(feats_sum)  # (B, C, 1, 1, 1)
        
        # 结合两个池化的特征
        combined_attn = torch.cat([avg_attn, max_attn], dim=1)  # (B, 2C, 1, 1, 1)
        # combined_attn = avg_attn + max_attn
        combined_attn = combined_attn.squeeze(-1).squeeze(-1).permute(0, 2, 1)
        
        # 使用多个不同卷积核大小的 Conv1d 层进行特征变换
        conv_feats = []
        for conv in self.conv1d_layers:
            feat = conv(combined_attn)  # (B, height, 2C, 1)
            feat = F.relu(feat)
            conv_feats.append(feat)
        
        # 将多个卷积层的输出拼接
        conv_feats = torch.cat(conv_feats, dim=1)  # (B, height * len(kernel_sizes), 2C, 1)
        conv_feats = conv_feats.squeeze(-1)  # (B, height * len(kernel_sizes), 2C)
        
        # 通过 1x1 卷积融合不同尺度的特征
        attn = self.combined_pool(conv_feats)  # (B, height, 2C, 1)
        attn = attn.squeeze(-1)  # (B, height, 2C)
        
        # 生成注意力权重
        attn = attn.view(B, self.height, C, 2)  # 假设 2 是因池化方式数量（avg 和 max）
        attn = torch.mean(attn, dim=-1)  # (B, height, C)
        attn = self.softmax(attn).view(B, self.height, C, 1, 1, 1)  # (B, height, C, 1, 1, 1)

        # 特征加权求和
        out = torch.sum(in_feats * attn, dim=1)  # (B, C, D, H, W)
        return out

In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

ema3d = SKFusionv3(128).to(device)
input_data = torch.rand(1, 128, 32, 32, 32).to(device)  # 3D 输入
output_data = ema3d(input_data)

print(ema3d)
print(output_data.shape)  # 输出形状应与输入形状一致

ValueError: not enough values to unpack (expected 5, got 4)

## 模块测试

EMA(
  (softmax): Softmax(dim=-1)
  (averagePooling): AdaptiveAvgPool2d(output_size=(1, 1))
  (Pool_h): AdaptiveAvgPool2d(output_size=(None, 1))
  (Pool_w): AdaptiveAvgPool2d(output_size=(1, None))
  (groupNorm): GroupNorm(4, 4, eps=1e-05, affine=True)
  (conv1x1): Conv2d(4, 4, kernel_size=(1, 1), stride=(1, 1))
  (conv3x3): Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
torch.Size([1, 128, 128, 128])
