In [None]:
#default_exp layers

# Layers

> Utilities for creating torch Modules for self supervised learning.

In [None]:
#export
from fastai.vision.all import *
import timm

In [None]:
# export
# https://github.com/rwightman/pytorch-image-models/blob/3a7aa95f7e5fc90a6a2683c756e854e26201d82e/timm/models/layers/adaptive_avgmax_pool.py#L79
mk_class('PoolType', **{o:o.lower() for o in ['Fast', 'Avg', 'AvgMax', 'CatAvgMax', 'Max']},
         doc="All possible pooling types as attributes to get tab-completion and typo-proofing")

In [None]:
#export
def create_fastai_encoder(arch:str, pretrained=True, n_in=3, pool_type=PoolType.CatAvgMax):
    "Create timm encoder from a given arch backbone"
    encoder = create_body(arch, n_in, pretrained, cut=None)
    pool = AdaptiveConcatPool2d() if pool_type == "catavgmax" else nn.AdaptiveAvgPool2d(1)
    return nn.Sequential(*encoder, pool, Flatten())

def create_timm_encoder(arch:str, pretrained=True, n_in=3, pool_type=PoolType.CatAvgMax):
    "Creates a body from any model in the `timm` library. If pool_type is None then it uses timm default"
    if ('vit' in arch) or (pool_type is None):
        model = timm.create_model(arch, pretrained=pretrained, in_chans=n_in, num_classes=0)
    else:
        model = timm.create_model(arch, pretrained=pretrained, in_chans=n_in, num_classes=0, global_pool=pool_type)
    return model

def create_encoder(arch:str, pretrained=True, n_in=3, pool_type=PoolType.CatAvgMax):
    "A utility for creating encoder without specifying the package"
    if arch in globals(): return create_fastai_encoder(globals()[arch], pretrained, n_in, pool_type)
    else:                 return create_timm_encoder(arch, pretrained, n_in, pool_type)

In [None]:
inp = torch.randn((1,3,384,384))

Fastai encoder expects a function as it's first argument, where timm expects a string. Also, fastai defaults to concat pooling, aka `catavgmax` in timm. With timm's selective pooling any `PoolType` can used. Experiments show that concat pooling is better on average so it is set as our default.

For any other `pool_type` fastai uses `AdaptiveAvgPool2d`, for timm you can choose from the remaining `PoolType`.

In [None]:
fastai_encoder = create_fastai_encoder(xresnet34)
out = fastai_encoder(inp); out.shape

torch.Size([1, 1024])

In [None]:
fastai_encoder = create_fastai_encoder(xresnet34, pool_type=False)
out = fastai_encoder(inp); out.shape

torch.Size([1, 512])

In [None]:
model = create_timm_encoder("tf_efficientnet_b0_ns", pretrained=False)
out = model(inp); out.shape

torch.Size([1, 2560])

In [None]:
model = create_timm_encoder("tf_efficientnet_b0_ns", pretrained=False, pool_type=PoolType.Avg)
out = model(inp); out.shape

torch.Size([1, 1280])

In [None]:
model = create_encoder("xresnet34", pretrained=False, pool_type=PoolType.Avg)
out = model(inp); out.shape

torch.Size([1, 512])

In [None]:
model = create_encoder("tf_efficientnet_b0_ns", pretrained=False, pool_type=PoolType.Avg)
out = model(inp); out.shape

torch.Size([1, 1280])

Vision Transformer is a special case which uses `Layernorm`.

In [None]:
vit_model = create_timm_encoder("vit_large_patch16_384", pretrained=False)
out = vit_model(inp); out.shape

torch.Size([1, 1024])

In [None]:
#export
def create_mlp_module(dim,hidden_size,projection_size,bn=False,nlayers=2):
    "MLP module as described in papers, used as projection layer"
    l = []
    for i in range(nlayers-1):
        l += [nn.Linear(dim, hidden_size) if i == 0 else nn.Linear(hidden_size, hidden_size)] 
        if bn: l += [nn.BatchNorm1d(hidden_size)]
        l += [nn.ReLU(inplace=True)]
    ls = l + [nn.Linear(hidden_size, projection_size)]
    return nn.Sequential(*ls)

