In [None]:
#hide
#skip
! [ -e /content ] && pip install -Uqq self-supervised  # upgrade self-supervised on colab

In [None]:
#default_exp layers

# Layers

> Utilities for creating torch Modules for self supervised learning.

In [None]:
#export
from fastai.vision.all import *
import timm

In [None]:
# export
# https://github.com/rwightman/pytorch-image-models/blob/3a7aa95f7e5fc90a6a2683c756e854e26201d82e/timm/models/layers/adaptive_avgmax_pool.py#L79
mk_class('PoolingType', **{o:o.lower() for o in ['Fast', 'Avg', 'AvgMax', 'CatAvgMax', 'Max']},
         doc="All possible pooling types as attributes to get tab-completion and typo-proofing")

In [None]:
#export
_all_ = ['PoolingType', '_splitter']

In [None]:
#export
def create_fastai_encoder(arch:str, pretrained=True, n_in=3, pool_type=PoolingType.CatAvgMax):
    "Create timm encoder from a given arch backbone"
    encoder = create_body(arch, n_in, pretrained, cut=None)
    pool = AdaptiveConcatPool2d() if pool_type == "catavgmax" else nn.AdaptiveAvgPool2d(1)
    return nn.Sequential(*encoder, pool, Flatten())

def create_timm_encoder(arch:str, pretrained=True, n_in=3, pool_type=PoolingType.CatAvgMax):
    "Creates a body from any model in the `timm` library. If pool_type is None then it uses timm default"
    if ('vit' in arch) or (pool_type is None):
        model = timm.create_model(arch, pretrained=pretrained, in_chans=n_in, num_classes=0)
    else:
        model = timm.create_model(arch, pretrained=pretrained, in_chans=n_in, num_classes=0, global_pool=pool_type)
    return model

def create_encoder(arch:str, pretrained=True, n_in=3, pool_type=PoolingType.CatAvgMax):
    "A utility for creating encoder without specifying the package"
    if arch in globals(): return create_fastai_encoder(globals()[arch], pretrained, n_in, pool_type)
    else:                 return create_timm_encoder(arch, pretrained, n_in, pool_type)

In [None]:
inp = torch.randn((1,3,384,384))

Fastai encoder expects a function as it's first argument, where timm expects a string. Also, fastai defaults to concat pooling, aka `catavgmax` in timm. With timm's selective pooling any `PoolingType` can used. Experiments show that concat pooling is better on average so it is set as our default.

For any other `pool_type` fastai uses `AdaptiveAvgPool2d`, for timm you can choose from the remaining `PoolingType`.

In [None]:
fastai_encoder = create_fastai_encoder(xresnet34)
out = fastai_encoder(inp); out.shape

torch.Size([1, 1024])

In [None]:
fastai_encoder = create_fastai_encoder(xresnet34, pool_type=False)
out = fastai_encoder(inp); out.shape

torch.Size([1, 512])

In [None]:
model = create_timm_encoder("tf_efficientnet_b0_ns", pretrained=False)
out = model(inp); out.shape

torch.Size([1, 2560])

In [None]:
model = create_timm_encoder("tf_efficientnet_b0_ns", pretrained=False, pool_type=PoolingType.Avg)
out = model(inp); out.shape

torch.Size([1, 1280])

In [None]:
model = create_encoder("xresnet34", pretrained=False, pool_type=PoolingType.Avg)
out = model(inp); out.shape

torch.Size([1, 512])

In [None]:
model = create_encoder("tf_efficientnet_b0_ns", pretrained=False, pool_type=PoolingType.Avg)
out = model(inp); out.shape

torch.Size([1, 1280])

Vision Transformer is a special case which uses `Layernorm`.

In [None]:
vit_model = create_timm_encoder("vit_large_patch16_384", pretrained=False)
out = vit_model(inp); out.shape

torch.Size([1, 1024])

