In [23]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [24]:
class ChannelAttention(nn.Module):
    def __init__(self, in_chan):
        super(ChannelAttention, self).__init__()
        self.MLP  = nn.Conv2d(in_chan, in_chan, kernel_size=1)
        self.ReLU = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avgPool    = torch.mean(x, dim=(2, 3), keepdim=True)
        maxPool, _ = torch.max(x, dim=2, keepdim=True)  
        maxPool, _ = torch.max(maxPool, dim=3, keepdim=True)  
        out = self.MLP(avgPool) + self.MLP(maxPool)
        return self.sigmoid(out) * x

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv    = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=kernel_size // 2)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avgPool    = torch.mean(x, dim=1, keepdim=True)
        maxPool, _ = torch.max(x, dim=1, keepdim=True)
        concat = torch.cat([avgPool, maxPool], dim=1)
        out = self.conv(concat)
        return self.sigmoid(out) * x

class CBAM(nn.Module):
    def __init__(self, in_chan, kernel_size=7):
        super(CBAM, self).__init__()
        self.chAttn = ChannelAttention(in_chan)
        self.spAttn = SpatialAttention(kernel_size)

    def forward(self, x):
        x = self.chAttn(x)
        x = self.spAttn(x)
        return x


In [None]:
class ChannelAttentionDaddy(nn.Module):
    def __init__(self, in_chan):
        super(ChannelAttentionDaddy, self).__init__()
        self.MLP  = nn.Conv2d(in_chan, in_chan, kernel_size=1)
        self.ReLU = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avgPool    = torch.mean(x, dim=(2, 3), keepdim=True)
        maxPool, _ = torch.max(x, dim=2, keepdim=True)  
        maxPool, _ = torch.max(maxPool, dim=3, keepdim=True)  
        out = self.MLP(avgPool) + self.MLP(maxPool)
        return self.sigmoid(out)

class SpatialAttentionDaddy(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttentionDaddy, self).__init__()
        self.conv    = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=kernel_size // 2)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avgPool    = torch.mean(x, dim=1, keepdim=True)
        maxPool, _ = torch.max(x, dim=1, keepdim=True)
        concat = torch.cat([avgPool, maxPool], dim=1)
        out = self.conv(concat)
        return self.sigmoid(out)

class CBAMDaddy(nn.Module):
    def __init__(self, in_chan, kernel_size=7):
        super(CBAMDaddy, self).__init__()
        self.chAttn = ChannelAttention(in_chan)
        self.spAttn = SpatialAttention(kernel_size)

    def forward(self, x):
        x = self.chAttn(x) * x
        x = self.spAttn(x) * x 
        return x

In [26]:
input = torch.randn(1, 3, 6, 6)

model = CBAMDaddy(3)

output = model(input)

print(output.shape)
print(output)



