In [2]:
# ==============================================================================
# 0. SETUP: INSTALL REQUIRED LIBRARIES
# ==============================================================================
# Make sure you have the necessary libraries installed
# !pip install torch torchsummary einops -q

# ==============================================================================
# 1. IMPORTS
# ==============================================================================
import torch
import torch.nn as nn
from einops import rearrange
from einops.layers.torch import Rearrange
from torchsummary import summary

# ==============================================================================
# 2. MODEL ARCHITECTURE DEFINITIONS
# ==============================================================================

# --- Building Block from Original Model: Transformer Encoder ---
class TransformerEncoderBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_dim, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_dim), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(mlp_dim, dim), nn.Dropout(dropout)
        )
    def forward(self, x):
        normed_x = self.norm1(x)
        attn_output, _ = self.attn(normed_x, normed_x, normed_x)
        x = x + attn_output
        x = x + self.mlp(self.norm2(x))
        return x

# --- Building Block from Original Model: Transformer Backbone ---
class CustomAttentionBackbone(nn.Module):
    def __init__(self, in_channels=3, patch_size=16, dim=192, depth=6, num_heads=6, mlp_dim=384):
        super().__init__()
        image_height, image_width = 224, 224
        num_patches = (image_height // patch_size) * (image_width // patch_size)
        patch_dim = in_channels * patch_size ** 2
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size),
            nn.LayerNorm(patch_dim), nn.Linear(patch_dim, dim), nn.LayerNorm(dim),
        )
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
        self.transformer_layers = nn.ModuleList([
            TransformerEncoderBlock(dim, num_heads, mlp_dim) for _ in range(depth)
        ])
        self.to_2d = Rearrange('b (h w) c -> b c h w', h=(image_height // patch_size))
    def forward(self, x):
        x = self.to_patch_embedding(x)
        x += self.pos_embedding[:, :(x.shape[1])]
        for layer in self.transformer_layers:
            x = layer(x)
        return self.to_2d(x)

# --- NEW ATTENTION-BASED CNN BLOCK (CBAM Implementation) ---
# This block replaces the PartitionReconstructionAttentionBlock_LMSA

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return self.sigmoid(out)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding='same', bias=False)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)

class AttentionCNNBlock(nn.Module):
    """ An attention-based CNN block (CBAM) for feature refinement. """
    def __init__(self, in_channels, out_channels, kernel_size=3):
        super().__init__()
        # Standard convolution to learn local patterns
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, padding='same', bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        # Attention modules to refine features
        self.ca = ChannelAttention(out_channels)
        self.sa = SpatialAttention()

    def forward(self, x):
        # Apply convolution first
        x = self.conv(x)
        # Then, apply channel attention
        x = x * self.ca(x)
        # Finally, apply spatial attention
        x = x * self.sa(x)
        return x

# --- Building Block from Original Model: Squeeze-and-Excitation Block ---
class ConvSEBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, reduction=8):
        super().__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, padding='same', bias=False),
            nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True)
        )
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(out_channels, out_channels // reduction, 1, bias=False), nn.ReLU(inplace=True),
            nn.Conv2d(out_channels // reduction, out_channels, 1, bias=False), nn.Sigmoid()
        )
    def forward(self, x):
        cnn_out = self.conv_block(x)
        return cnn_out * self.se(cnn_out)

# --- Final Model: Hybrid with Attention-Based CNN Refinement ---
class SemaCheXFormer_AttentionCNN(nn.Module):
    def __init__(self, num_classes=1):
        super().__init__()

        # Stage 1: Transformer backbone for global feature extraction
        self.backbone = CustomAttentionBackbone(
            patch_size=16, dim=192, depth=6, num_heads=6, mlp_dim=384
        )
        backbone_out_channels = 192

        # Stage 2: The new Attention-CNN block for feature refinement
        self.attention_cnn_stage = AttentionCNNBlock(
            in_channels=backbone_out_channels,
            out_channels=128
        )

        # Stage 3: The original SE block to further recalibrate channels
        self.cnn_se_stage = ConvSEBlock(in_channels=128, out_channels=128)

        # Stage 4: The original classification head
        self.classification_head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1), nn.Flatten(),
            nn.Dropout(p=0.2, inplace=True),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        features = self.backbone(x)
        refined_features = self.attention_cnn_stage(features)
        cnn_out = self.cnn_se_stage(refined_features)
        return self.classification_head(cnn_out)

# ==============================================================================
# 3. GENERATE AND PRINT MODEL SUMMARY
# ==============================================================================

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
attention_cnn_model = SemaCheXFormer_AttentionCNN(num_classes=1).to(device)
input_size = (3, 224, 224)

print("=" * 80)
print("     SemaCheXFormer - Model Summary")
print("=" * 80)
summary(attention_cnn_model, input_size=input_size, device=str(device))
print("=" * 80)

     SemaCheXFormer with Attention-Based CNN Refinement - Model Summary
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         Rearrange-1             [-1, 196, 768]               0
         LayerNorm-2             [-1, 196, 768]           1,536
            Linear-3             [-1, 196, 192]         147,648
         LayerNorm-4             [-1, 196, 192]             384
         LayerNorm-5             [-1, 196, 192]             384
MultiheadAttention-6  [[-1, 196, 192], [-1, 196, 196]]               0
         LayerNorm-7             [-1, 196, 192]             384
            Linear-8             [-1, 196, 384]          74,112
              GELU-9             [-1, 196, 384]               0
          Dropout-10             [-1, 196, 384]               0
           Linear-11             [-1, 196, 192]          73,920
          Dropout-12             [-1, 196, 192]               0
TransformerEncoderBlock-