In [None]:
# default_exp layers

In [None]:
#export
from local.imports import *
from local.test import *
from local.core import *
from torch.nn.utils import weight_norm, spectral_norm

# Layers
> Custom fastai layers and basic functions to grab them.

## Basic manipulations and resize

In [None]:
# export
class Lambda(nn.Module):
    "An easy way to create a pytorch layer for a simple `func`"
    def __init__(self, func):
        super().__init__()
        self.func=func

    def forward(self, x): return self.func(x)
    def __repr__(self): return f'{self.__class__.__name__}({self.func})'

> Warning: In the tests below, we use lambda functions for convenience, but you shouldn't do this when building a real modules as it would make models that won't pickle (so you won't be able to save/export them).

In [None]:
tst = Lambda(lambda x:x+2)
x = torch.randn(10,20)
test_eq(tst(x), x+2)

In [None]:
# export
class PartialLayer(Lambda):
    "Layer that applies `partial(func, **kwargs)`"
    def __init__(self, func, **kwargs):
        super().__init__(partial(func, **kwargs))
        self.repr = f'{func.__name__}, {kwargs}'

    def forward(self, x): return self.func(x)
    def __repr__(self): return f'{self.__class__.__name__}({self.repr})'

In [None]:
def test_func(a,b=2): return a+b
tst = PartialLayer(test_func, b=5)
test_eq(tst(x), x+5)

In [None]:
# export
class View(nn.Module):
    "Reshape `x` to `size`"
    def __init__(self, *size):
        super().__init__()
        self.size = size

    def forward(self, x): return x.view(self.size)

In [None]:
tst = View(10,5,4)
test_eq(tst(x).shape, [10,5,4])

In [None]:
# export
class ResizeBatch(nn.Module):
    "Reshape `x` to `size`, keeping batch dim the same size"
    def __init__(self, *size):
        super().__init__()
        self.size = size

    def forward(self, x):
        size = (x.size(0),) + self.size
        return x.view(size)

In [None]:
tst = ResizeBatch(5,4)
test_eq(tst(x).shape, [10,5,4])

In [None]:
# export
class Flatten(nn.Module):
    "Flatten `x` to a single dimension, often used at the end of a model. `full` for rank-1 tensor"
    def __init__(self, full=False):
        super().__init__()
        self.full = full

    def forward(self, x):
        return x.view(-1) if self.full else x.view(x.size(0), -1)

In [None]:
tst = Flatten()
x = torch.randn(10,5,4)
test_eq(tst(x).shape, [10,20])
tst = Flatten(full=True)
test_eq(tst(x).shape, [200])

In [None]:
# export
class Debugger(nn.Module):
    "A module to debug inside a model."
    def forward(self,x):
        set_trace()
        return x

In [None]:
# export
def sigmoid_range(x, low, high):
    "Sigmoid function with range `(low, high)`"
    return torch.sigmoid(x) * (high - low) + low

In [None]:
test = tensor([-10.,0.,10.])
assert torch.allclose(sigmoid_range(test, -1,  2), tensor([-1.,0.5, 2.]), atol=1e-4, rtol=1e-4)
assert torch.allclose(sigmoid_range(test, -5, -1), tensor([-5.,-3.,-1.]), atol=1e-4, rtol=1e-4)
assert torch.allclose(sigmoid_range(test,  2,  4), tensor([2.,  3., 4.]), atol=1e-4, rtol=1e-4)

In [None]:
# export
class SigmoidRange(nn.Module):
    "Sigmoid module with range `(low, high)`"
    def __init__(self, low, high):
        super().__init__()
        self.low,self.high = low,high

    def forward(self, x): return sigmoid_range(x, self.low, self.high)

In [None]:
tst = SigmoidRange(-1, 2)
assert torch.allclose(tst(test), tensor([-1.,0.5, 2.]), atol=1e-4, rtol=1e-4)

## Pooling layers

In [None]:
# export
class PoolFlatten(nn.Sequential):
    "Combine `nn.AdaptiveAvgPool2d` and `Flatten`."
    def __init__(self): super().__init__(nn.AdaptiveAvgPool2d(1), Flatten())

In [None]:
tst = PoolFlatten()
x = torch.randn(10,5,4,4)
test_eq(tst(x).shape, [10,5])
test_eq(tst(x), x.mean(dim=[2,3]))

In [None]:
# export
class AdaptiveConcatPool2d(nn.Module):
    "Layer that concats `AdaptiveAvgPool2d` and `AdaptiveMaxPool2d`"
    def __init__(self, size=None):
        super().__init__()
        self.size = size or 1
        self.ap = nn.AdaptiveAvgPool2d(self.size)
        self.mp = nn.AdaptiveMaxPool2d(self.size)
    def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)

