In [2]:
import torch
from torchvision import models
import torchvision

In [3]:
def flatten_effnet(model):
    ''' flatten out the model into individual layers '''

    flattened = []
    def dfs_layer(layer):
        flat = []
        # check if layer has children
        if len([x for x in layer.children()]) > 0:
            for sublayer in layer.children():
                children = [x for x in sublayer.children()]
                if type(sublayer) == torchvision.ops.misc.SqueezeExcitation:
                    flat.append(children[0])  # avgpool
                    flat.append(children[1])  # conv1
                    flat.append(children[3])  # silu
                    flat.append(children[2])  # conv2
                    flat.append(children[4])  # sigmoid
                else:
                    if len(children) > 0:
                        flat += dfs_layer(sublayer)
                    else:
                        flat.append(sublayer)

                    # check if mbconv uses residual connection
                    if type(sublayer) == torchvision.models.efficientnet.MBConv:
                        if sublayer.use_res_connect:
                            flat.insert(-1, 'MBConvResidual')
        else:
            flat.append(layer)
        return flat

    for i, layer in enumerate(model.children()):
        flattened += dfs_layer(layer)
    
    return flattened

In [4]:
# Create effnet model and run an example load
model = models.efficientnet_b0(pretrained=True)
model.eval()

t = torch.ones((1, 3, 224, 224))
out_ref = model(t)

In [5]:
# create flattened model
flattened = flatten_effnet(model)
for layer in flattened[0:5]:
    print(layer) 

Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
SiLU(inplace=True)
Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)


In [6]:
# Run the example load through flattened model, layer by layer
x = t.clone()
outs = [x]
for i in range(len(flattened)):
    # check if layer is a residual layer
    if flattened[i] != 'MBConvResidual':
        
        # if it's a linear layer, need to flatten input
        if type(flattened[i]) == torch.nn.Linear:
            x = torch.flatten(x, -3, -1)

        # run input through current layer
        x = flattened[i](x)
        outs.append(x)

        # if it's sigmoid layer, need to use it as weight in sqeueeze
        # excitation block with previous layer
        if type(flattened[i]) == torch.nn.modules.activation.Sigmoid:
            prev = outs[i-5]
            x = x * prev
            
    else:
        # if it's a residual layer, do it
        x = x + outs[i-13]
        outs.append(x)
out = x

In [7]:
# check if model output equals flattened model output
torch.max(abs(out_ref - out)).item()

0.0

In [9]:
for layer in flattened:
    if type(layer) == torch.nn.Conv2d:
        print(layer)

Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=96, bias=False)
Conv2d(96, 4, kernel_size=(1, 1), stride=(1, 1))
Conv2d(4, 96, kernel_size=(1, 1), stride=(1, 1))
Conv2d(96, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)
Conv2d(144, 144, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=144, bias=False)
Conv2d(144, 6, kernel_size=(1, 1), stride=(1, 1))
Conv2d(6, 144, kernel_size=(1, 1), stride=(1, 1))
Conv2d(144, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
Conv2d(24, 144, kernel_size=(1, 1), stride=(1