In [None]:
#export
def create_mlp_module(dim,hidden_size,projection_size,bn=False,nlayers=2):
    "MLP module as described in papers, used as projection layer"
    l = []
    for i in range(nlayers-1):
        l += [nn.Linear(dim, hidden_size) if i == 0 else nn.Linear(hidden_size, hidden_size)] 
        if bn: l += [nn.BatchNorm1d(hidden_size)]
        l += [nn.ReLU(inplace=True)]
    ls = l + [nn.Linear(hidden_size, projection_size)]
    return nn.Sequential(*ls)

In [None]:
#SimCLR
create_mlp_module(1024,4096,128)

Sequential(
  (0): Linear(in_features=1024, out_features=4096, bias=True)
  (1): ReLU(inplace=True)
  (2): Linear(in_features=4096, out_features=128, bias=True)
)

In [None]:
#SimCLR-v2
create_mlp_module(1024,4096,128,nlayers=3)

Sequential(
  (0): Linear(in_features=1024, out_features=4096, bias=True)
  (1): ReLU(inplace=True)
  (2): Linear(in_features=4096, out_features=4096, bias=True)
  (3): ReLU(inplace=True)
  (4): Linear(in_features=4096, out_features=128, bias=True)
)

In [None]:
#BYOL
create_mlp_module(1024,4096,128,bn=True)

Sequential(
  (0): Linear(in_features=1024, out_features=4096, bias=True)
  (1): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): Linear(in_features=4096, out_features=128, bias=True)
)

In [None]:
#SWAV
create_mlp_module(1024,4096,128,bn=True,nlayers=3)

Sequential(
  (0): Linear(in_features=1024, out_features=4096, bias=True)
  (1): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): Linear(in_features=4096, out_features=4096, bias=True)
  (4): BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): ReLU(inplace=True)
  (6): Linear(in_features=4096, out_features=128, bias=True)
)

In [None]:
#export
def create_cls_module(nf, n_out, lin_ftrs=None, ps=0.5, use_bn=True, first_bn=True, bn_final=False, lin_first=False, y_range=None):
    "Creates classification layer which takes nf flatten features and outputs n_out logits"
    lin_ftrs = [nf, 512, n_out] if lin_ftrs is None else [nf] + lin_ftrs + [n_out]
    bns = [first_bn] + [use_bn]*len(lin_ftrs[1:])
    ps = L(ps)
    if len(ps) == 1: ps = [ps[0]/2] * (len(lin_ftrs)-2) + ps
    actns = [nn.ReLU(inplace=True)] * (len(lin_ftrs)-2) + [None]
    layers = []
    if lin_first: layers.append(nn.Dropout(ps.pop(0)))
    for ni,no,bn,p,actn in zip(lin_ftrs[:-1], lin_ftrs[1:], bns, ps, actns):
        layers += LinBnDrop(ni, no, bn=bn, p=p, act=actn, lin_first=lin_first)
    if lin_first: layers.append(nn.Linear(lin_ftrs[-2], n_out))
    if bn_final: layers.append(nn.BatchNorm1d(lin_ftrs[-1], momentum=0.01))
    if y_range is not None: layers.append(SigmoidRange(*y_range))
    return nn.Sequential(*layers)

In [None]:
inp = torch.randn((2,3,384,384))

In [None]:
encoder = create_encoder("xresnet34", pretrained=False)
out = encoder(inp) 
classifier = create_cls_module(out.size(-1), n_out=5, first_bn=False)
model = nn.Sequential(encoder, classifier)

In [None]:
with torch.no_grad(): print(model(inp))

tensor([[-0.5934,  0.0218, -1.0546, -0.0870, -0.0212],
        [ 0.8928,  1.1403,  0.0279, -0.5045, -1.0595]])


In [None]:
encoder = create_encoder("vit_large_patch16_384", pretrained=False)
out = encoder(inp) 
classifier = create_cls_module(out.size(-1), n_out=5, first_bn=False)
model = nn.Sequential(encoder, classifier)

In [None]:
with torch.no_grad(): print(model(inp))

tensor([[ 0.0023, -0.0434, -0.1689,  0.7236,  1.4304],
        [ 0.2860,  0.3319, -1.1037, -0.1302, -1.2017]])


`create_model` can be used to create models for classification, for example quickly creating a model for downstream classification training.

