In [1]:
#default_exp layers

# Layers

> Utilities for creating torch Modules for self supervised learning.

In [2]:
#export
from fastai.vision.all import *
import timm

In [3]:
# export
# https://github.com/rwightman/pytorch-image-models/blob/3a7aa95f7e5fc90a6a2683c756e854e26201d82e/timm/models/layers/adaptive_avgmax_pool.py#L79
mk_class('PoolType', **{o:o.lower() for o in ['Fast', 'Avg', 'AvgMax', 'CatAvgMax', 'Max']},
         doc="All possible resize method as attributes to get tab-completion and typo-proofing")

In [4]:
#export
def create_fastai_encoder(arch, n_in=3, pretrained=True, cut=None, concat_pool=True):
    "Create timm encoder from a given arch backbone"
    encoder = create_body(arch, n_in, pretrained, cut)
    pool = AdaptiveConcatPool2d() if concat_pool else nn.AdaptiveAvgPool2d(1)
    return nn.Sequential(*encoder, pool, Flatten())

def create_timm_encoder(arch:str, pretrained=True, cut=None, n_in=3, pool_type=None):
    "Creates a body from any model in the `timm` library. If pool_type is None then it uses timm default"
    if ('vit' in arch) or (pool_type is None):
        model = timm.create_model(arch, pretrained=pretrained, in_chans=n_in, num_classes=0)
    else:
        model = timm.create_model(arch, pretrained=pretrained, in_chans=n_in, num_classes=0, global_pool=pool_type)
    return model

In [5]:
inp = torch.randn((1,3,384,384))

In [7]:
model = create_timm_encoder("tf_efficientnet_b0_ns", pretrained=False); model.global_pool

SelectAdaptivePool2d (pool_type=avg, flatten=True)

In [8]:
out = model(inp); out.shape

torch.Size([1, 1280])

In [9]:
model = create_timm_encoder("tf_efficientnet_b0_ns", pretrained=False, pool_type=PoolType.CatAvgMax); model.global_pool

SelectAdaptivePool2d (pool_type=catavgmax, flatten=True)

In [10]:
out = model(inp); out.shape

torch.Size([1, 2560])

In [11]:
# vision transformer doesn't have global pool
vit_model = create_timm_encoder("vit_large_patch16_384", pretrained=False)

In [12]:
out = vit_model(inp); out.shape

torch.Size([1, 1024])

In [17]:
create_head(1024, 256, lin_ftrs=[2048], ps=0., first_bn=False)

Sequential(
  (0): AdaptiveConcatPool2d(
    (ap): AdaptiveAvgPool2d(output_size=1)
    (mp): AdaptiveMaxPool2d(output_size=1)
  )
  (1): Flatten(full=False)
  (2): Linear(in_features=2048, out_features=2048, bias=True)
  (3): ReLU(inplace=True)
  (4): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): Linear(in_features=2048, out_features=256, bias=False)
)

In [58]:
#export
def create_mlp_module(dim,hidden_size,projection_size,bn=False,nlayers=2):
    "MLP module as described in papers, used as projection layer"
    l = [nn.Linear(dim, hidden_size)] 
    if bn: l += [nn.BatchNorm1d(hidden_size)]
    l += [nn.ReLU(inplace=True)]
    ls = l*(nlayers-1) + [nn.Linear(hidden_size, projection_size)]
    return nn.Sequential(*ls)

In [61]:
#SimCLR
create_mlp_module(1024,4096,128)

Sequential(
  (0): Linear(in_features=1024, out_features=4096, bias=True)
  (1): ReLU(inplace=True)
  (2): Linear(in_features=4096, out_features=128, bias=True)
)

In [65]:
#SimCLR-v2
create_mlp_module(1024,4096,128,nlayers=3)

Sequential(
  (0): Linear(in_features=1024, out_features=4096, bias=True)
  (1): ReLU(inplace=True)
  (2): Linear(in_features=1024, out_features=4096, bias=True)
  (3): ReLU(inplace=True)
  (4): Linear(in_features=4096, out_features=128, bias=True)
)

In [62]:
#BYOL
create_mlp_module(1024,4096,128,bn=True)

Sequential(
  (0): Linear(in_features=1024, out_features=4096, bias=True)
  (1): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): Linear(in_features=4096, out_features=128, bias=True)
)

In [63]:
#SWAV
create_mlp_module(1024,4096,128,bn=True)

Sequential(
  (0): Linear(in_features=1024, out_features=4096, bias=True)
  (1): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): Linear(in_features=4096, out_features=128, bias=True)
)

## Export -

In [64]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00-utils.ipynb.
Converted 01-augmentations.ipynb.
Converted 02-layers.ipynb.
Converted 10-simclr.ipynb.
Converted 10b-simclr_v2.ipynb.
Converted 20-byol.ipynb.
Converted 30-swav.ipynb.
Converted index.ipynb.