If the input is `bs x nf x h x h`, the output will be `bs x 2*nf x 1 x 1` if no size is passed or `bs x 2*nf x size x size`

In [None]:
tst = AdaptiveConcatPool2d()
test_eq(tst(x).shape, [10,10,1,1])
max1 = torch.max(x,    dim=2, keepdim=True)[0]
maxp = torch.max(max1, dim=3, keepdim=True)[0]
test_eq(tst(x)[:,:5], maxp)
test_eq(tst(x)[:,5:], x.mean(dim=[2,3], keepdim=True))
tst = AdaptiveConcatPool2d(2)
test_eq(tst(x).shape, [10,10,2,2])

## BatchNorm layers

In [None]:
# export
NormType = Enum('NormType', 'Batch BatchZero Weight Spectral')

In [None]:
#export
def BatchNorm(nf, norm_type=NormType.Batch, ndim=2, **kwargs):
    "BatchNorm layer with `nf` features and `ndim` initialized depending on `norm_type`."
    assert 1 <= ndim <= 3
    bn = getattr(nn, f"BatchNorm{ndim}d")(nf, **kwargs)
    bn.bias.data.fill_(1e-3)
    bn.weight.data.fill_(0. if norm_type==NormType.BatchZero else 1.)
    return bn

`kwargs` are passed to `nn.BatchNorm` and can be `eps`, `momentum`, `affine` and `track_running_stats`.

In [None]:
tst = BatchNorm(15)
assert isinstance(tst, nn.BatchNorm2d)
test_eq(tst.weight, torch.ones(15))
tst = BatchNorm(15, norm_type=NormType.BatchZero)
test_eq(tst.weight, torch.zeros(15))
tst = BatchNorm(15, ndim=1)
assert isinstance(tst, nn.BatchNorm1d)
tst = BatchNorm(15, ndim=3)
assert isinstance(tst, nn.BatchNorm3d)

In [None]:
# export
class BatchNorm1dFlat(nn.BatchNorm1d):
    "`nn.BatchNorm1d`, but first flattens leading dimensions"
    def forward(self, x):
        if x.dim()==2: return super().forward(x)
        *f,l = x.shape
        x = x.contiguous().view(-1,l)
        return super().forward(x).view(*f,l)

In [None]:
tst = BatchNorm1dFlat(15)
x = torch.randn(32, 64, 15)
y = tst(x)
mean = x.mean(dim=[0,1])
test_close(tst.running_mean, 0*0.9 + mean*0.1)
var = (x-mean).pow(2).mean(dim=[0,1])
test_close(tst.running_var, 1*0.9 + var*0.1, eps=1e-4)
test_close(y, (x-mean)/torch.sqrt(var+1e-5) * tst.weight + tst.bias, eps=1e-4)

In [None]:
# export
class BnDropLin(nn.Sequential):
    "Module grouping `BatchNorm1d`, `Dropout` and `Linear` layers"
    def __init__(self, n_in, n_out, bn=True, p=0., act=None):
        layers = [BatchNorm(n_in, ndim=1)] if bn else []
        if p != 0: layers.append(nn.Dropout(p))
        layers.append(nn.Linear(n_in, n_out))
        if act is not None: layers.append(act)
        super().__init__(*layers)