torch.Size([1, 3, 6, 6])
tensor([[[[8.1796e-02, 6.5391e-01, 2.8656e-03, 5.2295e+00, 1.4008e+00,
           1.3271e+00],
          [7.3064e-02, 1.2798e-01, 3.0845e+00, 2.8438e-01, 4.7552e-01,
           7.0303e+00],
          [1.3514e+00, 3.4308e-03, 4.1171e-01, 2.4986e-04, 5.0205e-01,
           5.6473e-01],
          [2.1673e-03, 5.5378e-02, 2.8053e-06, 2.3386e-02, 1.7142e-01,
           2.6164e-01],
          [4.8052e-04, 3.3477e-02, 1.9566e-01, 5.7302e-02, 9.6979e-02,
           6.7784e-03],
          [9.7828e+00, 1.0728e-01, 2.3751e+00, 2.5471e-04, 5.5607e+00,
           6.1168e-07]],

         [[7.7405e-02, 2.8208e-02, 4.2619e-02, 6.9947e-01, 1.0849e-01,
           9.3995e-03],
          [2.8000e-08, 1.4173e+00, 7.8529e-04, 4.4211e-01, 3.3867e-03,
           4.9019e-04],
          [3.0995e-02, 1.8595e+00, 6.9131e-01, 7.1287e-04, 1.3922e-02,
           5.1789e-08],
          [8.5691e-02, 2.5090e-02, 2.1725e-02, 2.0520e-01, 2.9485e-01,
           3.4463e-02],
          [1.0213e-03, 

In [None]:


class RepVGGBlock(nn.Module):
    def __init__(self, channels):
        super(RepVGGBlock, self).__init__()
        self.conv3 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn3   = nn.BatchNorm2d(channels)
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn1   = nn.BatchNorm2d(channels)
        self.bn0   = nn.BatchNorm2d(channels)
        self.ReLU  = nn.ReLU(inplace=True)
    
    def forward(self, x):
        out_3x3 = self.bn3(self.conv3(x))
        out_1x1 = self.bn1(self.conv1(x))
        out_id  = self.bn0(x)
        out = self.ReLU(out_3x3 + out_1x1 + out_id, inplace=True)
        return out

class RepVGGBlockDeploy(nn.Module):
    def __init__(self, channels):
        super(RepVGGBlockDeploy, self).__init__()
        self.fused_conv = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=True)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        return self.relu(self.fused_conv(x))

# Fusiona una conv con su BN
def fuse_conv_bn(conv, bn):
    W = conv.weight         # [channels, channels, k, k]
    gamma = bn.weight
    beta  = bn.bias
    mean  = bn.running_mean
    var   = bn.running_var
    eps   = bn.eps
    std = torch.sqrt(var + eps)
    
    W_fused = W * (gamma / std).reshape(-1, 1, 1, 1)
    b_fused = beta - mean * (gamma / std)
    return W_fused, b_fused

# Expande un kernel 1x1 a 3x3
def expand_1x1_to_3x3(W1):
    outC, inC, _, _ = W1.shape
    W1_expanded = torch.zeros((outC, inC, 3, 3), dtype=W1.dtype, device=W1.device)
    W1_expanded[:, :, 1, 1] = W1[:, :, 0, 0]
    return W1_expanded

# Crea el kernel identidad y lo fusiona con su BN
def fuse_identity_bn(channels, bn):
    # Crear el kernel identidad
    I = torch.eye(channels).view(channels, channels, 1, 1).to(bn.weight.device)

    # Extraer parámetros de BN:
    gamma = bn.weight 
    beta  = bn.bias 
    mean  = bn.running_mean
    var   = bn.running_var
    eps   = bn.eps

    # Calcular la desviación estándar:
    std = sqrt(var + eps)

    # Fusionar
    W_fused = I * (gamma / std).reshape([channels, 1, 1, 1])
    b_fused = beta - mean * (gamma / std)

    return W_fused, b_fused

# Función para convertir un bloque de entrenamiento a versión evaluacion
def repVGG_convert_block(block):
    channels = block.conv3.in_channels 
    
    # Fusionar rama 3x3
    W3, b3 = fuse_conv_bn(block.conv3, block.bn3)
    
    # Fusionar rama 1x1 y expandir a 3x3
    W1, b1 = fuse_conv_bn(block.conv1, block.bn1)
    W1_3x3 = expand_1x1_to_3x3(W1)
    
    # Fusionar rama identidad y expandir a 3x3
    W_id, b_id = fuse_identity_bn(channels, block.bn0)
    W_id_3x3 = expand_1x1_to_3x3(W_id)
    
    # Sumar los tres kernels y biases
    W_fused = W3 + W1_3x3 + W_id_3x3
    b_fused = b3 + b1 + b_id
    
    # Emplear los pesos fusionados
    deploy_block = RepVGGBlockDeploy(channels)
    deploy_block.fused_conv.weight.data.copy_(W_fused)
    deploy_block.fused_conv.bias.data.copy_(b_fused)
    return deploy_block