diff --git a/README.md b/README.md index f2277ff..875dd1b 100644 --- a/README.md +++ b/README.md @@ -29,8 +29,18 @@ have [docs](https://pywick.readthedocs.io/en/latest/)! They're still a work in progress though so apologies for anything that's broken. ## What's New (highlights) +- **Jan. 15, 2020** + - New release: 0.5.5 + - Mish activation function (SoTA) + - [rwightman's](https://github.com/rwightman/gen-efficientnet-pytorch) models of pretrained/ported variants for classification (44 total) + - efficientnet Tensorflow port b0-b8, with and without AP, el/em/es, cc + - mixnet L/M/S + - mobilenetv3 + - mnasnet + - spnasnet + - Additional loss functions - **Aug. 1, 2019** - - New segmentation NNs: BiSeNet, DANet, DenseASPP, DUNet, OCNet, PSANet + - New segmentation NNs: BiSeNet, DANet, DenseASPP, DUNet, OCNet, PSANet - New Loss Functions: Focal Tversky Loss, OHEM CrossEntropy Loss, various combination losses - Major restructuring and standardization of NN models and loading functionality - General bug fixes and code improvements @@ -40,8 +50,7 @@ work in progress though so apologies for anything that's broken. or specific version from git: -`pip -install git+https://github.com/achaiah/pywick.git@v0.5.3` +`pip install git+https://github.com/achaiah/pywick.git@v0.5.5` ## ModuleTrainer The `ModuleTrainer` class provides a high-level training interface which abstracts diff --git a/docs/source/README.md b/docs/source/README.md index f2277ff..875dd1b 100644 --- a/docs/source/README.md +++ b/docs/source/README.md @@ -29,8 +29,18 @@ have [docs](https://pywick.readthedocs.io/en/latest/)! They're still a work in progress though so apologies for anything that's broken. ## What's New (highlights) +- **Jan. 15, 2020** + - New release: 0.5.5 + - Mish activation function (SoTA) + - [rwightman's](https://github.com/rwightman/gen-efficientnet-pytorch) models of pretrained/ported variants for classification (44 total) + - efficientnet Tensorflow port b0-b8, with and without AP, el/em/es, cc + - mixnet L/M/S + - mobilenetv3 + - mnasnet + - spnasnet + - Additional loss functions - **Aug. 1, 2019** - - New segmentation NNs: BiSeNet, DANet, DenseASPP, DUNet, OCNet, PSANet + - New segmentation NNs: BiSeNet, DANet, DenseASPP, DUNet, OCNet, PSANet - New Loss Functions: Focal Tversky Loss, OHEM CrossEntropy Loss, various combination losses - Major restructuring and standardization of NN models and loading functionality - General bug fixes and code improvements @@ -40,8 +50,7 @@ work in progress though so apologies for anything that's broken. or specific version from git: -`pip -install git+https://github.com/achaiah/pywick.git@v0.5.3` +`pip install git+https://github.com/achaiah/pywick.git@v0.5.5` ## ModuleTrainer The `ModuleTrainer` class provides a high-level training interface which abstracts diff --git a/pywick/functions/activations_autofn.py b/pywick/functions/activations_autofn.py new file mode 100644 index 0000000..929d533 --- /dev/null +++ b/pywick/functions/activations_autofn.py @@ -0,0 +1,74 @@ +# Source: https://github.com/rwightman/gen-efficientnet-pytorch/blob/master/geffnet/activations/activations_autofn.py (Apache 2.0) + +import torch +from torch import nn as nn +from torch.nn import functional as F + + +__all__ = ['swish_auto', 'SwishAuto', 'mish_auto', 'MishAuto'] + + +class SwishAutoFn(torch.autograd.Function): + """Swish - Described in: https://arxiv.org/abs/1710.05941 + Memory efficient variant from: + https://medium.com/the-artificial-impostor/more-memory-efficient-swish-activation-function-e07c22c12a76 + """ + @staticmethod + def forward(ctx, x): + result = x.mul(torch.sigmoid(x)) + ctx.save_for_backward(x) + return result + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + x_sigmoid = torch.sigmoid(x) + return grad_output.mul(x_sigmoid * (1 + x * (1 - x_sigmoid))) + + +def swish_auto(x, inplace=False): + # inplace ignored + return SwishAutoFn.apply(x) + + +class SwishAuto(nn.Module): + def __init__(self, inplace: bool = False): + super(SwishAuto, self).__init__() + self.inplace = inplace + + def forward(self, x): + return SwishAutoFn.apply(x) + + +class MishAutoFn(torch.autograd.Function): + """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + Experimental memory-efficient variant + """ + + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + y = x.mul(torch.tanh(F.softplus(x))) # x * tanh(ln(1 + exp(x))) + return y + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + x_sigmoid = torch.sigmoid(x) + x_tanh_sp = F.softplus(x).tanh() + return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp)) + + +def mish_auto(x, inplace=False): + # inplace ignored + return MishAutoFn.apply(x) + + +class MishAuto(nn.Module): + def __init__(self, inplace: bool = False): + super(MishAuto, self).__init__() + self.inplace = inplace + + def forward(self, x): + return MishAutoFn.apply(x) + diff --git a/pywick/functions/activations_jit.py b/pywick/functions/activations_jit.py new file mode 100644 index 0000000..a22b424 --- /dev/null +++ b/pywick/functions/activations_jit.py @@ -0,0 +1,114 @@ +# Source: https://github.com/rwightman/gen-efficientnet-pytorch/blob/master/geffnet/activations/activations_jit.py (Apache 2.0) + +import torch +from torch import nn as nn +from torch.nn import functional as F + + +__all__ = ['swish_jit', 'SwishJit', 'mish_jit', 'MishJit'] +#'hard_swish_jit', 'HardSwishJit', 'hard_sigmoid_jit', 'HardSigmoidJit'] + + +@torch.jit.script +def swish_jit_fwd(x): + return x.mul(torch.sigmoid(x)) + + +@torch.jit.script +def swish_jit_bwd(x, grad_output): + x_sigmoid = torch.sigmoid(x) + return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid))) + + +class SwishJitAutoFn(torch.autograd.Function): + """ torch.jit.script optimised Swish + Inspired by conversation btw Jeremy Howard & Adam Pazske + https://twitter.com/jeremyphoward/status/1188251041835315200 + """ + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return swish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return swish_jit_bwd(x, grad_output) + + +def swish_jit(x, inplace=False): + # inplace ignored + return SwishJitAutoFn.apply(x) + + +class SwishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(SwishJit, self).__init__() + self.inplace = inplace + + def forward(self, x): + return SwishJitAutoFn.apply(x) + + +@torch.jit.script +def mish_jit_fwd(x): + return x.mul(torch.tanh(F.softplus(x))) + + +@torch.jit.script +def mish_jit_bwd(x, grad_output): + x_sigmoid = torch.sigmoid(x) + x_tanh_sp = F.softplus(x).tanh() + return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp)) + + +class MishJitAutoFn(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return mish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return mish_jit_bwd(x, grad_output) + + +def mish_jit(x, inplace=False): + # inplace ignored + return MishJitAutoFn.apply(x) + + +class MishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(MishJit, self).__init__() + self.inplace = inplace + + def forward(self, x): + return MishJitAutoFn.apply(x) + + +# @torch.jit.script +# def hard_swish_jit(x, inplac: bool = False): +# return x.mul(F.relu6(x + 3.).mul_(1./6.)) +# +# +# class HardSwishJit(nn.Module): +# def __init__(self, inplace: bool = False): +# super(HardSwishJit, self).__init__() +# +# def forward(self, x): +# return hard_swish_jit(x) +# +# +# @torch.jit.script +# def hard_sigmoid_jit(x, inplace: bool = False): +# return F.relu6(x + 3.).mul(1./6.) +# +# +# class HardSigmoidJit(nn.Module): +# def __init__(self, inplace: bool = False): +# super(HardSigmoidJit, self).__init__() +# +# def forward(self, x): +# return hard_sigmoid_jit(x) diff --git a/pywick/functions/mish.py b/pywick/functions/mish.py new file mode 100644 index 0000000..f95287b --- /dev/null +++ b/pywick/functions/mish.py @@ -0,0 +1,24 @@ +# Source: https://github.com/rwightman/gen-efficientnet-pytorch/blob/master/geffnet/activations/activations.py (Apache 2.0) +# Note. Cuda-compiled source can be found here: https://github.com/thomasbrandon/mish-cuda (MIT) + +import torch.nn as nn +import torch.nn.functional as F + +def mish(x, inplace: bool = False): + """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + """ + return x.mul(F.softplus(x).tanh()) + +class Mish(nn.Module): + """ + Mish - "Mish: A Self Regularized Non-Monotonic Neural Activation Function" + https://arxiv.org/abs/1908.08681v1 + implemented for PyTorch / FastAI by lessw2020 + github: https://github.com/lessw2020/mish + """ + def __init__(self, inplace: bool = False): + super(Mish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return mish(x, self.inplace) diff --git a/pywick/functions/swish.py b/pywick/functions/swish.py index d846487..aab6b11 100644 --- a/pywick/functions/swish.py +++ b/pywick/functions/swish.py @@ -1,5 +1,6 @@ # Source: https://forums.fast.ai/t/implementing-new-activation-functions-in-fastai-library/17697 +import torch import torch.nn as nn import torch.nn.functional as F @@ -51,4 +52,19 @@ def __init__(self, a=1.5, b = 2.): def forward(self, x): aria2 = 1 + ((F.exp(-x) ** self.b) ** (-self.a)) - return x * aria2 \ No newline at end of file + return x * aria2 + + +# Source: https://github.com/rwightman/gen-efficientnet-pytorch/blob/master/geffnet/activations/activations.py (Apache 2.0) +def hard_swish(x, inplace: bool = False): + inner = F.relu6(x + 3.).div_(6.) + return x.mul_(inner) if inplace else x.mul(inner) + + +class HardSwish(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSwish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return hard_swish(x, self.inplace) diff --git a/pywick/models/model_utils.py b/pywick/models/model_utils.py index e27d298..bf89fe4 100644 --- a/pywick/models/model_utils.py +++ b/pywick/models/model_utils.py @@ -4,6 +4,7 @@ from enum import Enum from torchvision import models as torch_models from torchvision.models.inception import InceptionAux +import torch import torch.nn as nn import os import errno @@ -62,7 +63,9 @@ def get_model(model_type, model_name, num_classes, pretrained=True, **kwargs): :param pretrained: (bool): whether to load the default pretrained version of the model NOTE! NOTE! For classification, the lowercase model names are the pretrained variants while the Uppercase model names are not. - It is IN ERROR to specify an Uppercase model name variant with pretrained=True but one can specify a lowercase model variant with pretrained=False + The only exception applies to torch.hub models (all efficientnet, mixnet, mobilenetv3, mnasnet, spnasnet variants) where a single + lower-case string can be used for vanilla and pretrained versions. Otherwise, it is IN ERROR to specify an Uppercase model name variant + with pretrained=True but one can specify a lowercase model variant with pretrained=False (default: True) :return: model """ @@ -74,60 +77,63 @@ def get_model(model_type, model_name, num_classes, pretrained=True, **kwargs): print("INFO: Loading Model: -- " + model_name + " with number of classes: " + str(num_classes)) if model_type == ModelType.CLASSIFICATION: - - # 1. Load model (pretrained or vanilla) - fc_name = get_fc_names(model_name=model_name, model_type=model_type)[-1:][0] # we're only interested in the last layer name - new_fc = None # Custom layer to replace with (if none, then it will be handled generically) - if model_name in torch_models.__dict__: - print('INFO: Loading torchvision model: {}\t Pretrained: {}'.format(model_name, pretrained)) - model = torch_models.__dict__[model_name](pretrained=pretrained) # find a model included in the torchvision package + torch_hub_names = torch.hub.list('rwightman/gen-efficientnet-pytorch') + if model_name in torch_hub_names: + model = torch.hub.load('rwightman/gen-efficientnet-pytorch', model_name, pretrained=pretrained, num_classes=num_classes) else: - net_list = ['fbresnet', 'inception', 'mobilenet', 'nasnet', 'polynet', 'resnext', 'se_resnet', 'senet', 'shufflenet', 'xception'] - if pretrained: - print('INFO: Loading a pretrained model: {}'.format(model_name)) - if 'dpn' in model_name: - model = classification.__dict__[model_name](pretrained=True) # find a model included in the pywick classification package - elif any(net_name in model_name for net_name in net_list): - model = classification.__dict__[model_name](pretrained='imagenet') + # 1. Load model (pretrained or vanilla) + fc_name = get_fc_names(model_name=model_name, model_type=model_type)[-1:][0] # we're only interested in the last layer name + new_fc = None # Custom layer to replace with (if none, then it will be handled generically) + if model_name in torch_models.__dict__: + print('INFO: Loading torchvision model: {}\t Pretrained: {}'.format(model_name, pretrained)) + model = torch_models.__dict__[model_name](pretrained=pretrained) # find a model included in the torchvision package else: - print('INFO: Loading a vanilla model: {}'.format(model_name)) - model = classification.__dict__[model_name](pretrained=None) # pretrained must be set to None for the extra models... go figure - - # 2. Create custom FC layers for non-standardized models - if 'squeezenet' in model_name: - final_conv = nn.Conv2d(512, num_classes, kernel_size=1) - new_fc = nn.Sequential( - nn.Dropout(p=0.5), - final_conv, - nn.ReLU(inplace=True), - nn.AvgPool2d(13, stride=1) - ) - model.num_classes = num_classes - elif 'vgg' in model_name: - new_fc = nn.Sequential( - nn.Linear(512 * 7 * 7, 4096), - nn.ReLU(True), - nn.Dropout(), - nn.Linear(4096, 4096), - nn.ReLU(True), - nn.Dropout(), - nn.Linear(4096, num_classes) - ) - elif 'inception3' in model_name.lower() or 'inception_v3' in model_name.lower(): - # Replace the extra aux_logits FC layer if aux_logits are enabled - if getattr(model, 'aux_logits', False): - model.AuxLogits = InceptionAux(768, num_classes) - elif 'dpn' in model_name.lower(): - old_fc = getattr(model, fc_name) - new_fc = nn.Conv2d(old_fc.in_channels, num_classes, kernel_size=1, bias=True) - - # 3. For standard FC layers (nn.Linear) perform a reflection lookup and generate a new FC - if new_fc is None: - old_fc = getattr(model, fc_name) - new_fc = nn.Linear(old_fc.in_features, num_classes) - - # 4. perform replacement of the last FC / Linear layer with a new one - setattr(model, fc_name, new_fc) + net_list = ['fbresnet', 'inception', 'mobilenet', 'nasnet', 'polynet', 'resnext', 'se_resnet', 'senet', 'shufflenet', 'xception'] + if pretrained: + print('INFO: Loading a pretrained model: {}'.format(model_name)) + if 'dpn' in model_name: + model = classification.__dict__[model_name](pretrained=True) # find a model included in the pywick classification package + elif any(net_name in model_name for net_name in net_list): + model = classification.__dict__[model_name](pretrained='imagenet') + else: + print('INFO: Loading a vanilla model: {}'.format(model_name)) + model = classification.__dict__[model_name](pretrained=None) # pretrained must be set to None for the extra models... go figure + + # 2. Create custom FC layers for non-standardized models + if 'squeezenet' in model_name: + final_conv = nn.Conv2d(512, num_classes, kernel_size=1) + new_fc = nn.Sequential( + nn.Dropout(p=0.5), + final_conv, + nn.ReLU(inplace=True), + nn.AvgPool2d(13, stride=1) + ) + model.num_classes = num_classes + elif 'vgg' in model_name: + new_fc = nn.Sequential( + nn.Linear(512 * 7 * 7, 4096), + nn.ReLU(True), + nn.Dropout(), + nn.Linear(4096, 4096), + nn.ReLU(True), + nn.Dropout(), + nn.Linear(4096, num_classes) + ) + elif 'inception3' in model_name.lower() or 'inception_v3' in model_name.lower(): + # Replace the extra aux_logits FC layer if aux_logits are enabled + if getattr(model, 'aux_logits', False): + model.AuxLogits = InceptionAux(768, num_classes) + elif 'dpn' in model_name.lower(): + old_fc = getattr(model, fc_name) + new_fc = nn.Conv2d(old_fc.in_channels, num_classes, kernel_size=1, bias=True) + + # 3. For standard FC layers (nn.Linear) perform a reflection lookup and generate a new FC + if new_fc is None: + old_fc = getattr(model, fc_name) + new_fc = nn.Linear(old_fc.in_features, num_classes) + + # 4. perform replacement of the last FC / Linear layer with a new one + setattr(model, fc_name, new_fc) return model @@ -231,7 +237,9 @@ def get_supported_models(type): pt_excludes.append(modname.split('.')[-1]) pt_names = [x for x in torch_models.__dict__.keys() if '__' not in x and x not in pt_excludes] # includes directory and filenames - return pywick_names + pt_names + torch_hub_names = torch.hub.list('rwightman/gen-efficientnet-pytorch') + + return pywick_names + pt_names + torch_hub_names else: return None diff --git a/setup.py b/setup.py index a8c197b..3e1da17 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup(name='pywick', - version='0.5.3', + version='0.5.5', description='High-level batteries-included training framework for Pytorch', author='Achaiah', install_requires=[