In [129]:
import torch
from torch import nn
from torchvision import models, ops
from functools import partial

In [130]:
conv2d_norm_activation = ops.Conv2dNormActivation(in_channels=16, 
                                                  out_channels=32, 
                                                  norm_layer=partial(nn.BatchNorm2d, eps=0.001),
                                                  activation_layer=nn.SiLU)

In [131]:
model = models.efficientnet_v2_l().features
model[0] = conv2d_norm_activation
model = model.to('mps')

In [132]:
model

Sequential(
  (0): Conv2dNormActivation(
    (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    (2): SiLU(inplace=True)
  )
  (1): Sequential(
    (0): FusedMBConv(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (2): SiLU(inplace=True)
        )
      )
      (stochastic_depth): StochasticDepth(p=0.0, mode=row)
    )
    (1): FusedMBConv(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          (2): SiLU(inplace=True)
        )
      )
      (stochastic_d

In [133]:
input = torch.randn(16, 16, 256, 256).to('mps')

In [134]:
x = model[:2](input)
x = model[2:3](x)
x = model[3:4](x)
x = model[4:5](x)
x = model[5:7](x)

x.shape

RuntimeError: MPS backend out of memory (MPS allocated: 26.90 GB, other allocations: 9.32 GB, max allowed: 36.27 GB). Tried to allocate 84.00 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

In [None]:
# 3 : 128
# 4 : 64
# 6 : 32
# 8 : 16 > bottleneck