# network

> This module contains the networks that are used in this project and a function to retrieve them for training.

In [None]:
#| default_exp network

In [None]:
#| hide
from nbdev.showdoc import *
from fastcore.test import *
from fastcore.utils import *

In [None]:
#| export
from typing import Union, BinaryIO, IO
from os import PathLike

import torch
import torchvision
from torch.nn import Module

from birdclef.dataset import get_dataloader


## The models

In [None]:
#| export
class EfficientNetV2(torch.nn.Module):
    def __init__(self, num_classes=264, size='s'):
        super().__init__()

        if size=='s':
            self.efficientnet_v2 = torchvision.models.efficientnet_v2_s(weights=None, progress=True, num_classes=num_classes)
        elif size=='m':
            self.efficientnet_v2 = torchvision.models.efficientnet_v2_m(weights=None, progress=True, num_classes=num_classes)
        else:
            self.efficientnet_v2 = torchvision.models.efficientnet_v2_l(weights=None, progress=True, num_classes=num_classes)

        self.init_conv = torch.nn.Conv2d(1, 3, (3,3), padding="same")
        #self.sigmoid = torch.nn.functional.sigmoid

    def forward(self, x):
        x = self.init_conv(x)
        x = self.efficientnet_v2(x)

        return x

Let's check that the network works

In [None]:
#| eval:false
dl = get_dataloader('train_simple')
model = EfficientNetV2(num_classes=3)
batch = next(iter(dl))
model(batch[0])

tensor([[-1.9340, -3.4938,  1.2743]], grad_fn=<AddmmBackward0>)

## Handling models

As for the datasets and dataloaders, also in this case we need a way to retrieve the created models.

In [None]:
#| export
model_dict = {
        'efficient_net_v2_s': (EfficientNetV2, {}),
        'efficient_net_v2_m': (EfficientNetV2, {'size':'m'}),
        'efficient_net_v2_l': (EfficientNetV2, {'size':'l'}),
        }

def get_model(model_key:str, # A key of the model dictionary
              weights_path:Union[str, PathLike, BinaryIO, IO[bytes]] = None,   # A file-like object to the model weights
              num_classes:int = 264,  # Number of classes to predict
              )->Module:      # A pytorch model
    "A getter method to retrieve the wanted (possibly pretrained) model"
    assert model_key in model_dict, f'{model_key} is not an existing network, choose one from {model_dict.keys()}.'
    
    net_class, kwargs = model_dict[model_key]
    model = net_class(num_classes=num_classes, **kwargs)

    if weights_path is not None:
        model.load_state_dict(torch.load(weights_path))

    return model

In [None]:
#|echo: false
print("The existing keys are:\n" + "\n".join([k for k in model_dict.keys()]))

test_eq(len(model_dict.keys()), 3)

The existing keys are:
simple_efficient_net_v2_s
efficient_net_v2_s


Let's see how it works getting a model

In [None]:
#| eval:false
model = get_model('efficient_net_v2_s', num_classes=3)
model(batch[0])

tensor([[-0.4579,  8.3441,  1.5603]], grad_fn=<AddmmBackward0>)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()