Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
achaiah
authored and
achaiah
committed
Jan 15, 2020
1 parent
1a15aff
commit 1dc03c1
Showing
8 changed files
with
316 additions
and
62 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# Source: https://github.com/rwightman/gen-efficientnet-pytorch/blob/master/geffnet/activations/activations_autofn.py (Apache 2.0) | ||
|
||
import torch | ||
from torch import nn as nn | ||
from torch.nn import functional as F | ||
|
||
|
||
__all__ = ['swish_auto', 'SwishAuto', 'mish_auto', 'MishAuto'] | ||
|
||
|
||
class SwishAutoFn(torch.autograd.Function): | ||
"""Swish - Described in: https://arxiv.org/abs/1710.05941 | ||
Memory efficient variant from: | ||
https://medium.com/the-artificial-impostor/more-memory-efficient-swish-activation-function-e07c22c12a76 | ||
""" | ||
@staticmethod | ||
def forward(ctx, x): | ||
result = x.mul(torch.sigmoid(x)) | ||
ctx.save_for_backward(x) | ||
return result | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
x = ctx.saved_tensors[0] | ||
x_sigmoid = torch.sigmoid(x) | ||
return grad_output.mul(x_sigmoid * (1 + x * (1 - x_sigmoid))) | ||
|
||
|
||
def swish_auto(x, inplace=False): | ||
# inplace ignored | ||
return SwishAutoFn.apply(x) | ||
|
||
|
||
class SwishAuto(nn.Module): | ||
def __init__(self, inplace: bool = False): | ||
super(SwishAuto, self).__init__() | ||
self.inplace = inplace | ||
|
||
def forward(self, x): | ||
return SwishAutoFn.apply(x) | ||
|
||
|
||
class MishAutoFn(torch.autograd.Function): | ||
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 | ||
Experimental memory-efficient variant | ||
""" | ||
|
||
@staticmethod | ||
def forward(ctx, x): | ||
ctx.save_for_backward(x) | ||
y = x.mul(torch.tanh(F.softplus(x))) # x * tanh(ln(1 + exp(x))) | ||
return y | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
x = ctx.saved_tensors[0] | ||
x_sigmoid = torch.sigmoid(x) | ||
x_tanh_sp = F.softplus(x).tanh() | ||
return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp)) | ||
|
||
|
||
def mish_auto(x, inplace=False): | ||
# inplace ignored | ||
return MishAutoFn.apply(x) | ||
|
||
|
||
class MishAuto(nn.Module): | ||
def __init__(self, inplace: bool = False): | ||
super(MishAuto, self).__init__() | ||
self.inplace = inplace | ||
|
||
def forward(self, x): | ||
return MishAutoFn.apply(x) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
# Source: https://github.com/rwightman/gen-efficientnet-pytorch/blob/master/geffnet/activations/activations_jit.py (Apache 2.0) | ||
|
||
import torch | ||
from torch import nn as nn | ||
from torch.nn import functional as F | ||
|
||
|
||
__all__ = ['swish_jit', 'SwishJit', 'mish_jit', 'MishJit'] | ||
#'hard_swish_jit', 'HardSwishJit', 'hard_sigmoid_jit', 'HardSigmoidJit'] | ||
|
||
|
||
@torch.jit.script | ||
def swish_jit_fwd(x): | ||
return x.mul(torch.sigmoid(x)) | ||
|
||
|
||
@torch.jit.script | ||
def swish_jit_bwd(x, grad_output): | ||
x_sigmoid = torch.sigmoid(x) | ||
return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid))) | ||
|
||
|
||
class SwishJitAutoFn(torch.autograd.Function): | ||
""" torch.jit.script optimised Swish | ||
Inspired by conversation btw Jeremy Howard & Adam Pazske | ||
https://twitter.com/jeremyphoward/status/1188251041835315200 | ||
""" | ||
@staticmethod | ||
def forward(ctx, x): | ||
ctx.save_for_backward(x) | ||
return swish_jit_fwd(x) | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
x = ctx.saved_tensors[0] | ||
return swish_jit_bwd(x, grad_output) | ||
|
||
|
||
def swish_jit(x, inplace=False): | ||
# inplace ignored | ||
return SwishJitAutoFn.apply(x) | ||
|
||
|
||
class SwishJit(nn.Module): | ||
def __init__(self, inplace: bool = False): | ||
super(SwishJit, self).__init__() | ||
self.inplace = inplace | ||
|
||
def forward(self, x): | ||
return SwishJitAutoFn.apply(x) | ||
|
||
|
||
@torch.jit.script | ||
def mish_jit_fwd(x): | ||
return x.mul(torch.tanh(F.softplus(x))) | ||
|
||
|
||
@torch.jit.script | ||
def mish_jit_bwd(x, grad_output): | ||
x_sigmoid = torch.sigmoid(x) | ||
x_tanh_sp = F.softplus(x).tanh() | ||
return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp)) | ||
|
||
|
||
class MishJitAutoFn(torch.autograd.Function): | ||
@staticmethod | ||
def forward(ctx, x): | ||
ctx.save_for_backward(x) | ||
return mish_jit_fwd(x) | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
x = ctx.saved_tensors[0] | ||
return mish_jit_bwd(x, grad_output) | ||
|
||
|
||
def mish_jit(x, inplace=False): | ||
# inplace ignored | ||
return MishJitAutoFn.apply(x) | ||
|
||
|
||
class MishJit(nn.Module): | ||
def __init__(self, inplace: bool = False): | ||
super(MishJit, self).__init__() | ||
self.inplace = inplace | ||
|
||
def forward(self, x): | ||
return MishJitAutoFn.apply(x) | ||
|
||
|
||
# @torch.jit.script | ||
# def hard_swish_jit(x, inplac: bool = False): | ||
# return x.mul(F.relu6(x + 3.).mul_(1./6.)) | ||
# | ||
# | ||
# class HardSwishJit(nn.Module): | ||
# def __init__(self, inplace: bool = False): | ||
# super(HardSwishJit, self).__init__() | ||
# | ||
# def forward(self, x): | ||
# return hard_swish_jit(x) | ||
# | ||
# | ||
# @torch.jit.script | ||
# def hard_sigmoid_jit(x, inplace: bool = False): | ||
# return F.relu6(x + 3.).mul(1./6.) | ||
# | ||
# | ||
# class HardSigmoidJit(nn.Module): | ||
# def __init__(self, inplace: bool = False): | ||
# super(HardSigmoidJit, self).__init__() | ||
# | ||
# def forward(self, x): | ||
# return hard_sigmoid_jit(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# Source: https://github.com/rwightman/gen-efficientnet-pytorch/blob/master/geffnet/activations/activations.py (Apache 2.0) | ||
# Note. Cuda-compiled source can be found here: https://github.com/thomasbrandon/mish-cuda (MIT) | ||
|
||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
def mish(x, inplace: bool = False): | ||
"""Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 | ||
""" | ||
return x.mul(F.softplus(x).tanh()) | ||
|
||
class Mish(nn.Module): | ||
""" | ||
Mish - "Mish: A Self Regularized Non-Monotonic Neural Activation Function" | ||
https://arxiv.org/abs/1908.08681v1 | ||
implemented for PyTorch / FastAI by lessw2020 | ||
github: https://github.com/lessw2020/mish | ||
""" | ||
def __init__(self, inplace: bool = False): | ||
super(Mish, self).__init__() | ||
self.inplace = inplace | ||
|
||
def forward(self, x): | ||
return mish(x, self.inplace) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.