In [1]:
from nbdev import *

In [2]:
%nbdev_default_export activations

Cells will be exported to model_constructor.activations,
unless a different module is specified after an export flag: `%nbdev_export special.module`


# Activations functions.

> Activations functions.

Activation functions, forked from https://github.com/rwightman/pytorch-image-models/timm/models/layers/activations.py

Mish: Self Regularized  
Non-Monotonic Activation Function  
https://github.com/digantamisra98/Mish  
fastai forum discussion https://forums.fast.ai/t/meet-mish-new-activation-function-possible-successor-to-relu  

In [3]:
%nbdev_export
# forked from https://github.com/rwightman/pytorch-image-models/timm/models/layers/activations.py
import torch
from torch import nn as nn
from torch.nn import functional as F

# Swish

In [4]:
%nbdev_export
def swish(x, inplace: bool = False):
    """Swish - Described in: https://arxiv.org/abs/1710.05941"""
    return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())

class Swish(nn.Module):
    """Swish - Described in: https://arxiv.org/abs/1710.05941"""
    def __init__(self, inplace: bool = False):
        super(Swish, self).__init__()
        self.inplace = inplace

    def forward(self, x):
        return swish(x, self.inplace)

# SwishJit

In [33]:
%nbdev_export
@torch.jit.script
def swish_jit(x, inplace: bool = False):
    """Jit version of Swish.
    Swish- Described in: https://arxiv.org/abs/1710.05941
    """
    return x.mul(x.sigmoid())

class SwishJit(nn.Module):
    """Jit version of Swish. 
    Swish - Described in: https://arxiv.org/abs/1710.05941"""
    def __init__(self, inplace: bool = False):
        super(SwishJit, self).__init__()

    def forward(self, x):
        return swish_jit(x)

# SwishJitMe - memory-efficient.

In [35]:
%nbdev_export

@torch.jit.script
def swish_jit_bwd(x, grad_output):
    x_sigmoid = torch.sigmoid(x)
    return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))


class SwishJitAutoFn(torch.autograd.Function):
    """ torch.jit.script optimised Swish w/ memory-efficient checkpoint
    Inspired by conversation btw Jeremy Howard & Adam Pazske
    https://twitter.com/jeremyphoward/status/1188251041835315200
    """

    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return swish_jit_fwd(x)

    @staticmethod
    def backward(ctx, grad_output):
        x = ctx.saved_tensors[0]
        return swish_jit_bwd(x, grad_output)


def swish_me(x, inplace=False):
    return SwishJitAutoFn.apply(x)


class SwishMe(nn.Module):
    def __init__(self, inplace: bool = False):
        super(SwishMe, self).__init__()

    def forward(self, x):
        return SwishJitAutoFn.apply(x)

# Mish

In [15]:
%nbdev_export
def mish(x, inplace: bool = False):
    """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
    NOTE: I don't have a working inplace variant
    """
    return x.mul(F.softplus(x).tanh())


class Mish(nn.Module):
    """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681"""
    def __init__(self, inplace: bool = False):
        """NOTE: inplace variant not working """
        super(Mish, self).__init__()

    def forward(self, x):
        return mish(x)

# MishJit

In [32]:
%nbdev_export
@torch.jit.script
def mish_jit(x, _inplace: bool = False):
    """Jit version of Mish. 
    Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
    """
    return x.mul(F.softplus(x).tanh())

class MishJit(nn.Module):
    def __init__(self, inplace: bool = False):
        """Jit version of Mish. 
        Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681"""
        super(MishJit, self).__init__()

    def forward(self, x):
        return mish_jit(x)

# MishJitMe - memory-efficient.

In [40]:
%nbdev_export

@torch.jit.script
def mish_jit_fwd(x):
#     return x.mul(torch.tanh(F.softplus(x)))
    return x.mul(F.softplus(x).tanh())


@torch.jit.script
def mish_jit_bwd(x, grad_output):
    x_sigmoid = torch.sigmoid(x)
    x_tanh_sp = F.softplus(x).tanh()
    return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))


class MishJitAutoFn(torch.autograd.Function):
    """ Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
    A memory efficient, jit scripted variant of Mish"""
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return mish_jit_fwd(x)

    @staticmethod
    def backward(ctx, grad_output):
        x = ctx.saved_tensors[0]
        return mish_jit_bwd(x, grad_output)


def mish_me(x, inplace=False):
    return MishJitAutoFn.apply(x)


class MishMe(nn.Module):
    """ Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
    A memory efficient, jit scripted variant of Mish"""
    def __init__(self, inplace: bool = False):
        super(MishMe, self).__init__()

    def forward(self, x):
        return MishJitAutoFn.apply(x)


# HardSwish

In [7]:
%nbdev_export
def hard_swish(x, inplace: bool = False):
    """Hard swish activation function"""
    inner = F.relu6(x + 3.).div_(6.)
    return x.mul_(inner) if inplace else x.mul(inner)


class HardSwish(nn.Module):
    """Hard swish activation function"""
    def __init__(self, inplace: bool = False):
        super(HardSwish, self).__init__()
        self.inplace = inplace

    def forward(self, x):
        return hard_swish(x, self.inplace)