In [None]:
#SimCLR
create_mlp_module(1024,4096,128)

Sequential(
  (0): Linear(in_features=1024, out_features=4096, bias=True)
  (1): ReLU(inplace=True)
  (2): Linear(in_features=4096, out_features=128, bias=True)
)

In [None]:
#SimCLR-v2
create_mlp_module(1024,4096,128,nlayers=3)

Sequential(
  (0): Linear(in_features=1024, out_features=4096, bias=True)
  (1): ReLU(inplace=True)
  (2): Linear(in_features=4096, out_features=4096, bias=True)
  (3): ReLU(inplace=True)
  (4): Linear(in_features=4096, out_features=128, bias=True)
)

In [None]:
#BYOL
create_mlp_module(1024,4096,128,bn=True)

Sequential(
  (0): Linear(in_features=1024, out_features=4096, bias=True)
  (1): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): Linear(in_features=4096, out_features=128, bias=True)
)

In [None]:
#SWAV
create_mlp_module(1024,4096,128,bn=True,nlayers=3)

Sequential(
  (0): Linear(in_features=1024, out_features=4096, bias=True)
  (1): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): Linear(in_features=4096, out_features=4096, bias=True)
  (4): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): ReLU(inplace=True)
  (6): Linear(in_features=4096, out_features=128, bias=True)
)

In [None]:
#export
def create_cls_module(nf, n_out, lin_ftrs=None, ps=0.5, first_bn=True, bn_final=False, lin_first=False, y_range=None):
    "Creates classification layer which takes nf flatten features and outputs n_out logits"
    lin_ftrs = [nf, 512, n_out] if lin_ftrs is None else [nf] + lin_ftrs + [n_out]
    bns = [first_bn] + [True]*len(lin_ftrs[1:])
    ps = L(ps)
    if len(ps) == 1: ps = [ps[0]/2] * (len(lin_ftrs)-2) + ps
    actns = [nn.ReLU(inplace=True)] * (len(lin_ftrs)-2) + [None]
    layers = []
    if lin_first: layers.append(nn.Dropout(ps.pop(0)))
    for ni,no,bn,p,actn in zip(lin_ftrs[:-1], lin_ftrs[1:], bns, ps, actns):
        layers += LinBnDrop(ni, no, bn=bn, p=p, act=actn, lin_first=lin_first)
    if lin_first: layers.append(nn.Linear(lin_ftrs[-2], n_out))
    if bn_final: layers.append(nn.BatchNorm1d(lin_ftrs[-1], momentum=0.01))
    if y_range is not None: layers.append(SigmoidRange(*y_range))
    return nn.Sequential(*layers)

In [None]:
inp = torch.randn((2,3,384,384))

In [None]:
encoder = create_encoder("xresnet34", pretrained=False)
out = encoder(inp) 
classifier = create_cls_module(out.size(-1), n_out=5, first_bn=False)
model = nn.Sequential(encoder, classifier)

In [None]:
with torch.no_grad(): print(model(inp))

tensor([[-0.0195, -0.2775, -0.7257,  0.8391, -0.1943],
        [-0.2854, -0.0407,  1.4847, -0.3034,  0.3028]])


In [None]:
encoder = create_encoder("vit_large_patch16_384", pretrained=False)
out = encoder(inp) 
classifier = create_cls_module(out.size(-1), n_out=5, first_bn=False)
model = nn.Sequential(encoder, classifier)

In [None]:
with torch.no_grad(): print(model(inp))

tensor([[-0.1113,  1.4554, -0.0675, -1.2252, -0.6768],
        [-0.6076, -0.1960,  1.1632,  1.0209,  0.3946]])


## Export -

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00-utils.ipynb.
Converted 01-augmentations.ipynb.
Converted 02-layers.ipynb.
Converted 10-simclr.ipynb.
Converted 10b-simclr_v2.ipynb.
Converted 20-byol.ipynb.
Converted 30-swav.ipynb.
Converted index.ipynb.