In [None]:
#export
@delegates(create_cls_module)
def create_model(arch, n_out, pretrained=True, n_in=3, pool_type=PoolingType.CatAvgMax, **kwargs):
    encoder = create_encoder(arch, pretrained=pretrained, n_in=n_in, pool_type=pool_type)
    sz = int(arch.split("_")[-1]) if 'vit'in arch else 224
    with torch.no_grad(): nf = encoder(torch.randn(2,3,sz,sz)).size(-1)
    head = create_cls_module(nf, n_out, **kwargs)
    apply_init(head)
    model = nn.Sequential(encoder, head)    
    return model

`_splitter` can be passed to `Learner(...,splitter=splitter_func)`. This can be used to freeze or unfreeze encoder layers, in this case first parameter group is the encoder and second parameter group is the classification head. Simply by indexing to model[0] and model[1] we can access encoder and classification head modules.

In [None]:
#export 
def _splitter(m): return L(m[0], m[1]).map(params)

In [None]:
model = create_model("xresnet34", 10, pretrained=False)
model[1]

Sequential(
  (0): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (1): Dropout(p=0.25, inplace=False)
  (2): Linear(in_features=1024, out_features=512, bias=False)
  (3): ReLU(inplace=True)
  (4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): Dropout(p=0.5, inplace=False)
  (6): Linear(in_features=512, out_features=10, bias=False)
)

In [None]:
with torch.no_grad(): print(model(inp))

tensor([[ 1.6586,  2.8627,  0.5942,  1.3333, -0.7247, -0.1750,  0.4790, -4.1426,
          0.3254,  0.5920],
        [-0.6557,  0.3469, -3.1064,  0.4271,  3.6438,  0.0830, -1.9096,  4.2991,
         -1.3772,  0.3817]])


In [None]:
model = create_model("vit_large_patch16_384", 10, pretrained=False, use_bn=False, first_bn=False, bn_final=False)
model[1]

Sequential(
  (0): Dropout(p=0.25, inplace=False)
  (1): Linear(in_features=1024, out_features=512, bias=True)
  (2): ReLU(inplace=True)
  (3): Dropout(p=0.5, inplace=False)
  (4): Linear(in_features=512, out_features=10, bias=True)
)

In [None]:
with torch.no_grad(): print(model(inp))

tensor([[ 2.2183,  2.0234, -4.9572, -1.5017,  5.2824, -0.1557,  1.8053,  2.5815,
          1.0612,  1.0911],
        [ 0.8053, -0.1254, -1.0162, -2.4544,  3.7484,  0.2554,  1.4608,  0.5014,
         -1.6777, -2.0474]])


### Gradient Checkpointing

For memory conservation, to train with larger image resolution and/or batch size. For now it's compatible with **timm** EfficientNet and ResNet models, and **fastai** models. But it should be easy to implement for any encoder model that you are using.

This is a current fix for using gradient checkpointing with autocast / `to_fp16()` https://github.com/pytorch/pytorch/pull/49757/files

In [None]:
#export 
from torch.utils.checkpoint import checkpoint_sequential

class CheckpointResNet(Module):
    def __init__(self, resnet_model, checkpoint_nchunks=2):
        "Up to 4 chunks"
        self.checkpoint_nchunks = checkpoint_nchunks
        self.resnet_model = resnet_model
        self.forward_layers = nn.Sequential(*[
            self.resnet_model.layer1,
            self.resnet_model.layer2,
            self.resnet_model.layer3,
            self.resnet_model.layer4
        ])
    
    def forward(self, x):
        x = self.resnet_model.conv1(x)
        x = self.resnet_model.bn1(x)
        x = self.resnet_model.act1(x)
        x = self.resnet_model.maxpool(x)
            
        x = checkpoint_sequential(self.forward_layers, self.checkpoint_nchunks, x)
        x = self.resnet_model.global_pool(x)
        
        if self.resnet_model.drop_rate:
            x = F.dropout(x, p=float(self.resnet_model.drop_rate), training=self.resnet_model.training)
        x = self.resnet_model.fc(x)
        return x