# HardSwishJit

In [None]:
%nbdev_export
@torch.jit.script
def hard_swish_jit(x, inplace: bool = False):
    # return x * (F.relu6(x + 3.) / 6)
    return x * (x + 3).clamp(min=0, max=6).div(6.)  # clamp seems ever so slightly faster?


class HardSwishJit(nn.Module):
    def __init__(self, inplace: bool = False):
        super(HardSwishJit, self).__init__()

    def forward(self, x):
        return hard_swish_jit(x)

# HardSwishJitMe

In [38]:
%nbdev_export

@torch.jit.script
def hard_swish_jit_fwd(x):
    return x * (x + 3).clamp(min=0, max=6).div(6.)


@torch.jit.script
def hard_swish_jit_bwd(x, grad_output):
    m = torch.ones_like(x) * (x >= 3.)
    m = torch.where((x >= -3.) & (x <= 3.),  x / 3. + .5, m)
    return grad_output * m


class HardSwishJitAutoFn(torch.autograd.Function):
    """A memory efficient, jit-scripted HardSwish activation"""
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return hard_swish_jit_fwd(x)

    @staticmethod
    def backward(ctx, grad_output):
        x = ctx.saved_tensors[0]
        return hard_swish_jit_bwd(x, grad_output)


def hard_swish_me(x, inplace=False):
    """A memory efficient, jit-scripted HardSwish activation"""
    return HardSwishJitAutoFn.apply(x)


class HardSwishMe(nn.Module):
    """A memory efficient, jit-scripted HardSwish activation"""
    def __init__(self, inplace: bool = False):
        super(HardSwishMe, self).__init__()

    def forward(self, x):
        return HardSwishJitAutoFn.apply(x)

# HardMish

In [25]:
%nbdev_export
def hard_mish(x, inplace: bool = False):
    """ Hard Mish
    Experimental, based on notes by Mish author Diganta Misra at
      https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
    """
    if inplace:
        return x.mul_(0.5 * (x + 2).clamp(min=0, max=2))
    else:
        return 0.5 * x * (x + 2).clamp(min=0, max=2)


class HardMish(nn.Module):
    """Hard Mish, Experimental, based on notes by Mish author Diganta Misra at
      https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md"""
    def __init__(self, inplace: bool = False):
        super(HardMish, self).__init__()
        self.inplace = inplace

    def forward(self, x):
        return hard_mish(x, self.inplace)

# HardMishJit

In [34]:
%nbdev_export
@torch.jit.script
def hard_mish_jit(x, inplace: bool = False):
    """ Hard Mish
    Experimental, based on notes by Mish author Diganta Misra at
      https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
    """
    return 0.5 * x * (x + 2).clamp(min=0, max=2)


class HardMishJit(nn.Module):
    """ Hard Mish
    Experimental, based on notes by Mish author Diganta Misra at
      https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
    """
    def __init__(self, inplace: bool = False):
        super(HardMishJit, self).__init__()

    def forward(self, x):
        return hard_mish_jit(x)

# HardMishJitMe - memory efficient.

In [None]:
%nbdev_export
@torch.jit.script
def hard_mish_jit_fwd(x):
    return 0.5 * x * (x + 2).clamp(min=0, max=2)


@torch.jit.script
def hard_mish_jit_bwd(x, grad_output):
    m = torch.ones_like(x) * (x >= -2.)
    m = torch.where((x >= -2.) & (x <= 0.), x + 1., m)
    return grad_output * m


class HardMishJitAutoFn(torch.autograd.Function):
    """ A memory efficient, jit scripted variant of Hard Mish
    Experimental, based on notes by Mish author Diganta Misra at
      https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
    """
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return hard_mish_jit_fwd(x)

    @staticmethod
    def backward(ctx, grad_output):
        x = ctx.saved_tensors[0]
        return hard_mish_jit_bwd(x, grad_output)


def hard_mish_me(x, inplace: bool = False):
    return HardMishJitAutoFn.apply(x)


class HardMishMe(nn.Module):
    """ A memory efficient, jit scripted variant of Hard Mish
    Experimental, based on notes by Mish author Diganta Misra at
      https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
    """
    def __init__(self, inplace: bool = False):
        super(HardMishMe, self).__init__()

    def forward(self, x):
        return HardMishJitAutoFn.apply(x)

In [None]:
%nbdev_hide
act_fn = Swish(inplace=True)

In [None]:
%nbdev_hide
act_fn = Mish(inplace=True)

# end
model_constructor
by ayasyrev

In [41]:
%nbdev_hide
from nbdev.export import *
notebook2script()

Converted 00_Net.ipynb.
Converted 01_activations.ipynb.
Converted 01_layers.ipynb.
Converted 03_MXResNet.ipynb.
Converted 04_YaResNet.ipynb.
Converted 05_Twist.ipynb.
Converted 10_base_constructor.ipynb.
Converted 11_xresnet.ipynb.
Converted index.ipynb.
