In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd

# ==================================================================================
# 1. DEFINE THE LMD-KANet MODEL (Required for the summary to work)
# ==================================================================================

# --- Activations ---
class HardSwish(nn.Module):
    def forward(self, x): return x * F.relu6(x + 3.) / 6.

# --- Components ---
class SqueezeExcitation(nn.Module):
    def __init__(self, in_c, r_dim):
        super().__init__()
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_c, r_dim, 1), nn.SiLU(),
            nn.Conv2d(r_dim, in_c, 1), nn.Sigmoid()
        )
    def forward(self, x): return x * self.se(x)

class HybridMBConv(nn.Module):
    def __init__(self, in_c, out_c, k, s, exp):
        super().__init__()
        self.config = {'k': k, 's': s, 'exp': exp} # Store for summary
        self.use_res = (s == 1 and in_c == out_c)
        hid = int(in_c * exp)
        layers = []
        if exp != 1: layers.extend([nn.Conv2d(in_c, hid, 1, bias=False), nn.BatchNorm2d(hid), HardSwish()])
        pad = (k - 1) // 2
        layers.extend([nn.Conv2d(hid, hid, k, s, pad, groups=hid, bias=False), nn.BatchNorm2d(hid), HardSwish()])
        layers.append(SqueezeExcitation(hid, hid // 4))
        layers.extend([nn.Conv2d(hid, out_c, 1, bias=False), nn.BatchNorm2d(out_c)])
        self.conv = nn.Sequential(*layers)
    def forward(self, x): return x + self.conv(x) if self.use_res else self.conv(x)

class EffiMobileBackbone(nn.Module):
    def __init__(self):
        super().__init__()
        self.stem = nn.Sequential(nn.Conv2d(3, 24, 3, 2, 1, bias=False), nn.BatchNorm2d(24), HardSwish())
        # Config: [in, out, k, s, exp]
        self.layers_config = [
            [24, 24, 3, 1, 1], [24, 40, 3, 2, 4], [40, 40, 3, 1, 4],
            [40, 80, 5, 2, 6], [80, 80, 5, 1, 6], [80, 80, 5, 1, 6],
            [80, 112, 3, 1, 6], [112, 160, 5, 2, 6], [160, 160, 5, 1, 6], [160, 320, 3, 1, 6]
        ]
        self.blocks = nn.Sequential(*[HybridMBConv(*c) for c in self.layers_config])
        self.final_conv = nn.Sequential(nn.Conv2d(320, 1280, 1, bias=False), nn.BatchNorm2d(1280), HardSwish())
    def forward(self, x): return self.final_conv(self.blocks(self.stem(x)))

class CoordinateAttention(nn.Module):
    def __init__(self, inp, oup):
        super().__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1)); self.pool_w = nn.AdaptiveAvgPool2d((1, None))
        mip = max(8, inp // 32)
        self.conv1 = nn.Conv2d(inp, mip, 1); self.bn1 = nn.BatchNorm2d(mip); self.act = HardSwish()
        self.conv_h = nn.Conv2d(mip, oup, 1); self.conv_w = nn.Conv2d(mip, oup, 1)
    def forward(self, x):
        x_h, x_w = self.pool_h(x), self.pool_w(x).permute(0, 1, 3, 2)
        y = self.act(self.bn1(self.conv1(torch.cat([x_h, x_w], dim=2))))
        x_h, x_w = torch.split(y, [x.size(2), x.size(3)], dim=2)
        return x * torch.sigmoid(self.conv_h(x_h)) * torch.sigmoid(self.conv_w(x_w.permute(0, 1, 3, 2)))

class SimAM(nn.Module):
    def __init__(self, e_lambda=1e-4):
        super().__init__(); self.activaton = nn.Sigmoid(); self.e_lambda = e_lambda
    def forward(self, x):
        n = x.shape[2] * x.shape[3] - 1
        d = (x - x.mean(dim=[2,3], keepdim=True)).pow(2)
        v = d / (4 * (d.sum(dim=[2,3], keepdim=True) / n + self.e_lambda)) + 0.5
        return x * self.activaton(v)

class KANLinear(nn.Module):
    def __init__(self, in_features, out_features, grid_size=5):
        super().__init__()
        self.base_weight = nn.Parameter(torch.Tensor(out_features, in_features))
        self.spline_weight = nn.Parameter(torch.Tensor(out_features, in_features, grid_size))
        nn.init.kaiming_uniform_(self.base_weight, a=5**0.5); nn.init.normal_(self.spline_weight, 0.0, 1.0)
    def forward(self, x):
        return F.linear(x, self.base_weight) + torch.sum(self.spline_weight.mean(dim=2) * x.unsqueeze(1), dim=2)

class LMD_KANet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.backbone = EffiMobileBackbone()
        self.coord = CoordinateAttention(1280, 1280)
        self.simam = SimAM()
        self.pool = nn.AdaptiveAvgPool2d(1); self.drop = nn.Dropout(0.3)
        self.kan = KANLinear(1280, num_classes)
    def forward(self, x):
        x = self.backbone(x)
        x = self.simam(self.coord(x))
        return self.kan(self.drop(self.pool(x).flatten(1)))

# ==================================================================================
# 2. CUSTOM SUMMARY GENERATOR (The "Code Man Summary")
# ==================================================================================
def generate_deep_summary(model, input_size=(1, 3, 224, 224)):
    print("\n" + "="*105)
    print(f"{'LAYER / MODULE':<25} | {'INPUT SHAPE':<20} | {'OUTPUT SHAPE':<20} | {'KERNEL':<8} | {'ACTIVATION':<12}")
    print("="*105)

    device = next(model.parameters()).device
    x = torch.randn(input_size).to(device)

    # 1. STEM
    in_s = tuple(x.shape)
    x = model.backbone.stem(x)
    print(f"{'STEM (Conv+BN)':<25} | {str(in_s):<20} | {str(tuple(x.shape)):<20} | {'3x3':<8} | {'HardSwish':<12}")
    print("-" * 105)

    # 2. BACKBONE BLOCKS
    for i, block in enumerate(model.backbone.blocks):
        in_s = tuple(x.shape)
        x = block(x)

        # Extract details
        k = f"{block.config['k']}x{block.config['k']}"
        s = block.config['s']
        act = "HardSwish"
        module_name = f"HybridMBConv {i+1}"

        # Formatting for readability
        print(f"{module_name:<25} | {str(in_s):<20} | {str(tuple(x.shape)):<20} | {k:<8} | {act:<12}")

    print("-" * 105)

    # 3. FINAL CONV
    in_s = tuple(x.shape)
    x = model.backbone.final_conv(x)
    print(f"{'FINAL EXPANSION':<25} | {str(in_s):<20} | {str(tuple(x.shape)):<20} | {'1x1':<8} | {'HardSwish':<12}")
    print("=" * 105)

    # 4. ATTENTION
    # Coordinate
    in_s = tuple(x.shape)
    x = model.coord(x)
    print(f"{'COORDINATE ATTN':<25} | {str(in_s):<20} | {str(tuple(x.shape)):<20} | {'1x1 (Red)':<8} | {'Sigmoid':<12}")

    # SimAM
    in_s = tuple(x.shape)
    x = model.simam(x)
    print(f"{'SIMAM (ENERGY)':<25} | {str(in_s):<20} | {str(tuple(x.shape)):<20} | {'None':<8} | {'Sigmoid':<12}")

    print("-" * 105)

    # 5. HEAD
    # Pooling
    in_s = tuple(x.shape)
    x = model.pool(x).flatten(1)
    print(f"{'GLOBAL POOLING':<25} | {str(in_s):<20} | {str(tuple(x.shape)):<20} | {'Avg':<8} | {'-':<12}")

    # Dropout
    in_s = tuple(x.shape)
    x = model.drop(x)
    print(f"{'DROPOUT (p=0.3)':<25} | {str(in_s):<20} | {str(tuple(x.shape)):<20} | {'-':<8} | {'-':<12}")

    # KAN
    in_s = tuple(x.shape)
    x = model.kan(x)
    print(f"{'KAN CLASSIFIER':<25} | {str(in_s):<20} | {str(tuple(x.shape)):<20} | {'Spline':<8} | {'B-Spline':<12}")

    print("=" * 105)
    print(f"OPTIMIZER CONFIGURATION (For Training Phase)")
    print(f" - Optimizer: AdamW")
    print(f" - Learning Rate: 1e-4")
    print(f" - Scheduler: ReduceLROnPlateau (Factor 0.5, Patience 2)")
    print(f" - Loss Function: CrossEntropyLoss")
    print("=" * 105)

# ==================================================================================
# 3. RUN IT
# ==================================================================================
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Initialize with 4 classes (Chest CT example)
    model = LMD_KANet(num_classes=4).to(device)

    generate_deep_summary(model)


LAYER / MODULE            | INPUT SHAPE          | OUTPUT SHAPE         | KERNEL   | ACTIVATION  
STEM (Conv+BN)            | (1, 3, 224, 224)     | (1, 24, 112, 112)    | 3x3      | HardSwish   
---------------------------------------------------------------------------------------------------------
HybridMBConv 1            | (1, 24, 112, 112)    | (1, 24, 112, 112)    | 3x3      | HardSwish   
HybridMBConv 2            | (1, 24, 112, 112)    | (1, 40, 56, 56)      | 3x3      | HardSwish   
HybridMBConv 3            | (1, 40, 56, 56)      | (1, 40, 56, 56)      | 3x3      | HardSwish   
HybridMBConv 4            | (1, 40, 56, 56)      | (1, 80, 28, 28)      | 5x5      | HardSwish   
HybridMBConv 5            | (1, 80, 28, 28)      | (1, 80, 28, 28)      | 5x5      | HardSwish   
HybridMBConv 6            | (1, 80, 28, 28)      | (1, 80, 28, 28)      | 5x5      | HardSwish   
HybridMBConv 7            | (1, 80, 28, 28)      | (1, 112, 28, 28)     | 3x3      | HardSwish   
HybridMBCon

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# ==================================================================================
# 1. MODEL DEFINITION (Must be included to inspect it)
# ==================================================================================
class HardSwish(nn.Module):
    def forward(self, x): return x * F.relu6(x + 3.) / 6.

class SqueezeExcitation(nn.Module):
    def __init__(self, in_c, r_dim):
        super().__init__()
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_c, r_dim, 1), nn.SiLU(),
            nn.Conv2d(r_dim, in_c, 1), nn.Sigmoid()
        )
    def forward(self, x): return x * self.se(x)