The `BatchNorm` layer is skipped if `bn=False`, as is the dropout if `p=0.`. Optionally, you can add an activation for after the linear laeyr with `act`.

In [None]:
tst = BnDropLin(10, 20)
mods = list(tst.children())
test_eq(len(mods), 2)
assert isinstance(mods[0], nn.BatchNorm1d)
assert isinstance(mods[1], nn.Linear)

tst = BnDropLin(10, 20, p=0.1)
mods = list(tst.children())
test_eq(len(mods), 3)
assert isinstance(mods[0], nn.BatchNorm1d)
assert isinstance(mods[1], nn.Dropout)
assert isinstance(mods[2], nn.Linear)

tst = BnDropLin(10, 20, act=nn.ReLU())
mods = list(tst.children())
test_eq(len(mods), 3)
assert isinstance(mods[0], nn.BatchNorm1d)
assert isinstance(mods[1], nn.Linear)
assert isinstance(mods[2], nn.ReLU)

tst = BnDropLin(10, 20, bn=False)
mods = list(tst.children())
test_eq(len(mods), 1)
assert isinstance(mods[0], nn.Linear)

## Convolutions

In [None]:
#export
def init_default(m, func=nn.init.kaiming_normal_):
    "Initialize `m` weights with `func` and set `bias` to 0."
    if func:
        if hasattr(m, 'weight'): func(m.weight)
        if hasattr(m, 'bias') and hasattr(m.bias, 'data'): m.bias.data.fill_(0.)
    return m

In [None]:
#export
def _relu(inplace:bool=False, leaky:float=None):
    "Return a relu activation, maybe `leaky` and `inplace`."
    return nn.LeakyReLU(inplace=inplace, negative_slope=leaky) if leaky is not None else nn.ReLU(inplace=inplace)

def _conv_func(ndim=2, transpose=False):
    "Return the proper conv `ndim` function, potentially `transposed`."
    assert 1 <= ndim <=3
    return getattr(nn, f'Conv{"Transpose" if transpose else ""}{ndim}d')

In [None]:
#hide
test_eq(_conv_func(ndim=1),torch.nn.modules.conv.Conv1d)
test_eq(_conv_func(ndim=2),torch.nn.modules.conv.Conv2d)
test_eq(_conv_func(ndim=3),torch.nn.modules.conv.Conv3d)
test_eq(_conv_func(ndim=1, transpose=True),torch.nn.modules.conv.ConvTranspose1d)
test_eq(_conv_func(ndim=2, transpose=True),torch.nn.modules.conv.ConvTranspose2d)
test_eq(_conv_func(ndim=3, transpose=True),torch.nn.modules.conv.ConvTranspose3d)

In [None]:
# export
defaults = SimpleNamespace(activation=nn.ReLU)

In [None]:
# export
class ConvLayer(nn.Sequential):
    "Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and `norm_type` layers."
    def __init__(self, ni, nf, ks=3, stride=1, padding=None, bias=None, ndim=2, norm_type=NormType.Batch,
                 act_cls=defaults.activation, transpose=False, init=nn.init.kaiming_normal_, xtra=None):
        if padding is None: padding = ((ks-1)//2 if not transpose else 0)
        bn = norm_type in (NormType.Batch, NormType.BatchZero)
        if bias is None: bias = not bn
        conv_func = _conv_func(ndim, transpose=transpose)
        conv = init_default(conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding), init)
        if   norm_type==NormType.Weight:   conv = weight_norm(conv)
        elif norm_type==NormType.Spectral: conv = spectral_norm(conv)
        layers = [conv]
        if act_cls is not None: layers.append(act_cls())
        if bn: layers.append(BatchNorm(nf, norm_type=norm_type, ndim=ndim))
        if xtra: layers.append(xtra)
        super().__init__(*layers)

