In [4]:
import torch
import torch.nn as nn

class NonLocalBlock(nn.Module):
    def __init__(self, in_channels):
        super(NonLocalBlock, self).__init__()
        self.in_channels = in_channels
        self.inter_channels = in_channels // 2  # 缩减通道数，节约计算

        # 1x1x1 卷积层用于生成 Query, Key, Value
        self.query_conv = nn.Conv3d(in_channels, self.inter_channels, kernel_size=1)
        self.key_conv = nn.Conv3d(in_channels, self.inter_channels, kernel_size=1)
        self.value_conv = nn.Conv3d(in_channels, self.inter_channels, kernel_size=1)

        # 最终的输出通道映射回原始的 in_channels
        self.out_conv = nn.Conv3d(self.inter_channels, in_channels, kernel_size=1)

        # 用于归一化的 softmax
        self.softmax = nn.Softmax(dim=-1)

        # 初始化残差连接中的权重
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        # 输入维度: (N, C, T, H, W)
        batch_size, C, T, H, W = x.size()

        # 生成 Query, Key, Value
        query = self.query_conv(x).view(batch_size, self.inter_channels, -1)  # (N, C', T*H*W)
        key = self.key_conv(x).view(batch_size, self.inter_channels, -1)     # (N, C', T*H*W)
        value = self.value_conv(x).view(batch_size, self.inter_channels, -1) # (N, C', T*H*W)
        
        # 计算相似度: query 和 key 的点积，得到时空位置之间的相似度
        affinity = torch.bmm(query.transpose(1, 2), key)  
        # (N, T*H*W, T*H*W)
        # print(affinity.shape)
        # 使用 Softmax 对相似度进行归一化
        attention = self.softmax(affinity)  
        # (N, T*H*W, T*H*W)
        # print(attention.shape)
        # 使用注意力权重加权 value
        out = torch.bmm(value, attention.transpose(1, 2))  
        # (N, C', T*H*W)
        # print(out.shape)
        # 恢复到原始的空间尺寸 
        out = out.view(batch_size, self.inter_channels, T, H, W)
        # (N, C', T, H, W)
        # print(out.shape)
        # 通过卷积映射到输出通道
        out = self.out_conv(out)  
        # (N, C, T, H, W)
        # print(out.shape)
        # 残差连接
        out = self.gamma * out + x  # 通过 gamma 参数控制残差的比例

        return out

# 测试模块
if __name__ == "__main__":
    # 输入示例: Batch = 2, Channels = 64, Time = 8, Height = 32, Width = 32
    x = torch.rand(2, 64, 8, 32, 32)
    nonlocal_block = NonLocalBlock(in_channels=64)
    out = nonlocal_block(x)
    print("输入维度:", x.shape)
    print("输出维度:", out.shape)


输入维度: torch.Size([2, 64, 8, 32, 32])
输出维度: torch.Size([2, 64, 8, 32, 32])
