In [3]:
# test Mlp 
import torch
import torch.nn as nn
from timm.layers.helpers import to_2tuple
from functools import partial
from thop import profile,clever_format 

In [2]:
class Mlp(nn.Module):
    def __init__(
            self,
            in_features,
            hidden_features=None,
            out_features=None,
            act_layer=nn.GELU,
            norm_layer=None,
            bias=True,
            drop=0.,
            use_conv=False,
    ):
        super().__init__()
        self.use_conv=use_conv
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        bias = to_2tuple(bias)
        drop_probs = to_2tuple(drop)
        linear_layer = partial(nn.Conv1d, kernel_size=1) if use_conv else nn.Linear

        self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop_probs[0])
        self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
        self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
        self.drop2 = nn.Dropout(drop_probs[1])

    def forward(self, x):
        if self.use_conv:
            x = x.permute(0, -1, -2)
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.norm(x)
        x = self.fc2(x)
        x = self.drop2(x)
        if self.use_conv:
            x = x.permute(0, -1, -2)
        return x

In [5]:
model_conv=Mlp(512, 2048, use_conv=True)
model=Mlp(512, 2048)

In [7]:
img = torch.rand((1, 49, 512))
f, p = profile(model, inputs=(img,))
clever_format([f, p], '.%3f')

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.


('.102.760448M', '.2.099712M')

In [8]:
f, p = profile(model_conv, inputs=(img,))
clever_format([f, p], '.%3f')

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv1d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.


('.102.760448M', '.2.099712M')

In [None]:
from .CAE import ClassAttentionBlock