The convolution uses `ks` (kernel size) `stride`, `padding` and `bias`. `padding` will default to the appropriate value (`(ks-1)//2` if it's not a transposed conv) and `bias` will default to `True` the `norm_type` is `Spectral` or `Weight`, `False` if it's `Batch` or `BatchZero`. Note that if you don't want any normalization, you should pass `norm_type=None`.

This defines a conv layer with `ndim` (1,2 or 3) that will be a ConvTranspose if `transpose=True`. `act_cls` is the class of the activation function to use (instantiated inside). Pass `act=None` if you don't want an activation function. If you quickly want to change your default activation, you can change the value of `defaults.activation`.

`init` is used to initialize the weights (the bias are initiliazed to 0) and `xtra` is an optional layer to add at the end.

In [None]:
tst = ConvLayer(16, 32)
mods = list(tst.children())
test_eq(len(mods), 3)
assert isinstance(mods[0], nn.Conv2d)
assert isinstance(mods[1], nn.ReLU)
assert isinstance(mods[2], nn.BatchNorm2d)
test_eq(mods[2].weight, torch.ones(32))
test_eq(mods[0].padding, (1,1))

In [None]:
x = torch.randn(64, 16, 8, 8)
#Padding is selected to make the shape the same if stride=1
test_eq(tst(x).shape, [64,32,8,8])

#Padding is selected to make the shape half if stride=2
tst = ConvLayer(16, 32, stride=2)
test_eq(tst(x).shape, [64,32,4,4])

#But you can always pass your own padding if you want
tst = ConvLayer(16, 32, padding=0)
test_eq(tst(x).shape, [64,32,6,6])

In [None]:
#No bias by default for Batch NormType
assert mods[0].bias is None
#But can be overriden with `bias=True`
tst = ConvLayer(16, 32, bias=True)
test_eq(list(tst.children())[0].bias, torch.zeros(32))
#For no norm, or spectral/weight, bias is True by default
for t in [None, NormType.Spectral, NormType.Weight]:
    tst = ConvLayer(16, 32, norm_type=t)
    test_eq(list(tst.children())[0].bias, torch.zeros(32))

In [None]:
#Various n_dim/tranpose
tst = ConvLayer(16, 32, ndim=3)
assert isinstance(list(tst.children())[0], nn.Conv3d)
assert isinstance(list(tst.children())[2], nn.BatchNorm3d)
tst = ConvLayer(16, 32, ndim=1, transpose=True)
assert isinstance(list(tst.children())[0], nn.ConvTranspose1d)
assert isinstance(list(tst.children())[2], nn.BatchNorm1d)

In [None]:
#No activation/leaky
tst = ConvLayer(16, 32, ndim=3, act_cls=None)
mods = list(tst.children())
test_eq(len(mods), 2)
tst = ConvLayer(16, 32, ndim=3, act_cls=partial(nn.LeakyReLU, negative_slope=0.1))
mods = list(tst.children())
test_eq(len(mods), 3)
assert isinstance(mods[1], nn.LeakyReLU)

## Flattened loss functions

It's convenient to flatten the tensors before trying to take the losses, here is the general class to do this and the applications we use it for.

In [None]:
# export
class FlattenedLoss():
    "Same as `loss_cls`, but flattens input and target."
    def __init__(self, loss_cls, *args, axis=-1, floatify=False, is_2d=True, **kwargs):
        self.func,self.axis,self.floatify,self.is_2d = loss_cls(*args,**kwargs),axis,floatify,is_2d
        functools.update_wrapper(self, self.func)

    def __repr__(self): return f"FlattenedLoss of {self.func}"
    @property
    def reduction(self): return self.func.reduction
    @reduction.setter
    def reduction(self, v): self.func.reduction = v

    def __call__(self, input, target, **kwargs):
        input  = input .transpose(self.axis,-1).contiguous()
        target = target.transpose(self.axis,-1).contiguous()
        if self.floatify: target = target.float()
        input = input.view(-1,input.shape[-1]) if self.is_2d else input.view(-1)
        return self.func.__call__(input, target.view(-1), **kwargs)

The `args` and `kwargs` will be passed to `loss_cls` during the initialization to instantiate a loss function. `axis` is put at the end for losses like softmax that are often performed on the last axis. If `floatify=True` the targs will be converted to float (usefull for losses that only accept float targets like `BCEWithLogitsLoss`) and `is_2d` determines if we flatten while keeping the first dimension (batch size) or completely flatten the input. We want the first for losses like Cross Entropy, and the second for pretty much anything else.

In [None]:
# export
def CrossEntropyLossFlat(*args, axis:int=-1, **kwargs):
    "Same as `nn.CrossEntropyLoss`, but flattens input and target."
    return FlattenedLoss(nn.CrossEntropyLoss, *args, axis=axis, **kwargs)

In [None]:
tst = CrossEntropyLossFlat()
output = torch.randn(32, 5, 10)
target = torch.randint(0, 10, (32,5))
#nn.CrossEntropy would fail with those two tensors, but not our flattened version.
_ = tst(output, target)
test_fail(lambda x: nn.CrossEntropyLoss()(output,target))
#In a segmentation task, we want to take the softmax over the channel dimension
tst = CrossEntropyLossFlat(axis=1)
output = torch.randn(32, 5, 128, 128)
target = torch.randint(0, 5, (32, 128, 128))
_ = tst(output, target)

In [None]:
# export
def BCEWithLogitsLossFlat(*args, axis:int=-1, floatify:bool=True, **kwargs):
    "Same as `nn.BCEWithLogitsLoss`, but flattens input and target."
    return FlattenedLoss(nn.BCEWithLogitsLoss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs)

In [None]:
tst = BCEWithLogitsLossFlat()
output = torch.randn(32, 5, 10)
target = torch.randn(32, 5, 10)
#nn.BCEWithLogitsLoss would fail with those two tensors, but not our flattened version.
_ = tst(output, target)
test_fail(lambda x: nn.BCEWithLogitsLoss()(output,target))
output = torch.randn(32, 5)
target = torch.randint(0,2,(32, 5))
#nn.BCEWithLogitsLoss would fail with int targets but not our flattened version.
_ = tst(output, target)
test_fail(lambda x: nn.BCEWithLogitsLoss()(output,target))

In [None]:
# export
def BCELossFlat(*args, axis:int=-1, floatify:bool=True, **kwargs):
    "Same as `nn.BCELoss`, but flattens input and target."
    return FlattenedLoss(nn.BCELoss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs)

In [None]:
tst = BCELossFlat()
output = torch.sigmoid(torch.randn(32, 5, 10))
target = torch.randint(0,2,(32, 5, 10))
_ = tst(output, target)
test_fail(lambda x: nn.BCELoss()(output,target))

In [None]:
# export
def MSELossFlat(*args, axis:int=-1, floatify:bool=True, **kwargs):
    "Same as `nn.MSELoss`, but flattens input and target."
    return FlattenedLoss(nn.MSELoss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs)

In [None]:
tst = MSELossFlat()
output = torch.sigmoid(torch.randn(32, 5, 10))
target = torch.randint(0,2,(32, 5, 10))
_ = tst(output, target)
test_fail(lambda x: nn.MSELoss()(output,target))

## Embeddings

In [None]:
# export
def trunc_normal_(x, mean=0., std=1.):
    "Truncated normal initialization (approximation)"
    # From https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/12
    return x.normal_().fmod_(2).mul_(std).add_(mean)

In [None]:
# export
class Embedding(nn.Embedding):
    "Embedding layer with truncated normal initialization"
    def __init__(self, ni, nf):
        super().__init__(ni, nf)
        trunc_normal_(self.weight.data, std=0.01)

Truncated normal initialization bounds the distribution to avoid large value. For a given standard deviation `std`, the bounds are roughly `-std`, `std`.

In [None]:
tst = Embedding(10, 30)
assert tst.weight.min() > -0.02
assert tst.weight.max() < 0.02
test_close(tst.weight.mean(), 0, 1e-3)
test_close(tst.weight.std(), 0.01, 0.1)

## Self attention

In [None]:
class SelfAttention(nn.Module):
    "Self attention layer for `n_channels`."
    def __init__(self, n_channels):
        super().__init__()
        self.query = self._conv(n_channels, n_channels//8)
        self.key   = self._conv(n_channels, n_channels//8)
        self.value = self._conv(n_channels, n_channels)
        self.gamma = nn.Parameter(tensor([0.]))

    def _conv(self,n_in,n_out):
        return ConvLayer(n_in, n_out, ks=1, ndim=1, norm_type=NormType.Spectral, use_activ=False, bias=False)
    
    def forward(self, x):
        #Notation from the paper.
        size = x.size()
        x = x.view(*size[:2],-1)
        f,g,h = self.query(x),self.key(x),self.value(x)
        beta = F.softmax(torch.bmm(f.permute(0,2,1).contiguous(), g), dim=1)
        o = self.gamma * torch.bmm(h, beta) + x
        return o.view(*size).contiguous()

Self-attention layer as introduced in [Self-Attention Generative Adversarial Networks](https://arxiv.org/abs/1805.08318).

In [None]:
class PooledSelfAttention2d(nn.Module):
    "Pooled self attention layer for 2d."
    def __init__(self, n_channels:int):
        super().__init__()
        self.n_channels = n_channels
        self.theta = self._conv(n_channels, n_channels//8)
        self.phi   = self._conv(n_channels, n_channels//8)
        self.g     = self._conv(n_channels, n_channels//2)
        self.o     = self._conv(n_channels//2, n_channels)
        self.gamma = nn.Parameter(tensor([0.]))
        
    def _conv(self,n_in,n_out):
        return ConvLayer(n_in, n_out, ks=1, norm_type=NormType.Spectral, use_activ=False, bias=False)

    def forward(self, x):
        #Inspired by https://github.com/ajbrock/BigGAN-PyTorch/blob/7b65e82d058bfe035fc4e299f322a1f83993e04c/layers.py#L156
        theta = self.theta(x)
        phi = F.max_pool2d(self.phi(x), [2,2])
        g =   F.max_pool2d(self.g(x),   [2,2])    
        theta = theta.view(-1, self.n_channels//8, x.shape[2]*x.shape[3])
        phi   = phi  .view(-1, self.n_channels//8, x.shape[2]*x.shape[3]//4)
        g     = g    .view(-1, self.n_channels//2, x.shape[2]*x.shape[3]//4)
        beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1)
        o = self.o(torch.bmm(g, beta.transpose(1,2)).view(-1, self.ch//2, x.shape[2], x.shape[3]))
        return self.gamma * o + x

Self-attention layer used in the [Big GAN paper](https://arxiv.org/abs/1809.11096).

## PixelShuffle

In [None]:
def icnr(x, scale=2, init=nn.init.kaiming_normal_):
    "ICNR init of `x`, with `scale` and `init` function."
    ni,nf,h,w = x.shape
    ni2 = int(ni/(scale**2))
    k = init(torch.zeros([ni2,nf,h,w])).transpose(0, 1)
    k = k.contiguous().view(ni2, nf, -1)
    k = k.repeat(1, 1, scale**2)
    k = k.contiguous().view([nf,ni,h,w]).transpose(0, 1)
    x.data.copy_(k)

In [None]:
class PixelShuffle_ICNR(nn.Module):
    "Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, `icnr` init, and `weight_norm`."
    def __init__(self, ni:int, nf:int=None, scale:int=2, blur:bool=False, norm_type=NormType.Weight, leaky:float=None):
        super().__init__()
        nf = ifnone(nf, ni)
        self.conv = conv_layer(ni, nf*(scale**2), ks=1, norm_type=norm_type, use_activ=False)
        icnr(self.conv[0].weight)
        self.shuf = nn.PixelShuffle(scale)
        # Blurring over (h*w) kernel
        # "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts"
        # - https://arxiv.org/abs/1806.02658
        self.pad = nn.ReplicationPad2d((1,0,1,0))
        self.blur = nn.AvgPool2d(2, stride=1)
        self.relu = relu(True, leaky=leaky)

    def forward(self,x):
        x = self.shuf(self.relu(self.conv(x)))
        return self.blur(self.pad(x)) if self.blur else x

## Sequential extensions

In [None]:
class SequentialEx(nn.Module):
    "Like `nn.Sequential`, but with ModuleList semantics, and can access module input"
    def __init__(self, *layers):
        super().__init__()
        self.layers = nn.ModuleList(layers)

    def forward(self, x):
        res = x
        for l in self.layers:
            res.orig = x
            nres = l(res)
            # We have to remove res.orig to avoid hanging refs and therefore memory leaks
            res.orig = None
            res = nres
        return res

    def __getitem__(self,i): return self.layers[i]
    def append(self,l):      return self.layers.append(l)
    def extend(self,l):      return self.layers.extend(l)
    def insert(self,i,l):    return self.layers.insert(i,l)

In [None]:
class MergeLayer(nn.Module):
    "Merge a shortcut with the result of the module by adding them or concatenating them if `dense=True`."
    def __init__(self, dense:bool=False):
        super().__init__()
        self.dense=dense

    def forward(self, x): return torch.cat([x,x.orig], dim=1) if self.dense else (x+x.orig)

## Ready-to-go models

In [None]:
class ResBlock(nn.Module):
    def __init__(self, expansion, ni, nh, stride=1, norm_type=NormType.Batch, **kwargs):
        super().__init__()
        norm2 = NormType.BatchZero if norm_type==NormType.Batch else norm_type 
        nf,ni = nh*expansion,ni*expansion
        layers  = [ConvLayer(ni, nh, 3, stride=stride, norm_type=norm_type, **kwargs),
                   ConvLayer(nh, nf, 3, norm_type=norm2, act=None)
        ] if expansion == 1 else [
                   ConvLayer(ni, nh, 1, norm_type=norm_type, **kwargs),
                   ConvLayer(nh, nh, 3, stride=stride, norm_type=norm_type, **kwargs),
                   ConvLayer(nh, nf, 1, norm_type=norm2, act=None, **kwargs)
        ]
        self.convs = nn.Sequential(*layers)
        self.idconv = noop if ni==nf else ConvLayer(ni, nf, 1, act=None, **kwargs)
        self.pool = noop if stride==1 else nn.AvgPool2d(2, ceil_mode=True)

    def forward(self, x): return act_fn(self.convs(x) + self.idconv(self.pool(x)))

In [None]:
class SimpleCNN(nn.Sequential):
    def __init__(self, filters, kernel_szs=None, strides=None, bn=True):
        nl = len(filters)-1
        kernel_szs = ifnone(kernel_szs, [3]*nl)
        strides    = ifnone(strides   , [2]*nl)
        layers = [ConvLayer(filters[i], filters[i+1], kernel_szs[i], stride=strides[i],
                  norm_type=(NormType.Batch if bn and i<nl-1 else None)) for i in range_of(strides)]
        layers.append(PoolFlatten())
        super().__init__(*layers)

In [None]:
SimpleCNN([16,32,64])

SimpleCNN(
  (0): ConvLayer(
    (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): ReLU()
    (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (1): ConvLayer(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): ReLU()
  )
  (2): PoolFlatten(
    (0): AdaptiveAvgPool2d(output_size=1)
    (1): Flatten()
  )
)