In [1]:
#default_exp layers

# Layers

> Utilities for creating torch Modules for self supervised learning.

In [26]:
#export
from fastai.vision.all import *
from fastai.vision.learner import _update_first_layer
import timm

In [27]:
# export
# https://github.com/rwightman/pytorch-image-models/blob/3a7aa95f7e5fc90a6a2683c756e854e26201d82e/timm/models/layers/adaptive_avgmax_pool.py#L79
mk_class('PoolType', **{o:o.lower() for o in ['Fast', 'Avg', 'AvgMax', 'CatAvgMax', 'Max']},
         doc="All possible resize method as attributes to get tab-completion and typo-proofing")

In [28]:
#export
def create_fastai_encoder(arch, n_in=3, pretrained=True, cut=None, concat_pool=True):
    "Create timm encoder from a given arch backbone"
    encoder = create_body(arch, n_in, pretrained, cut)
    pool = AdaptiveConcatPool2d() if concat_pool else nn.AdaptiveAvgPool2d(1)
    return nn.Sequential(*encoder, pool, Flatten())

def create_timm_encoder(arch:str, pretrained=True, cut=None, n_in=3, pool_type=None):
    "Creates a body from any model in the `timm` library. If pool_type is None then it uses timm default"
    if ('vit' in arch) or (pool_type is None):
        model = timm.create_model(arch, pretrained=pretrained, in_chans=n_in, num_classes=0)
    else:
        model = timm.create_model(arch, pretrained=pretrained, in_chans=n_in, num_classes=0, global_pool=pool_type)
    return model

In [29]:
inp = torch.randn((1,3,384,384))

In [34]:
model = create_timm_body("tf_efficientnet_b0_ns", pretrained=False); model.global_pool

SelectAdaptivePool2d (pool_type=avg, flatten=True)

In [38]:
out = model(inp); out.shape

torch.Size([1, 1280])

In [40]:
model = create_timm_body("tf_efficientnet_b0_ns", pretrained=False, pool_type=PoolType.CatAvgMax); model.global_pool

SelectAdaptivePool2d (pool_type=catavgmax, flatten=True)

In [41]:
out = model(inp); out.shape

torch.Size([1, 2560])

In [42]:
# vision transformer doesn't have global pool
vit_model = create_timm_body("vit_large_patch16_384", pretrained=False)

In [43]:
out = vit_model(inp); out.shape

torch.Size([1, 1024])

## Export -

In [53]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00-utils.ipynb.
Converted 01-augmentations.ipynb.
Converted 10-simclr.ipynb.
Converted 10b-simclr_v2.ipynb.
Converted 20-byol.ipynb.
Converted 30-swav.ipynb.
Converted index.ipynb.
