In [1]:
# default_exp layers

In [2]:
#hide
%load_ext autoreload
%autoreload 2

# Layers

> Common layers, blocks and utils.

In [3]:
# export
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

from typing import Sequence, Union, Tuple

In [4]:
#export
def identity(x):
    return x

class Identity():
    def __call__(self, x):
        return x

def exist(x):
    return x is not None

def ifnone(x, default):
    return default if x is None else x

In [5]:
# export 
def scale(x):
    return x*2 - 1

def unscale(x):
    return (x+1)/2

In [None]:
# export
def trainable_parameters(m:nn.Module):
    return [p for p in m.parameters() if p.requires_grad]

In [27]:
# export
class FullyConnected(nn.Sequential):

    def __init__(self, d_in:int, d_out:int, bn=False, preact=False, activation=nn.ReLU) -> None:
        layers = [activation()]
        if bn: layers.insert(0, nn.BatchNorm1d(d_in if preact else d_out))
        layers.insert(-1 if preact else 0, nn.Linear(d_in, d_out))
        super().__init__(*layers)

In [33]:
# export
class MLP(nn.Sequential):
    "Multi-layer perceptron"
    def __init__(self, d_in:int, d_out:int, d_h:int, n_layers:int, hiddens:Sequence=None, bn:bool=False, preact:bool=False) -> None:
        hiddens = ifnone(hiddens, [d_h]*n_layers)
        ds = [d_in] + hiddens
        layers = [FullyConnected(ds[i], ds[i+1], bn, preact) for i in range(len(ds)-1)]
        layers.append(FullyConnected(hiddens[-1], d_out, bn, preact) if preact else nn.Linear(hiddens[-1], d_out))
        super().__init__(*layers)

In [39]:
model = MLP(5, 10, 16, n_layers=3)
x = torch.randn(4, 5)
out = model(x)
assert out.shape == (4, 10)

