In [None]:
#default_exp layers

# Layers

> Utilities for creating torch Modules for self supervised learning.

In [2]:
#export
from fastai.vision.all import *
import timm

In [3]:
# export
# https://github.com/rwightman/pytorch-image-models/blob/3a7aa95f7e5fc90a6a2683c756e854e26201d82e/timm/models/layers/adaptive_avgmax_pool.py#L79
mk_class('PoolType', **{o:o.lower() for o in ['Fast', 'Avg', 'AvgMax', 'CatAvgMax', 'Max']},
         doc="All possible resize method as attributes to get tab-completion and typo-proofing")

In [4]:
#export
def create_fastai_encoder(arch, n_in=3, pretrained=True, cut=None, concat_pool=True):
    "Create timm encoder from a given arch backbone"
    encoder = create_body(arch, n_in, pretrained, cut)
    pool = AdaptiveConcatPool2d() if concat_pool else nn.AdaptiveAvgPool2d(1)
    return nn.Sequential(*encoder, pool, Flatten())

def create_timm_encoder(arch:str, pretrained=True, n_in=3, pool_type=None):
    "Creates a body from any model in the `timm` library. If pool_type is None then it uses timm default"
    if ('vit' in arch) or (pool_type is None):
        model = timm.create_model(arch, pretrained=pretrained, in_chans=n_in, num_classes=0)
    else:
        model = timm.create_model(arch, pretrained=pretrained, in_chans=n_in, num_classes=0, global_pool=pool_type)
    return model

In [5]:
inp = torch.randn((1,3,384,384))

In [6]:
model = create_timm_encoder("tf_efficientnet_b0_ns", pretrained=False); model.global_pool

SelectAdaptivePool2d (pool_type=avg, flatten=True)

In [7]:
out = model(inp); out.shape

torch.Size([1, 1280])

In [8]:
model = create_timm_encoder("tf_efficientnet_b0_ns", pretrained=False, pool_type=PoolType.CatAvgMax); model.global_pool

SelectAdaptivePool2d (pool_type=catavgmax, flatten=True)

In [9]:
out = model(inp); out.shape

torch.Size([1, 2560])

In [10]:
# vision transformer doesn't have global pool
vit_model = create_timm_encoder("vit_large_patch16_384", pretrained=False)

In [11]:
out = vit_model(inp); out.shape

torch.Size([1, 1024])

In [12]:
create_head(1024, 256, lin_ftrs=[2048], ps=0., first_bn=False)

Sequential(
  (0): AdaptiveConcatPool2d(
    (ap): AdaptiveAvgPool2d(output_size=1)
    (mp): AdaptiveMaxPool2d(output_size=1)
  )
  (1): Flatten(full=False)
  (2): Linear(in_features=2048, out_features=2048, bias=True)
  (3): ReLU(inplace=True)
  (4): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): Linear(in_features=2048, out_features=256, bias=False)
)

In [13]:
#export
def create_mlp_module(dim,hidden_size,projection_size,bn=False,nlayers=2):
    "MLP module as described in papers, used as projection layer"
    l = []
    for i in range(nlayers-1):
        l += [nn.Linear(dim, hidden_size) if i == 0 else nn.Linear(hidden_size, hidden_size)] 
        if bn: l += [nn.BatchNorm1d(hidden_size)]
        l += [nn.ReLU(inplace=True)]
    ls = l + [nn.Linear(hidden_size, projection_size)]
    return nn.Sequential(*ls)

In [14]:
#SimCLR
create_mlp_module(1024,4096,128)

Sequential(
  (0): Linear(in_features=1024, out_features=4096, bias=True)
  (1): ReLU(inplace=True)
  (2): Linear(in_features=4096, out_features=128, bias=True)
)

In [15]:
#SimCLR-v2
create_mlp_module(1024,4096,128,nlayers=3)

Sequential(
  (0): Linear(in_features=1024, out_features=4096, bias=True)
  (1): ReLU(inplace=True)
  (2): Linear(in_features=4096, out_features=4096, bias=True)
  (3): ReLU(inplace=True)
  (4): Linear(in_features=4096, out_features=128, bias=True)
)

In [16]:
#BYOL
create_mlp_module(1024,4096,128,bn=True)

Sequential(
  (0): Linear(in_features=1024, out_features=4096, bias=True)
  (1): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): Linear(in_features=4096, out_features=128, bias=True)
)

In [17]:
#SWAV
create_mlp_module(1024,4096,128,bn=True,nlayers=3)

Sequential(
  (0): Linear(in_features=1024, out_features=4096, bias=True)
  (1): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): Linear(in_features=4096, out_features=4096, bias=True)
  (4): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): ReLU(inplace=True)
  (6): Linear(in_features=4096, out_features=128, bias=True)
)

In [18]:
#export
def create_cls_module(nf, n_out, lin_ftrs=None, ps=0.5, first_bn=True, bn_final=False, lin_first=False, y_range=None):
    "Creates classification layer which takes nf flatten features and outputs n_out logits"
    lin_ftrs = [nf, 512, n_out] if lin_ftrs is None else [nf] + lin_ftrs + [n_out]
    bns = [first_bn] + [True]*len(lin_ftrs[1:])
    ps = L(ps)
    if len(ps) == 1: ps = [ps[0]/2] * (len(lin_ftrs)-2) + ps
    actns = [nn.ReLU(inplace=True)] * (len(lin_ftrs)-2) + [None]
    layers = []
    if lin_first: layers.append(nn.Dropout(ps.pop(0)))
    for ni,no,bn,p,actn in zip(lin_ftrs[:-1], lin_ftrs[1:], bns, ps, actns):
        layers += LinBnDrop(ni, no, bn=bn, p=p, act=actn, lin_first=lin_first)
    if lin_first: layers.append(nn.Linear(lin_ftrs[-2], n_out))
    if bn_final: layers.append(nn.BatchNorm1d(lin_ftrs[-1], momentum=0.01))
    if y_range is not None: layers.append(SigmoidRange(*y_range))
    return nn.Sequential(*layers)

In [40]:
inp = torch.randn((2,3,384,384))

In [45]:
encoder = create_timm_encoder("vit_large_patch16_384", pretrained=False)
out = encoder(inp)
nf = out.size(-1)
classifier = create_cls_module(nf, n_out=5, first_bn=False)
model = nn.Sequential(encoder, classifier)

In [48]:
with torch.no_grad(): print(model(inp))

tensor([[-0.7594,  1.7426, -0.1524,  0.0806,  0.2837],
        [-0.0398, -1.1073,  0.3935,  0.3417, -0.1568]])


## Export -

In [1]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00-utils.ipynb.
Converted 01-augmentations.ipynb.
Converted 02-layers.ipynb.
Converted 10-simclr.ipynb.
Converted 10b-simclr_v2.ipynb.
Converted 20-byol.ipynb.
Converted 30-swav.ipynb.
Converted index.ipynb.