class CheckpointEfficientNet(Module):
    def __init__(self, effnet_model, checkpoint_nchunks=2):
        self.checkpoint_nchunks = checkpoint_nchunks
        self.effnet_model = effnet_model
    
    def forward_features(self, x):
        x = self.effnet_model.conv_stem(x)
        x = self.effnet_model.bn1(x)
        x = self.effnet_model.act1(x)
        x = checkpoint_sequential(self.effnet_model.blocks, self.checkpoint_nchunks, x)
        x = self.effnet_model.conv_head(x)
        x = self.effnet_model.bn2(x)
        x = self.effnet_model.act2(x)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.effnet_model.global_pool(x)
        if self.effnet_model.drop_rate > 0.:
            x = F.dropout(x, p=self.effnet_model.drop_rate, training=self.effnet_model.training)
        return self.effnet_model.classifier(x)

class CheckpointSequential(Module):
    def __init__(self, fastai_model, checkpoint_nchunks=2):
        "This can be used for checkpointing fastai encoders which are sequential models"
        self.checkpoint_nchunks = checkpoint_nchunks
        self.fastai_model = fastai_model
    
    def forward(self, x):
        x = checkpoint_sequential(self.fastai_model, self.checkpoint_nchunks, x)
        return x

In [None]:
L(timm.list_models("*resnet50*"))[-10:]

(#10) ['seresnet50','seresnet50tn','skresnet50','skresnet50d','ssl_resnet50','swsl_resnet50','tv_resnet50','vit_base_resnet50d_224','vit_small_resnet50d_s3_224','wide_resnet50_2']

In [None]:
encoder = create_encoder("seresnet50", pretrained=False)
encoder = CheckpointResNet(encoder, checkpoint_nchunks=4)
out = encoder(inp) 
classifier = create_cls_module(out.size(-1), n_out=5, first_bn=False)
model = nn.Sequential(encoder, classifier)
with torch.no_grad(): print(model(inp))



tensor([[ 0.2153, -0.8222, -0.1195, -0.1419,  0.2558],
        [-0.0267,  1.1275,  0.4353, -0.2715, -1.3025]])


In [None]:
L(timm.list_models("*efficientnet*"))[-10:]

(#10) ['tf_efficientnet_el','tf_efficientnet_em','tf_efficientnet_es','tf_efficientnet_l2_ns','tf_efficientnet_l2_ns_475','tf_efficientnet_lite0','tf_efficientnet_lite1','tf_efficientnet_lite2','tf_efficientnet_lite3','tf_efficientnet_lite4']

In [None]:
encoder = create_encoder("tf_efficientnet_b0_ns", pretrained=False)
encoder = CheckpointEfficientNet(encoder, checkpoint_nchunks=4)
out = encoder(inp) 
classifier = create_cls_module(out.size(-1), n_out=5, first_bn=False)
model = nn.Sequential(encoder, classifier)
with torch.no_grad(): print(model(inp))

tensor([[ 0.2183, -1.7747,  0.6225, -0.2091, -0.6604],
        [-0.4133,  1.4024,  0.4160, -0.6159,  0.6558]])


In [None]:
encoder = create_encoder("xresnet34", pretrained=False)
encoder = CheckpointSequential(encoder, checkpoint_nchunks=4)
out = encoder(inp)
classifier = create_cls_module(out.size(-1), n_out=5, first_bn=False)
model = nn.Sequential(encoder, classifier)
with torch.no_grad(): print(model(inp))



tensor([[ 0.2094, -0.0202,  1.5631,  0.8257,  0.8442],
        [-0.4477, -0.2046, -0.9960, -1.3508,  0.2298]])


## Export -

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 01 - augmentations.ipynb.
Converted 02 - layers.ipynb.
Converted 03 - distributed.ipynb.
Converted 10 - simclr.ipynb.
Converted 11 - moco.ipynb.
Converted 12 - byol.ipynb.
Converted 13 - swav.ipynb.
Converted 20 - clip.ipynb.
Converted 21 - clip-moco.ipynb.
Converted index.ipynb.
