<a href="https://colab.research.google.com/github/Adithyan-mp/EfficientNet/blob/main/EfficientNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [65]:
import torch
import math
import torch.nn as nn
import torch.nn.functional as F

In [66]:
def round_to_8(in_ch):
  return int(math.ceil(in_ch/8) * 8)

SiLU

In [67]:
class Swish(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x * torch.sigmoid(x)


SE BLOCK

In [68]:
class SEBlock(nn.Module):
  def __init__(self,in_ch,reduction=16):
    super().__init__()
    self.fc1 = nn.Linear(in_ch,in_ch//reduction)
    self.fc2 = nn.Linear(in_ch//reduction,in_ch)

  def forward(self,x):
    b,c,_,_= x.size()
    y = F.adaptive_avg_pool2d(x,1).view(b,c)
    y = F.silu(self.fc1(y))
    y = torch.sigmoid(self.fc2(y))
    return x * y.view(b,c,1,1)

MBConv Block

In [69]:
class MBConv(nn.Module):
  def __init__(self,in_ch,out_ch,stride,expand_ratio):
    super().__init__()
    hidden_ch = in_ch * expand_ratio
    self.use_residual = (stride == 1 and in_ch == out_ch)

    layers = []

    #expansion
    if expand_ratio != 1:
      layers += [
          nn.Conv2d(in_ch,hidden_ch,1,bias=False),
          nn.BatchNorm2d(hidden_ch),
          Swish()
      ]

    #Depthwise convolution
    layers += [
        nn.Conv2d(hidden_ch,hidden_ch,3,stride,
                  padding=1,groups=hidden_ch,bias=False),
        nn.BatchNorm2d(hidden_ch),
        Swish()

    ]

    #SE block
    layers.append(SEBlock(hidden_ch))

    #projection
    layers += [
        nn.Conv2d(hidden_ch, out_ch, 1, bias=False),
        nn.BatchNorm2d(out_ch)
    ]

    self.block = nn.Sequential(*layers)

  def forward(self,x):
    if self.use_residual:
      return x + self.block(x)
    return self.block(x)


In [70]:
BASE_CONFIG = [
    #expand,out_ch,repeats,stride
    (1,  16, 1, 1),
    (6,  24, 2, 2),
    (6,  40, 2, 2),
    (6,  80, 3, 2),
    (6, 112, 3, 1),
    (6, 192, 4, 2),
    (6, 320, 1, 1),
]

In [71]:
class EfficientNet(nn.Module):
    def __init__(self, phi=0, num_classes=1000):
        super().__init__()

        alpha, beta = 1.2, 1.1
        depth_mult = alpha ** phi
        width_mult = beta ** phi

        # ---------- Stem ----------
        stem_out = round_to_8(32 * width_mult)
        self.stem = nn.Sequential(
            nn.Conv2d(3, stem_out, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(stem_out),
            Swish()
        )

        # ---------- Blocks ----------
        in_ch = stem_out
        blocks = []

        for expand, base_out_ch, base_reps, stride in BASE_CONFIG:
            out_ch = round_to_8(base_out_ch * width_mult)
            reps = int(math.ceil(base_reps * depth_mult))

            for i in range(reps):
                blocks.append(
                    MBConv(
                        in_ch=in_ch,
                        out_ch=out_ch,
                        stride=stride if i == 0 else 1,
                        expand_ratio=expand
                    )
                )
                in_ch = out_ch

        self.blocks = nn.Sequential(*blocks)

        # ---------- Head ----------
        head_ch = round_to_8(1280 * width_mult)
        self.head = nn.Sequential(
            nn.Conv2d(in_ch, head_ch, 1, bias=False),
            nn.BatchNorm2d(head_ch),
            Swish(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(head_ch, num_classes)
        )

    def forward(self, x):
        x = self.stem(x)
        x = self.blocks(x)
        return self.head(x)


In [72]:
def main():
    # -------- Device --------
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # -------- Model --------
    phi = 1                 # 0=B0, 1=B1, 2=B2, ...
    num_classes = 10        # change for your dataset

    model = EfficientNet(phi=phi, num_classes=num_classes).to(device)
    model.eval()

    # -------- Dummy Input --------
    input_size = 224        # resolution for B0
    x = torch.randn(1, 3, input_size, input_size).to(device)

    # -------- Forward Pass --------
    with torch.no_grad():
        y = model(x)

    print("Output shape:", y.shape)

    # -------- Parameter Count --------
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")

if __name__ == "__main__":
    main()

Using device: cpu
Output shape: torch.Size([1, 10])
Total parameters: 8,678,997
Trainable parameters: 8,678,997
