In [1]:
#! Consider using efficient net v2 instead of classic v0.

In [2]:
#Model from HF is probably small overkill here and it would require more hacking than plain torch
#config = transformers.AutoConfig.from_pretrained("google/efficientnet-b0")

In [3]:
from torchvision import models
import torch.nn as nn
import torch

from torchvision.ops.misc import Conv2dNormActivation

from pytorch_lightning.utilities.model_summary import ModelSummary
import lightning as L
import torch.nn.functional as F



In [6]:
import numpy as np

b = np.ones((10,10))

In [77]:
#https://medium.com/vitrox-publication/understanding-circle-loss-bdaa576312f7
class EfficientWordNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        N_BLOCKS_FROM_EFFNET = 5
        features_blocks = models.efficientnet_b0().features
        self.efnet_part = nn.Sequential(*[features_blocks[i] for i in range(N_BLOCKS_FROM_EFFNET)])

        # Our input is one channel spectrogram
        norm_layer = None 
        firstconv_output_channels = 32
        self.efnet_part[0] = Conv2dNormActivation(
                1, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, activation_layer=nn.SiLU
            )
        
        #They claim that input is (1,98,64) spectrogram (98 filterbanks and 64 time steps)
        # Output from 4th block of efficient net yields (80, 7, 4) Quite small image, but they use more convolutions 
        self.our_part = nn.Sequential(
            nn.Conv2d(80, 32, 3, padding="same"),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2), # They're not using ReLU in original implementation. This is actually okay with just max Pooling
            nn.Conv2d(32,64,3,padding="same"),
            nn.Flatten(), # Output at this stage is 384
            nn.Linear(384, 128),
        )

    def forward(self, X):
        X = self.efnet_part(X)
        X = self.our_part(X)
        return F.normalize(X) # We force this embeddings to line in a unit hypersphere




In [78]:
model = EfficientWordNetModule()

In [76]:
ModelSummary(EfficientWordNetModule(), max_depth=-1)

    | Name                                    | Type                 | Params
-----------------------------------------------------------------------------------
0   | efnet_part                              | Sequential           | 308 K 
1   | efnet_part.0                            | Conv2dNormActivation | 320   
2   | efnet_part.0.0                          | Conv2d               | 320   
3   | efnet_part.0.1                          | SiLU                 | 0     
4   | efnet_part.1                            | Sequential           | 1.4 K 
5   | efnet_part.1.0                          | MBConv               | 1.4 K 
6   | efnet_part.1.0.block                    | Sequential           | 1.4 K 
7   | efnet_part.1.0.block.0                  | Conv2dNormActivation | 352   
8   | efnet_part.1.0.block.0.0                | Conv2d               | 288   
9   | efnet_part.1.0.block.0.1                | BatchNorm2d          | 64    
10  | efnet_part.1.0.block.0.2                | SiLU      

In [1]:
from art.enums import TrainingStage

In [2]:
a = TrainingStage.TEST

In [3]:
a  == TrainingStage.TEST

True