In [26]:
#export
class Conv2dBlock(nn.Sequential):
    "Convolutional block. If preact is True will be BN-ACT-CONV as prposed in https://arxiv.org/abs/1603.05027"
    def __init__(self, c_in:int, c_out:int, ks:int, stride:int=1, padding:int=None, activation=nn.ReLU, preact=False):
        padding = ifnone(padding, (ks-1)//2)
        layers = [nn.BatchNorm2d(c_in if preact else c_out), activation(inplace=True)]
        layers.insert(-1 if preact else 0, nn.Conv2d(c_in, c_out, ks, stride, padding))
        super().__init__(*layers)



In [29]:
bs, c_in, c_out, h, w = 4, 3, 8, 4, 4 
conv = Conv2dBlock(c_in, c_out, 3, 2)
x = torch.randn(bs, c_in, h, w)
out = conv(x)
assert out.shape == (bs, c_out, (h+1)//2, (w+1)//2)

In [8]:
#hide
bs, c_in, c_out, h, w = 4, 3, 8, 24, 24 
conv = Conv2dBlock(c_in, c_out, 3, 2, preact=True)
x = torch.randn(bs, c_in, h, w)
out = conv(x)
assert out.shape == (bs, c_out, (h+1)//2, (w+1)//2)

In [9]:
#export
class ConvTranspose2dBlock(nn.Sequential):
    "Convolutional block. If preact is True will be BN-ACT-CONV as prposed in https://arxiv.org/abs/1603.05027"
    def __init__(self, c_in:int, c_out:int, ks:int, stride:int=1, padding:int=None, activation=nn.ReLU, preact=False):
        padding = ifnone(padding, (ks-1)//2)
        layers = [nn.BatchNorm2d(c_in if preact else c_out), activation(inplace=True)]
        layers.insert(-1 if preact else 0, nn.ConvTranspose2d(c_in, c_out, ks, stride, padding))
        super().__init__(*layers)

In [10]:
bs, c_in, c_out, h, w = 4, 16, 8, 10, 10 
conv = ConvTranspose2dBlock(c_in, c_out, 4, 2)
x = torch.randn(bs, c_in, h, w)
out = conv(x)
assert out.shape == (bs, c_out, h*2, w*2)

In [11]:
#export
class ResBlock(nn.Module):
    "Convolutional block with skip connection"
    def __init__(self, c_in:int, c_out:int, ks:Union[int, Tuple], stride:int=1, padding:int=None, activation=nn.ReLU):
        super().__init__()
        if isinstance(ks, int):
            ks = (ks, ks)
        self.conv = nn.Sequential(
            Conv2dBlock(c_in, c_out, ks[0], stride, padding, activation, preact=True),
            Conv2dBlock(c_out, c_out, ks[1], 1, padding, activation, preact=True)
        )

        skip_layers = []
        if stride != 1:
            skip_layers.append(nn.MaxPool2d(stride, ceil_mode=True))
        if c_in != c_out:
            skip_layers.append(nn.Conv2d(c_in, c_out, 1))
        self.skip = nn.Sequential(*skip_layers)

        self.act = activation()

    def forward(self, x):
        return self.act(self.skip(x) + self.conv(x))


In [12]:
bs, c_in, c_out, h, w = 4, 3, 8, 24, 24 
conv = ResBlock(c_in, c_out, 3, 1)
x = torch.randn(bs, c_in, h, w)
out = conv(x)
assert out.shape == (bs, c_out, h, w)

In [13]:
# hide
bs, c_in, c_out, h, w = 4, 3, 8, 24, 24 
conv = ResBlock(c_in, c_out, 3, 2)
x = torch.randn(bs, c_in, h, w)
out = conv(x)
assert out.shape == (bs, c_out, (h+1)//2, (w+1)//2)
conv = ResBlock(c_in, c_out, (3, 1))
out = conv(x)
assert out.shape == (bs, c_out, h, w)

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


In [14]:
# export
class ChanLayerNorm(nn.Module):
    "Channelwise LayerNorm"
    def __init__(self, d:int, **kwargs):
        super().__init__()
        self.ln = nn.LayerNorm(d, **kwargs)

    def forward(self, x):
        x = self.ln(x.permute(0,2,3,1))
        return x.permute(0,3,1,2).contiguous()

In [15]:
x = torch.randn(1, 3, 2, 2)
m = ChanLayerNorm(3)
out = m(x)
mu = out.mean(1)
assert torch.allclose(mu+1, torch.ones_like(mu))

In [32]:
#export
class ConvNet(nn.Sequential):
    "Stack of Conv2dBlocks"
    def __init__(self, c_in:int, ks:int=3, n_layers=4, channels:Sequence=None, preact=False) -> None:
        channels = ifnone(channels, [2**i for i in range(3, 3+n_layers)])
        layers = [Conv2dBlock(c_in, channels[0], ks, 2, preact=preact)]
        layers += [Conv2dBlock(channels[i], channels[i+1], ks, 2, preact=preact) for i in range(len(channels)-2)]
        layers += [Conv2dBlock(channels[-2], channels[-1], ks, 2) if preact else nn.Conv2d(channels[-2], channels[-1], ks, 2, padding=(ks-1)//2)]
        super().__init__(*layers)

In [33]:
model = ConvNet(1)
model

ConvNet(
  (0): Conv2dBlock(
    (0): Conv2d(1, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (1): Conv2dBlock(
    (0): Conv2d(8, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (2): Conv2dBlock(
    (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
)

In [18]:
class ResNet(nn.Module):

    def __init__(self, c_in):
        super().__init__()
        self.net = nn.Sequential(
            Conv2dBlock(c_in, 256, 4, 2),
            nn.Conv2d(256, 256, 4, 2, 1),
            ResBlock(256, 256, (3,1), 1, activation=Identity),
            ResBlock(256, 256, (3,1), 1, activation=Identity)
        )
        
    def forward(self, x):
        return self.net(x)

In [34]:
#hide
from nbdev.export import notebook2script; notebook2script()

Converted 00_layers.ipynb.
Converted 01_training.ipynb.
Converted 02_made.ipynb.
Converted 03_pixelcnn.ipynb.
Converted 10_experiments.pixelcnn.ipynb.
Converted index.ipynb.