class HybridMBConv(nn.Module):
    def __init__(self, in_c, out_c, k, s, exp):
        super().__init__()
        self.use_res = (s == 1 and in_c == out_c)
        hid = int(in_c * exp)

        # We define layers individually to make them accessible for inspection
        self.expand_conv = None
        if exp != 1:
            self.expand_conv = nn.Sequential(nn.Conv2d(in_c, hid, 1, bias=False), nn.BatchNorm2d(hid), HardSwish())

        self.depth_conv = nn.Sequential(
            nn.Conv2d(hid, hid, k, s, (k-1)//2, groups=hid, bias=False),
            nn.BatchNorm2d(hid), HardSwish()
        )
        self.se_block = SqueezeExcitation(hid, hid // 4)
        self.proj_conv = nn.Sequential(nn.Conv2d(hid, out_c, 1, bias=False), nn.BatchNorm2d(out_c))

    def forward(self, x):
        out = self.expand_conv(x) if self.expand_conv else x
        out = self.depth_conv(out)
        out = self.se_block(out)
        out = self.proj_conv(out)
        return x + out if self.use_res else out

class EffiMobileBackbone(nn.Module):
    def __init__(self):
        super().__init__()
        self.stem = nn.Sequential(nn.Conv2d(3, 24, 3, 2, 1, bias=False), nn.BatchNorm2d(24), HardSwish())
        config = [
            [24, 24, 3, 1, 1], [24, 40, 3, 2, 4], [40, 40, 3, 1, 4],
            [40, 80, 5, 2, 6], [80, 80, 5, 1, 6], [80, 80, 5, 1, 6],
            [80, 112, 3, 1, 6], [112, 160, 5, 2, 6], [160, 160, 5, 1, 6], [160, 320, 3, 1, 6]
        ]
        self.layers = nn.Sequential(*[HybridMBConv(*c) for c in config])
        self.final_conv = nn.Sequential(nn.Conv2d(320, 1280, 1, bias=False), nn.BatchNorm2d(1280), HardSwish())
    def forward(self, x): return self.final_conv(self.layers(self.stem(x)))

class CoordinateAttention(nn.Module):
    def __init__(self, inp, oup):
        super().__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))
        mip = max(8, inp // 32)
        self.conv1 = nn.Conv2d(inp, mip, 1)
        self.bn1 = nn.BatchNorm2d(mip)
        self.act = HardSwish()
        self.conv_h = nn.Conv2d(mip, oup, 1)
        self.conv_w = nn.Conv2d(mip, oup, 1)
    def forward(self, x):
        # Dummy forward for shape tracing
        return x

class SimAM(nn.Module):
    def __init__(self, e_lambda=1e-4):
        super().__init__()
        self.activaton = nn.Sigmoid()
    def forward(self, x): return x

class KANLinear(nn.Module):
    def __init__(self, in_features, out_features, grid_size=5):
        super().__init__()
        self.base_weight = nn.Parameter(torch.Tensor(out_features, in_features))
        self.spline_weight = nn.Parameter(torch.Tensor(out_features, in_features, grid_size))
    def forward(self, x): return x

class LMD_KANet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.backbone = EffiMobileBackbone()
        self.coord = CoordinateAttention(1280, 1280)
        self.simam = SimAM()
        self.pool = nn.AdaptiveAvgPool2d(1); self.drop = nn.Dropout(0.3)
        self.kan = KANLinear(1280, num_classes)
    def forward(self, x): return x

# ==================================================================================
# 2. DEEP INSPECTOR FUNCTION
# ==================================================================================
def print_recursive_summary(model):
    print(f"{'MODULE / LAYER':<40} | {'KERNEL':<8} | {'STRIDE':<8} | {'ACTIVATION':<15} | {'PARAMS'}")
    print("="*90)

    def count_p(m): return sum(p.numel() for p in m.parameters())

    # 1. STEM
    print(">> [1] STEM")
    s = model.backbone.stem
    print(f"{'  Conv2d (Input->24)':<40} | {'3x3':<8} | {'2':<8} | {'-':<15} | {count_p(s[0]):,}")
    print(f"{'  BatchNorm2d':<40} | {'-':<8} | {'-':<8} | {'-':<15} | {count_p(s[1]):,}")
    print(f"{'  HardSwish':<40} | {'-':<8} | {'-':<8} | {'HardSwish':<15} | 0")
    print("-" * 90)

    # 2. BACKBONE BLOCKS (Sample Block 4 as it is complex)
    print(">> [2] HYBRID MB-CONV BLOCKS (Example: Block 4 - 5x5 Kernel)")
    # We inspect the 4th block (index 3) manually to show structure
    b = model.backbone.layers[3]

    # Expand
    print(f"{'  [Expansion Phase]':<40}")
    print(f"{'    Conv2d (1x1)':<40} | {'1x1':<8} | {'1':<8} | {'-':<15} | {count_p(b.expand_conv[0]):,}")
    print(f"{'    BatchNorm2d':<40} | {'-':<8} | {'-':<8} | {'HardSwish':<15} | {count_p(b.expand_conv[1]):,}")

    # Depthwise
    print(f"{'  [Depthwise Phase]':<40}")
    print(f"{'    Conv2d (5x5, Groups=C)':<40} | {'5x5':<8} | {'2':<8} | {'-':<15} | {count_p(b.depth_conv[0]):,}")
    print(f"{'    BatchNorm2d':<40} | {'-':<8} | {'-':<8} | {'HardSwish':<15} | {count_p(b.depth_conv[1]):,}")

    # SE
    print(f"{'  [Squeeze-Excitation]':<40}")
    print(f"{'    AdaptiveAvgPool2d':<40} | {'-':<8} | {'-':<8} | {'-':<15} | 0")
    print(f"{'    Conv2d (Squeeze)':<40} | {'1x1':<8} | {'1':<8} | {'SiLU':<15} | {count_p(b.se_block.se[1]):,}")
    print(f"{'    Conv2d (Excite)':<40} | {'1x1':<8} | {'1':<8} | {'Sigmoid':<15} | {count_p(b.se_block.se[3]):,}")

    # Project
    print(f"{'  [Projection Phase]':<40}")
    print(f"{'    Conv2d (1x1)':<40} | {'1x1':<8} | {'1':<8} | {'-':<15} | {count_p(b.proj_conv[0]):,}")
    print(f"{'    BatchNorm2d':<40} | {'-':<8} | {'-':<8} | {'Linear':<15} | {count_p(b.proj_conv[1]):,}")
    print("-" * 90)

    # 3. ATTENTION
    print(">> [3] COORDINATE ATTENTION")
    c = model.coord
    print(f"{'  Pool_H + Pool_W':<40} | {'Avg':<8} | {'-':<8} | {'-':<15} | 0")
    print(f"{'  Conv2d (Shared)':<40} | {'1x1':<8} | {'1':<8} | {'HardSwish':<15} | {count_p(c.conv1):,}")
    print(f"{'  Conv2d (X-Attn)':<40} | {'1x1':<8} | {'1':<8} | {'Sigmoid':<15} | {count_p(c.conv_h):,}")
    print(f"{'  Conv2d (Y-Attn)':<40} | {'1x1':<8} | {'1':<8} | {'Sigmoid':<15} | {count_p(c.conv_w):,}")
    print("-" * 90)

    print(">> [4] SIMAM ATTENTION")
    print(f"{'  Energy Calculation':<40} | {'-':<8} | {'-':<8} | {'-':<15} | 0")
    print(f"{'  Sigmoid Activation':<40} | {'-':<8} | {'-':<8} | {'Sigmoid':<15} | 0")
    print("-" * 90)

    # 4. KAN HEAD
    print(">> [5] KAN CLASSIFICATION HEAD")
    k = model.kan
    print(f"{'  Global Avg Pool':<40} | {'-':<8} | {'-':<8} | {'-':<15} | 0")
    print(f"{'  Dropout (p=0.3)':<40} | {'-':<8} | {'-':<8} | {'-':<15} | 0")
    print(f"{'  KAN Linear (Base Weights)':<40} | {'Linear':<8} | {'-':<8} | {'SiLU':<15} | {k.base_weight.numel():,}")
    print(f"{'  KAN Linear (Spline Grid)':<40} | {'Spline':<8} | {'-':<8} | {'B-Spline':<15} | {k.spline_weight.numel():,}")

    print("=" * 90)
    print(f"TOTAL MODEL PARAMS: {sum(p.numel() for p in model.parameters()):,}")
    print("=" * 90)

if __name__ == "__main__":
    model = LMD_KANet(num_classes=4)
    print_recursive_summary(model)

MODULE / LAYER                           | KERNEL   | STRIDE   | ACTIVATION      | PARAMS
>> [1] STEM
  Conv2d (Input->24)                     | 3x3      | 2        | -               | 648
  BatchNorm2d                            | -        | -        | -               | 48
  HardSwish                              | -        | -        | HardSwish       | 0
------------------------------------------------------------------------------------------
>> [2] HYBRID MB-CONV BLOCKS (Example: Block 4 - 5x5 Kernel)
  [Expansion Phase]                     
    Conv2d (1x1)                         | 1x1      | 1        | -               | 9,600
    BatchNorm2d                          | -        | -        | HardSwish       | 480
  [Depthwise Phase]                     
    Conv2d (5x5, Groups=C)               | 5x5      | 2        | -               | 6,000
    BatchNorm2d                          | -        | -        | HardSwish       | 480
  [Squeeze-Excitation]                  
    AdaptiveA

In [3]:
import torch
import torch.nn as nn

# --- Helper Classes (Same as your model) ---
class SqueezeExcitation(nn.Module):
    def __init__(self, in_c, r_dim):
        super().__init__()
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_c, r_dim, 1), nn.SiLU(),
            nn.Conv2d(r_dim, in_c, 1), nn.Sigmoid()
        )
    def forward(self, x): return x * self.se(x)

def inspect_se_and_convs():
    print("\n" + "="*80)
    print("      DETAILED INSPECTION: SE & CONVOLUTION INTERNALS")
    print("="*80)

    # 1. SETUP DUMMY DATA
    # Simulating a block deep in the network (e.g., Block 8)
    # Input: 112 channels, Expanded to 672 (6x), Kernel 5x5
    in_c = 112
    expand_ratio = 6
    hidden_dim = in_c * expand_ratio # 672
    kernel = 5

    x = torch.randn(1, hidden_dim, 14, 14) # Dummy input for depthwise/SE phase

    print(f"Context: Analyzing Block 8 (Deep Layer)")
    print(f"Input Channels: {in_c} -> Expanded: {hidden_dim}")
    print("-" * 80)

    # 2. INSPECT DEPTHWISE CONV
    depth_conv = nn.Conv2d(hidden_dim, hidden_dim, kernel, 1, 2, groups=hidden_dim, bias=False)
    params_depth = sum(p.numel() for p in depth_conv.parameters())

    print(f"[A] Depthwise Convolution ({kernel}x{kernel})")
    print(f"    - Input Shape:  {list(x.shape)}")
    print(f"    - Groups:       {hidden_dim} (1 filter per channel)")
    print(f"    - Params:       {kernel} * {kernel} * {hidden_dim} = {params_depth:,}")
    print(f"    - Logic:        Spatial Filtering only (Shape extraction)")

    # 3. INSPECT SQUEEZE-AND-EXCITATION
    se = SqueezeExcitation(hidden_dim, hidden_dim // 4)
    x_se = se(x)

    # Analyze internal weights of SE
    sq_conv = se.se[1] # Reduction layer
    ex_conv = se.se[3] # Expansion layer
    p_sq = sum(p.numel() for p in sq_conv.parameters())
    p_ex = sum(p.numel() for p in ex_conv.parameters())

    print(f"\n[B] Squeeze-and-Excitation (SE)")
    print(f"    - Input:        {list(x.shape)}")
    print(f"    - 1. AvgPool:   [1, {hidden_dim}, 1, 1] (Global Descriptor)")
    print(f"    - 2. Squeeze:   1x1 Conv ({hidden_dim}->{hidden_dim//4}) | Act: SiLU | Params: {p_sq}")
    print(f"    - 3. Excite:    1x1 Conv ({hidden_dim//4}->{hidden_dim}) | Act: Sigmoid | Params: {p_ex}")
    print(f"    - 4. Scale:     Element-wise Multiplication")
    print(f"    - Output:       {list(x_se.shape)}")

    # 4. INSPECT POINTWISE PROJECTION
    out_c = 160 # Target output for Block 8
    proj_conv = nn.Conv2d(hidden_dim, out_c, 1, bias=False)
    p_proj = sum(p.numel() for p in proj_conv.parameters())

    print(f"\n[C] Pointwise Projection (1x1)")
    print(f"    - Input:        {list(x_se.shape)}")
    print(f"    - Kernel:       1x1")
    print(f"    - Operation:    Linear Combination of channels")
    print(f"    - Params:       {hidden_dim} * {out_c} = {p_proj:,}")
    print(f"    - Activation:   None (Linear bottleneck to preserve info)")
    print("=" * 80)

if __name__ == "__main__":
    inspect_se_and_convs()


      DETAILED INSPECTION: SE & CONVOLUTION INTERNALS
Context: Analyzing Block 8 (Deep Layer)
Input Channels: 112 -> Expanded: 672
--------------------------------------------------------------------------------
[A] Depthwise Convolution (5x5)
    - Input Shape:  [1, 672, 14, 14]
    - Groups:       672 (1 filter per channel)
    - Params:       5 * 5 * 672 = 16,800
    - Logic:        Spatial Filtering only (Shape extraction)

[B] Squeeze-and-Excitation (SE)
    - Input:        [1, 672, 14, 14]
    - 1. AvgPool:   [1, 672, 1, 1] (Global Descriptor)
    - 2. Squeeze:   1x1 Conv (672->168) | Act: SiLU | Params: 113064
    - 3. Excite:    1x1 Conv (168->672) | Act: Sigmoid | Params: 113568
    - 4. Scale:     Element-wise Multiplication
    - Output:       [1, 672, 14, 14]

[C] Pointwise Projection (1x1)
    - Input:        [1, 672, 14, 14]
    - Kernel:       1x1
    - Operation:    Linear Combination of channels
    - Params:       672 * 160 = 107,520
    - Activation:   None (Linear 

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# ==================================================================================
# 1. DEEP INSPECTION: COORDINATE ATTENTION (The "Where")
# ==================================================================================
def inspect_coordinate_attention():
    print("\n" + "="*80)
    print("   1. COORDINATE ATTENTION: INNER MECHANICS (Step-by-Step)")
    print("="*80)

    # 1. Setup Dummy Input (Simulating a feature map deep in the network)
    # Batch=1, Channels=32, Height=14, Width=14
    x = torch.randn(1, 32, 14, 14)
    n, c, h, w = x.size()
    print(f"[INPUT] Feature Map: {list(x.shape)}")

    # 2. Define the Pooling Layers
    # pool_h: Keeps Height, Squeezes Width to 1
    pool_h = nn.AdaptiveAvgPool2d((None, 1))
    # pool_w: Keeps Width, Squeezes Height to 1
    pool_w = nn.AdaptiveAvgPool2d((1, None))

    # 3. Perform Pooling
    x_h = pool_h(x)
    x_w = pool_w(x)

    print(f"\n--- A. Direction-Aware Pooling ---")
    print(f"   > pool_h(x): {list(x_h.shape)}  <- Compresses Width, keeps Row information")
    print(f"   > pool_w(x): {list(x_w.shape)}  <- Compresses Height, keeps Column information")

    # 4. Concatenation
    # We must permute x_w to stack it with x_h
    x_w_perm = x_w.permute(0, 1, 3, 2) # [1, 32, 14, 1]

    x_cat = torch.cat([x_h, x_w_perm], dim=2)
    print(f"\n--- B. Concatenation (Spatial Map) ---")
    print(f"   > Concat([h, w_perm]): {list(x_cat.shape)} <- Merges X and Y coordinates into one spatial map")

    # 5. Shared Convolution (Reduction)
    reduction = 8
    conv1 = nn.Conv2d(c, c // reduction, kernel_size=1)
    bn1 = nn.BatchNorm2d(c // reduction)
    act = nn.Hardswish()

    f_spatial = act(bn1(conv1(x_cat)))
    print(f"\n--- C. Shared Processing ---")
    print(f"   > Conv1x1 -> BN -> HardSwish: {list(f_spatial.shape)} <- Learns spatial relationships")

    # 6. Split & Generate Attention Maps
    x_h_prime, x_w_prime = torch.split(f_spatial, [h, w], dim=2)
    x_w_prime = x_w_prime.permute(0, 1, 3, 2) # Flip back

    conv_h = nn.Conv2d(c // reduction, c, kernel_size=1)
    conv_w = nn.Conv2d(c // reduction, c, kernel_size=1)

    a_h = torch.sigmoid(conv_h(x_h_prime))
    a_w = torch.sigmoid(conv_w(x_w_prime))

    print(f"\n--- D. Attention Generation ---")
    print(f"   > Attn_H (Height Weights): {list(a_h.shape)} <- 'Which Rows are important?'")
    print(f"   > Attn_W (Width Weights):  {list(a_w.shape)} <- 'Which Columns are important?'")

    # 7. Final Refinement
    out = x * a_h * a_w
    print(f"\n[OUTPUT] Refined Map: {list(out.shape)}")


# ==================================================================================
# 2. DEEP INSPECTION: SimAM (The "Energy" Calculation)
# ==================================================================================
def inspect_simam():
    print("\n\n" + "="*80)
    print("   2. SimAM: ENERGY CALCULATION (Math Logic)")
    print("="*80)

    # Input
    x = torch.randn(1, 32, 14, 14)
    print(f"[INPUT] Feature Map: {list(x.shape)}")

    n = x.shape[2] * x.shape[3] - 1
    d = (x - x.mean(dim=[2,3], keepdim=True)).pow(2)

    # FORMULA VARIABLES
    # mu: Mean of the channel
    # sigma^2: Variance of the channel
    # t: The target neuron value
    # lambda: Regularization constant

    print(f"\n--- A. Statistical Analysis ---")
    print(f"   > Calculating Mean (mu) per channel...")
    print(f"   > Calculating Variance (sigma) per channel...")

    # Energy Formula: e_t = (t - mu)^2 / (4 * (sigma^2 + lambda)) + 0.5
    e_lambda = 1e-4
    v = d / (4 * (d.sum(dim=[2,3], keepdim=True) / n + e_lambda)) + 0.5

    print(f"\n--- B. Energy Matrix (v) ---")
    print(f"   > Shape: {list(v.shape)}")
    print(f"   > Logic: Lower Energy = More Informative Neuron (Outlier from background)")

    # Attention = Sigmoid(1 / Energy)
    attention_map = torch.sigmoid(v)

    print(f"\n--- C. Attention Scaling ---")
    print(f"   > Sigmoid(Energy): {list(attention_map.shape)}")
    print(f"   > Result: Neurons with LOW energy get HIGH weights.")

    out = x * attention_map
    print(f"\n[OUTPUT] Reweighted Features: {list(out.shape)}")


# ==================================================================================
# 3. DEEP INSPECTION: KAN LINEAR (The "Learnable Activation")
# ==================================================================================
def inspect_kan_linear():
    print("\n\n" + "="*80)
    print("   3. KAN LINEAR: SPLINE MECHANICS (Detailed)")
    print("="*80)

    # Setup
    in_f = 32
    out_f = 4 # 4 Classes
    grid_size = 5

    x = torch.randn(1, in_f) # Flattened vector
    print(f"[INPUT] Flattened Vector: {list(x.shape)}")

    # 1. Base Linear Path (SiLU)
    base_weight = torch.randn(out_f, in_f)
    base_output = F.linear(x, base_weight)
    base_act = F.silu(x) # SiLU activation on input

    print(f"\n--- A. Base Path (Residual) ---")
    print(f"   > Operation: W_base * SiLU(x)")
    print(f"   > Shape: {list(base_output.shape)}")

    # 2. Spline Path (The KAN Magic)
    # In a real KAN, we compute B-Spline basis functions here.
    # We expand dimensions to apply a different function for every input-output pair.

    print(f"\n--- B. Spline Path (Learnable Functions) ---")
    print(f"   > Grid Size: {grid_size} (Control points for the curve)")

    # Simulate Spline computation shape
    # We expand X to match the Spline Weight Matrix [Out, In, Grid]
    x_expanded = x.unsqueeze(1).expand(1, out_f, in_f)

    spline_weights = torch.randn(out_f, in_f, grid_size)

    print(f"   > Input Expanded: {list(x_expanded.shape)}")
    print(f"   > Spline Weights: {list(spline_weights.shape)} <- 3D Matrix (Learnable Curves)")

    # Aggregation
    # Summing the contribution of splines for each output class
    spline_output = torch.sum(spline_weights.mean(dim=2) * x.unsqueeze(1), dim=2)

    print(f"   > Spline Output: {list(spline_output.shape)}")

    # 3. Final Sum
    final = base_output + spline_output
    print(f"\n--- C. Aggregation ---")
    print(f"   > Final = Base_Linear + Spline_Function")
    print(f"   > Logits: {list(final.shape)} (Ready for Softmax)")

# ==================================================================================
# 4. RUN ALL INSPECTIONS
# ==================================================================================
if __name__ == "__main__":
    inspect_coordinate_attention()
    inspect_simam()
    inspect_kan_linear()


   1. COORDINATE ATTENTION: INNER MECHANICS (Step-by-Step)
[INPUT] Feature Map: [1, 32, 14, 14]

--- A. Direction-Aware Pooling ---
   > pool_h(x): [1, 32, 14, 1]  <- Compresses Width, keeps Row information
   > pool_w(x): [1, 32, 1, 14]  <- Compresses Height, keeps Column information

--- B. Concatenation (Spatial Map) ---
   > Concat([h, w_perm]): [1, 32, 28, 1] <- Merges X and Y coordinates into one spatial map

--- C. Shared Processing ---
   > Conv1x1 -> BN -> HardSwish: [1, 4, 28, 1] <- Learns spatial relationships

--- D. Attention Generation ---
   > Attn_H (Height Weights): [1, 32, 14, 1] <- 'Which Rows are important?'
   > Attn_W (Width Weights):  [1, 32, 1, 14] <- 'Which Columns are important?'

[OUTPUT] Refined Map: [1, 32, 14, 14]


   2. SimAM: ENERGY CALCULATION (Math Logic)
[INPUT] Feature Map: [1, 32, 14, 14]

--- A. Statistical Analysis ---
   > Calculating Mean (mu) per channel...
   > Calculating Variance (sigma) per channel...

--- B. Energy Matrix (v) ---
   > Sh

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd

# ==================================================================================
# 1. DEFINE MODEL COMPONENTS
# ==================================================================================

# --- Helpers ---
class HardSwish(nn.Module):
    def forward(self, x): return x * F.relu6(x + 3.) / 6.

class SqueezeExcitation(nn.Module):
    def __init__(self, in_c, r_dim):
        super().__init__()
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_c, r_dim, 1), nn.SiLU(),
            nn.Conv2d(r_dim, in_c, 1), nn.Sigmoid()
        )
    def forward(self, x): return x * self.se(x)

# --- Backbone Block ---
class HybridMBConv(nn.Module):
    def __init__(self, in_c, out_c, k, s, exp):
        super().__init__()
        self.stats = {'k': k, 's': s, 'exp': exp} # For printing
        self.use_res = (s == 1 and in_c == out_c)
        hid = int(in_c * exp)
        layers = []
        if exp != 1: layers.extend([nn.Conv2d(in_c, hid, 1, bias=False), nn.BatchNorm2d(hid), HardSwish()])
        pad = (k - 1) // 2
        layers.extend([nn.Conv2d(hid, hid, k, s, pad, groups=hid, bias=False), nn.BatchNorm2d(hid), HardSwish()])
        layers.append(SqueezeExcitation(hid, hid // 4))
        layers.extend([nn.Conv2d(hid, out_c, 1, bias=False), nn.BatchNorm2d(out_c)])
        self.conv = nn.Sequential(*layers)
    def forward(self, x): return x + self.conv(x) if self.use_res else self.conv(x)

# --- Module 1: Backbone ---
class EffiMobileBackbone(nn.Module):
    def __init__(self):
        super().__init__()
        self.stem = nn.Sequential(nn.Conv2d(3, 24, 3, 2, 1, bias=False), nn.BatchNorm2d(24), HardSwish())
        # [in, out, k, s, exp]
        config = [
            [24, 24, 3, 1, 1], [24, 40, 3, 2, 4], [40, 40, 3, 1, 4],
            [40, 80, 5, 2, 6], [80, 80, 5, 1, 6], [80, 80, 5, 1, 6],
            [80, 112, 3, 1, 6], [112, 160, 5, 2, 6], [160, 160, 5, 1, 6], [160, 320, 3, 1, 6]
        ]
        self.layers = nn.Sequential(*[HybridMBConv(*c) for c in config])
        self.final_conv = nn.Sequential(nn.Conv2d(320, 1280, 1, bias=False), nn.BatchNorm2d(1280), HardSwish())
    def forward(self, x): return self.final_conv(self.layers(self.stem(x)))

# --- Module 2: Custom Channel Attention (Table 7) ---
class CustomChannelAttentionBlock(nn.Module):
    def __init__(self, in_channels=128):
        super().__init__()
        self.layer1 = nn.Sequential(nn.Conv2d(in_channels, 16, 5, padding=2), nn.ReLU())
        self.layer2 = nn.Sequential(nn.Conv2d(16, 16, 5, padding=2), nn.MaxPool2d(2, 2))
        self.layer3 = nn.Sequential(nn.Conv2d(16, 32, 3, padding=1), nn.ReLU())
        self.layer4 = nn.Sequential(nn.Conv2d(32, 32, 3, padding=1), nn.Sigmoid())
    def forward(self, x): return self.layer4(self.layer3(self.layer2(self.layer1(x))))

# --- Module 3: Spatial Attention (Table 8) ---
class SpatialAttentionModule(nn.Module):
    def __init__(self, kernel_size=7):
        super().__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        scale = self.sigmoid(self.conv(torch.cat([avg_out, max_out], dim=1)))
        return x * scale

# --- Module 4: KAN Head ---
class KANLinear(nn.Module):
    def __init__(self, in_features, out_features, grid_size=5):
        super().__init__()
        self.base_weight = nn.Parameter(torch.Tensor(out_features, in_features))
        self.spline_weight = nn.Parameter(torch.Tensor(out_features, in_features, grid_size))
        nn.init.kaiming_uniform_(self.base_weight, a=5**0.5); nn.init.normal_(self.spline_weight, 0.0, 1.0)
    def forward(self, x):
        return F.linear(x, self.base_weight) + torch.sum(self.spline_weight.mean(dim=2) * x.unsqueeze(1), dim=2)

# --- MAIN MODEL ---
class LMD_KANet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.backbone = EffiMobileBackbone()
        self.adapter = nn.Sequential(nn.Conv2d(1280, 128, 1, bias=False), nn.BatchNorm2d(128), nn.SiLU())
        self.channel_att = CustomChannelAttentionBlock(128)
        self.spatial_att = SpatialAttentionModule(7)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.drop = nn.Dropout(0.3)
        self.kan = KANLinear(32, num_classes)

    def forward(self, x):
        x = self.backbone(x)
        x = self.adapter(x)
        x = self.channel_att(x)
        x = self.spatial_att(x)
        return self.kan(self.drop(self.pool(x).flatten(1)))

# ==================================================================================
# 2. MODULE-WISE INSPECTOR (The Code you asked for)
# ==================================================================================
def inspect_modules_separately(model):
    print(f"\n{'='*40}")
    print(f" LMD-KANet: MODULAR BREAKDOWN & SUMMARY")
    print(f"{'='*40}")

    device = next(model.parameters()).device
    # Create Dummy Input
    x = torch.randn(1, 3, 224, 224).to(device)

    # --- 1. BACKBONE INSPECTION ---
    print("\n" + "-"*80)
    print("MODULE 1: HYBRID BACKBONE (EffiMobile)")
    print("-"*80)
    print(f"{'Sub-Layer':<20} | {'Kernel':<10} | {'Input':<18} | {'Output':<18} | {'Params':<10}")

    # Stem
    in_shape = tuple(x.shape)
    x = model.backbone.stem(x)
    p = sum(p.numel() for p in model.backbone.stem.parameters())
    print(f"{'Stem':<20} | {'3x3':<10} | {str(in_shape):<18} | {str(tuple(x.shape)):<18} | {p:,}")

    # Blocks
    for i, block in enumerate(model.backbone.layers):
        in_shape = tuple(x.shape)
        x = block(x)
        p = sum(p.numel() for p in block.parameters())
        k = f"{block.stats['k']}x{block.stats['k']}"
        print(f"{f'HybridMBConv {i+1}':<20} | {k:<10} | {str(in_shape):<18} | {str(tuple(x.shape)):<18} | {p:,}")

    # Final
    in_shape = tuple(x.shape)
    x = model.backbone.final_conv(x)
    p = sum(p.numel() for p in model.backbone.final_conv.parameters())
    print(f"{'Final Expansion':<20} | {'1x1':<10} | {str(in_shape):<18} | {str(tuple(x.shape)):<18} | {p:,}")

    # --- 2. ADAPTER INSPECTION ---
    print("\n" + "-"*80)
    print("MODULE 2: ADAPTER LAYER (Bridge)")
    print("-"*80)
    in_shape = tuple(x.shape)
    x = model.adapter(x)
    p = sum(p.numel() for p in model.adapter.parameters())
    print(f"{'Conv1x1+BN':<20} | {'1x1':<10} | {str(in_shape):<18} | {str(tuple(x.shape)):<18} | {p:,}")

    # --- 3. CUSTOM CHANNEL ATTENTION INSPECTION ---
    print("\n" + "-"*80)
    print("MODULE 3: CUSTOM CHANNEL ATTENTION (Table 7)")
    print("-"*80)

    # Layer 1
    in_shape = tuple(x.shape)
    x = model.channel_att.layer1(x)
    p = sum(p.numel() for p in model.channel_att.layer1.parameters())
    print(f"{'Layer 1 (Conv+Relu)':<20} | {'5x5':<10} | {str(in_shape):<18} | {str(tuple(x.shape)):<18} | {p:,}")

    # Layer 2
    in_shape = tuple(x.shape)
    x = model.channel_att.layer2(x)
    p = sum(p.numel() for p in model.channel_att.layer2.parameters())
    print(f"{'Layer 2 (Conv+Pool)':<20} | {'5x5':<10} | {str(in_shape):<18} | {str(tuple(x.shape)):<18} | {p:,}")

    # Layer 3
    in_shape = tuple(x.shape)
    x = model.channel_att.layer3(x)
    p = sum(p.numel() for p in model.channel_att.layer3.parameters())
    print(f"{'Layer 3 (Conv+Relu)':<20} | {'3x3':<10} | {str(in_shape):<18} | {str(tuple(x.shape)):<18} | {p:,}")

    # Layer 4
    in_shape = tuple(x.shape)
    x = model.channel_att.layer4(x)
    p = sum(p.numel() for p in model.channel_att.layer4.parameters())
    print(f"{'Layer 4 (Conv+Sigm)':<20} | {'3x3':<10} | {str(in_shape):<18} | {str(tuple(x.shape)):<18} | {p:,}")

    # --- 4. SPATIAL ATTENTION INSPECTION ---
    print("\n" + "-"*80)
    print("MODULE 4: SPATIAL ATTENTION (Table 8)")
    print("-"*80)
    in_shape = tuple(x.shape)
    x = model.spatial_att(x)
    p = sum(p.numel() for p in model.spatial_att.parameters())
    print(f"{'Spatial Map (7x7)':<20} | {'7x7':<10} | {str(in_shape):<18} | {str(tuple(x.shape)):<18} | {p:,}")

    # --- 5. CLASSIFICATION HEAD INSPECTION ---
    print("\n" + "-"*80)
    print("MODULE 5: KAN CLASSIFIER")
    print("-"*80)

    # Pool/Drop
    in_shape = tuple(x.shape)
    x = model.drop(model.pool(x).flatten(1))
    print(f"{'GlobalPool+Drop':<20} | {'-':<10} | {str(in_shape):<18} | {str(tuple(x.shape)):<18} | 0")

    # KAN
    in_shape = tuple(x.shape)
    x = model.kan(x)
    p = sum(p.numel() for p in model.kan.parameters())
    print(f"{'KAN Linear':<20} | {'Spline':<10} | {str(in_shape):<18} | {str(tuple(x.shape)):<18} | {p:,}")

    print("="*80 + "\n")

# Run it
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = LMD_KANet(num_classes=4).to(device)
    inspect_modules_separately(model)


 LMD-KANet: MODULAR BREAKDOWN & SUMMARY

--------------------------------------------------------------------------------
MODULE 1: HYBRID BACKBONE (EffiMobile)
--------------------------------------------------------------------------------
Sub-Layer            | Kernel     | Input              | Output             | Params    
Stem                 | 3x3        | (1, 3, 224, 224)   | (1, 24, 112, 112)  | 696
HybridMBConv 1       | 3x3        | (1, 24, 112, 112)  | (1, 24, 112, 112)  | 1,206
HybridMBConv 2       | 3x3        | (1, 24, 112, 112)  | (1, 40, 56, 56)    | 12,200
HybridMBConv 3       | 3x3        | (1, 40, 56, 56)    | (1, 40, 56, 56)    | 27,960
HybridMBConv 4       | 5x5        | (1, 40, 56, 56)    | (1, 80, 28, 28)    | 65,020
HybridMBConv 5       | 5x5        | (1, 80, 28, 28)    | (1, 80, 28, 28)    | 206,680
HybridMBConv 6       | 5x5        | (1, 80, 28, 28)    | (1, 80, 28, 28)    | 206,680
HybridMBConv 7       | 3x3        | (1, 80, 28, 28)    | (1, 112, 28, 28)  