Skip to content

Commit

Permalink
Update resnet
Browse files Browse the repository at this point in the history
  • Loading branch information
ericup committed Feb 29, 2024
1 parent a656b6b commit 97549a4
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 31 deletions.
113 changes: 83 additions & 30 deletions celldetection/models/resnet.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import torch
from torch import nn
from torch.nn import functional as F
from torchvision.models import resnet as tvr
from ..util.util import Dict, lookup_nn, get_nd_conv, get_nn
from os.path import isfile
from ..util.util import Dict, lookup_nn, get_nd_conv, get_nn, resolve_pretrained
from torch.hub import load_state_dict_from_url
from .ppm import append_pyramid_pooling_
from typing import Type, Union, Optional
from pytorch_lightning.core.mixins import HyperparametersMixin

__all__ = ['get_resnet', 'ResNet50', 'ResNet34', 'ResNet18', 'ResNet152', 'ResNet101', 'WideResNet101_2',
'WideResNet50_2', 'ResNeXt152_32x8d', 'ResNeXt101_32x8d', 'ResNeXt50_32x4d']
Expand All @@ -22,15 +25,23 @@
}


def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1,
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1, kernel_size=3,
nd=2) -> nn.Conv2d:
"""3x3 convolution with padding"""
if isinstance(kernel_size, int):
kernel_size = (kernel_size,) * nd
if isinstance(dilation, int):
dilation = (dilation,) * nd

# Calculate padding for 'same' padding
padding = tuple((ks - 1) * dil // 2 for ks, dil in zip(kernel_size, dilation))

return get_nd_conv(nd)(
in_planes,
out_planes,
kernel_size=3,
kernel_size=kernel_size,
stride=stride,
padding=dilation,
padding=padding,
groups=groups,
bias=False,
dilation=dilation,
Expand All @@ -56,6 +67,7 @@ def __init__( # Port from torchvision (to support 3d and add more features)
base_width: int = 64,
dilation: int = 1,
norm_layer='batchnorm2d',
kernel_size=3,
nd=2
) -> None:
super().__init__()
Expand All @@ -64,7 +76,7 @@ def __init__( # Port from torchvision (to support 3d and add more features)
raise ValueError("BasicBlock only supports groups=1 and base_width=64")
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
self.conv1 = conv3x3(inplanes, planes, stride, nd=nd)
self.conv1 = conv3x3(inplanes, planes, stride, nd=nd, kernel_size=kernel_size)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes, nd=nd)
Expand All @@ -87,14 +99,15 @@ def __init__( # Port from torchvision (to support 3d and add more features)
base_width: int = 64,
dilation: int = 1,
norm_layer='batchnorm2d',
kernel_size=3,
nd=2
) -> None:
super().__init__()
norm_layer = lookup_nn(norm_layer, call=False, nd=nd)
width = int(planes * (base_width / 64.0)) * groups
self.conv1 = conv1x1(inplanes, width, nd=nd)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation, nd=nd)
self.conv2 = conv3x3(width, width, stride, groups, dilation, kernel_size=kernel_size, nd=nd)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion, nd=nd)
self.bn3 = norm_layer(planes * self.expansion)
Expand All @@ -110,6 +123,7 @@ def _make_layer( # Port from torchvision (to support 3d)
blocks: int,
stride: int = 1,
dilate: bool = False,
kernel_size: int = 3,
nd=2,
secondary_block=None,
downsample_method=None,
Expand All @@ -126,6 +140,7 @@ def _make_layer( # Port from torchvision (to support 3d)
blocks:
stride:
dilate:
kernel_size:
nd:
secondary_block:
downsample_method: Downsample method. None: 1x1Conv with stride, Norm (standard ResNet),
Expand Down Expand Up @@ -160,7 +175,7 @@ def _make_layer( # Port from torchvision (to support 3d)
layers = []
layers.append(
block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer,
nd=nd))
kernel_size=kernel_size, nd=nd))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(
Expand All @@ -170,6 +185,7 @@ def _make_layer( # Port from torchvision (to support 3d)
base_width=self.base_width,
dilation=self.dilation,
norm_layer=norm_layer,
kernel_size=kernel_size,
nd=nd,
))
if secondary_block is not None:
Expand All @@ -178,7 +194,8 @@ def _make_layer( # Port from torchvision (to support 3d)


def make_res_layer(block, inplanes, planes, blocks, norm_layer=nn.BatchNorm2d, base_width=64, groups=1, stride=1,
dilation=1, dilate=False, nd=2, secondary_block=None, **kwargs) -> nn.Module:
dilation=1, dilate=False, nd=2, secondary_block=None, downsample_method=None, kernel_size=3,
**kwargs) -> nn.Module:
"""
Args:
Expand All @@ -194,6 +211,8 @@ def make_res_layer(block, inplanes, planes, blocks, norm_layer=nn.BatchNorm2d, b
dilate:
nd:
secondary_block:
downsample_method:
kernel_size:
kwargs:
Returns:
Expand All @@ -204,7 +223,7 @@ def make_res_layer(block, inplanes, planes, blocks, norm_layer=nn.BatchNorm2d, b
groups=groups, dilation=dilation) # almost a ResNet

return _make_layer(self=d, block=block, planes=planes, blocks=blocks, stride=stride, dilate=dilate, nd=nd,
secondary_block=secondary_block)
secondary_block=secondary_block, downsample_method=downsample_method, kernel_size=kernel_size)


def _apply_mapping_rules(key, rules: dict):
Expand Down Expand Up @@ -243,7 +262,7 @@ def map_state_dict(in_channels, state_dict, fused_initial):
return mapping


class ResNet(nn.Sequential):
class ResNet(nn.Sequential, HyperparametersMixin):
def __init__(self, in_channels, *body: nn.Module, initial_strides=2, base_channel=64, initial_pooling=True,
final_layer=None, final_activation=None, fused_initial=True, pretrained=False,
pyramid_pooling=False, pyramid_pooling_channels=64, pyramid_pooling_kwargs=None, nd=2, **kwargs):
Expand All @@ -270,32 +289,37 @@ def __init__(self, in_channels, *body: nn.Module, initial_strides=2, base_channe
components += [lookup_nn(final_activation)]
super(ResNet, self).__init__(*components)
if pretrained:
if isinstance(pretrained, str):
state_dict = load_state_dict_from_url(pretrained)
if '.pytorch.org' in pretrained:
state_dict = map_state_dict(in_channels, state_dict, fused_initial=fused_initial)
self.load_state_dict(state_dict)
else:
raise ValueError('There is no default set of weights for this model. '
'Please specify a URL using the `pretrained` argument.')
state_dict = resolve_pretrained(pretrained, in_channels=in_channels, fused_initial=fused_initial,
state_dict_mapper=map_state_dict)
self.load_state_dict(state_dict, strict=kwargs.get('pretrained_strict', True))
if pyramid_pooling:
pyramid_pooling_kwargs = {} if pyramid_pooling_kwargs is None else pyramid_pooling_kwargs
append_pyramid_pooling_(self, pyramid_pooling_channels, nd=nd, **pyramid_pooling_kwargs)


class VanillaResNet(ResNet):
def __init__(self, in_channels, out_channels=0, layers=(2, 2, 2, 2), base_channel=64, fused_initial=True, nd=2,
**kwargs):
def __init__(self, in_channels, out_channels=0, layers=(2, 2, 2, 2), base_channel=64, fused_initial=True,
kernel_size=3, per_layer_kernel_sizes: dict = None, nd=2, **kwargs):
if per_layer_kernel_sizes is None:
per_layer_kernel_sizes = {}
if isinstance(per_layer_kernel_sizes, (tuple, list)):
per_layer_kernel_sizes = {i: v for i, v in enumerate(per_layer_kernel_sizes)}
self.save_hyperparameters()
self.out_channels = oc = (base_channel, base_channel * 2, base_channel * 4, base_channel * 8)
self.out_strides = (4, 8, 16, 32)
if out_channels and 'final_layer' not in kwargs.keys():
kwargs['final_layer'] = get_nd_conv(nd)(self.out_channels[-1], out_channels, 1)

super(VanillaResNet, self).__init__(
in_channels,
make_res_layer(BasicBlock, base_channel, oc[0], layers[0], stride=1, nd=nd, **kwargs),
make_res_layer(BasicBlock, oc[0], oc[1], layers[1], stride=2, nd=nd, **kwargs),
make_res_layer(BasicBlock, oc[1], oc[2], layers[2], stride=2, nd=nd, **kwargs),
make_res_layer(BasicBlock, oc[2], oc[3], layers[3], stride=2, nd=nd, **kwargs),
make_res_layer(BasicBlock, base_channel, oc[0], layers[0], stride=1, nd=nd,
kernel_size=per_layer_kernel_sizes.get(0, kernel_size), **kwargs),
make_res_layer(BasicBlock, oc[0], oc[1], layers[1], stride=2, nd=nd,
kernel_size=per_layer_kernel_sizes.get(1, kernel_size), **kwargs),
make_res_layer(BasicBlock, oc[1], oc[2], layers[2], stride=2, nd=nd,
kernel_size=per_layer_kernel_sizes.get(2, kernel_size), **kwargs),
make_res_layer(BasicBlock, oc[2], oc[3], layers[3], stride=2, nd=nd,
kernel_size=per_layer_kernel_sizes.get(3, kernel_size), **kwargs),
base_channel=base_channel, fused_initial=fused_initial, nd=nd, **kwargs
)
if not fused_initial:
Expand All @@ -322,6 +346,8 @@ def __init__(self, in_channels, out_channels=0, pretrained=False, nd=2, **kwargs
pretrained = default_model_urls['ResNet18']
super(ResNet18, self).__init__(in_channels, out_channels=out_channels, layers=(2, 2, 2, 2),
pretrained=pretrained, nd=nd, **kwargs)
self.hparams.clear()
self.save_hyperparameters()


class ResNet34(VanillaResNet):
Expand All @@ -330,24 +356,35 @@ def __init__(self, in_channels, out_channels=0, pretrained=False, nd=2, **kwargs
pretrained = default_model_urls['ResNet34']
super(ResNet34, self).__init__(in_channels, out_channels=out_channels, layers=(3, 4, 6, 3),
pretrained=pretrained, nd=nd, **kwargs)
self.hparams.clear()
self.save_hyperparameters()

__init__.__doc__ = ResNet18.__init__.__doc__.replace('ResNet 18', 'ResNet 34')


class BottleResNet(ResNet):
def __init__(self, in_channels, out_channels=0, layers=(3, 4, 6, 3), base_channel=64, fused_initial=True, nd=2,
**kwargs):
def __init__(self, in_channels, out_channels=0, layers=(3, 4, 6, 3), base_channel=64, fused_initial=True,
kernel_size=3, per_layer_kernel_sizes: dict = None, nd=2, **kwargs):
if per_layer_kernel_sizes is None:
per_layer_kernel_sizes = {}
if isinstance(per_layer_kernel_sizes, (tuple, list)):
per_layer_kernel_sizes = {i: v for i, v in enumerate(per_layer_kernel_sizes)}
self.save_hyperparameters()
ex = Bottleneck.expansion
self.out_channels = oc = (base_channel * 4, base_channel * 8, base_channel * 16, base_channel * 32)
self.out_strides = (4, 8, 16, 32)
if out_channels and 'final_layer' not in kwargs.keys():
kwargs['final_layer'] = nn.Conv2d(self.out_channels[-1], out_channels, 1)
super(BottleResNet, self).__init__(
in_channels,
make_res_layer(Bottleneck, base_channel, oc[0] // ex, layers[0], stride=1, nd=nd, **kwargs),
make_res_layer(Bottleneck, base_channel * 4, oc[1] // ex, layers[1], stride=2, nd=nd, **kwargs),
make_res_layer(Bottleneck, base_channel * 8, oc[2] // ex, layers[2], stride=2, nd=nd, **kwargs),
make_res_layer(Bottleneck, base_channel * 16, oc[3] // ex, layers[3], stride=2, nd=nd, **kwargs),
make_res_layer(Bottleneck, base_channel, oc[0] // ex, layers[0], stride=1, nd=nd,
kernel_size=per_layer_kernel_sizes.get(0, kernel_size), **kwargs),
make_res_layer(Bottleneck, base_channel * 4, oc[1] // ex, layers[1], stride=2, nd=nd,
kernel_size=per_layer_kernel_sizes.get(1, kernel_size), **kwargs),
make_res_layer(Bottleneck, base_channel * 8, oc[2] // ex, layers[2], stride=2, nd=nd,
kernel_size=per_layer_kernel_sizes.get(2, kernel_size), **kwargs),
make_res_layer(Bottleneck, base_channel * 16, oc[3] // ex, layers[3], stride=2, nd=nd,
kernel_size=per_layer_kernel_sizes.get(3, kernel_size), **kwargs),
base_channel=base_channel, fused_initial=fused_initial, nd=nd, **kwargs
)
if not fused_initial:
Expand All @@ -361,6 +398,8 @@ def __init__(self, in_channels, out_channels=0, pretrained=False, nd=2, **kwargs
pretrained = default_model_urls['ResNet50']
super(ResNet50, self).__init__(in_channels, out_channels=out_channels, layers=(3, 4, 6, 3),
pretrained=pretrained, nd=nd, **kwargs)
self.hparams.clear()
self.save_hyperparameters()

__init__.__doc__ = ResNet18.__init__.__doc__.replace('ResNet 18', 'ResNet 50')

Expand All @@ -371,6 +410,8 @@ def __init__(self, in_channels, out_channels=0, pretrained=False, nd=2, **kwargs
pretrained = default_model_urls['ResNet101']
super(ResNet101, self).__init__(in_channels, out_channels=out_channels, layers=(3, 4, 23, 3),
pretrained=pretrained, nd=nd, **kwargs)
self.hparams.clear()
self.save_hyperparameters()

__init__.__doc__ = ResNet18.__init__.__doc__.replace('ResNet 18', 'ResNet 101')

Expand All @@ -381,6 +422,8 @@ def __init__(self, in_channels, out_channels=0, pretrained=False, nd=2, **kwargs
pretrained = default_model_urls['ResNet152']
super(ResNet152, self).__init__(in_channels, out_channels=out_channels, layers=(3, 8, 36, 3),
pretrained=pretrained, nd=nd, **kwargs)
self.hparams.clear()
self.save_hyperparameters()

__init__.__doc__ = ResNet18.__init__.__doc__.replace('ResNet 18', 'ResNet 152')

Expand All @@ -391,6 +434,8 @@ def __init__(self, in_channels, out_channels=0, pretrained=False, nd=2, **kwargs
pretrained = default_model_urls['ResNeXt50_32x4d']
super(ResNeXt50_32x4d, self).__init__(in_channels, out_channels=out_channels, layers=(3, 4, 6, 3), groups=32,
base_width=4, pretrained=pretrained, nd=nd, **kwargs)
self.hparams.clear()
self.save_hyperparameters()

__init__.__doc__ = ResNet18.__init__.__doc__.replace('ResNet 18', 'ResNeXt 50')

Expand All @@ -401,6 +446,8 @@ def __init__(self, in_channels, out_channels=0, pretrained=False, nd=2, **kwargs
pretrained = default_model_urls['ResNeXt101_32x8d']
super(ResNeXt101_32x8d, self).__init__(in_channels, out_channels=out_channels, layers=(3, 4, 23, 3), groups=32,
base_width=8, pretrained=pretrained, nd=nd, **kwargs)
self.hparams.clear()
self.save_hyperparameters()

__init__.__doc__ = ResNet18.__init__.__doc__.replace('ResNet 18', 'ResNeXt 101')

Expand All @@ -409,6 +456,8 @@ class ResNeXt152_32x8d(BottleResNet):
def __init__(self, in_channels, out_channels=0, nd=2, **kwargs):
super(ResNeXt152_32x8d, self).__init__(in_channels, out_channels=out_channels, layers=(3, 8, 36, 3), groups=32,
base_width=8, nd=nd, **kwargs)
self.hparams.clear()
self.save_hyperparameters()

__init__.__doc__ = ResNet18.__init__.__doc__.replace('ResNet 18', 'ResNeXt 152')

Expand All @@ -419,6 +468,8 @@ def __init__(self, in_channels, out_channels=0, pretrained=False, nd=2, **kwargs
pretrained = default_model_urls['WideResNet50_2']
super(WideResNet50_2, self).__init__(in_channels, out_channels=out_channels, layers=(3, 4, 6, 3),
base_width=64 * 2, pretrained=pretrained, nd=nd, **kwargs)
self.hparams.clear()
self.save_hyperparameters()

__init__.__doc__ = ResNet18.__init__.__doc__.replace('ResNet 18', 'Wide ResNet 50')

Expand All @@ -429,6 +480,8 @@ def __init__(self, in_channels, out_channels=0, pretrained=False, nd=2, **kwargs
pretrained = default_model_urls['WideResNet101_2']
super(WideResNet101_2, self).__init__(in_channels, out_channels=out_channels, layers=(3, 4, 23, 3),
base_width=64 * 2, pretrained=pretrained, nd=nd, **kwargs)
self.hparams.clear()
self.save_hyperparameters()

__init__.__doc__ = ResNet18.__init__.__doc__.replace('ResNet 18', 'Wide ResNet 101')

Expand Down
2 changes: 1 addition & 1 deletion celldetection/util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
'image_to_base64', 'base64_to_image', 'model2dict', 'dict2model', 'is_ipython', 'grouped_glob',
'tweak_attribute_', 'to_batched_h5', 'compare_file_hashes', 'import_file', 'load_imagej_rois',
'glob_h5_split', 'say_goodbye', 'parse_url_params', 'save_requirements', 'get_installed_packages',
'resolve_model', 'is_package_installed', 'has_argument', 'dict_to_json_string']
'resolve_model', 'is_package_installed', 'has_argument', 'dict_to_json_string', 'resolve_pretrained']


def copy_script(dst, no_script_okay=True, frame=None, verbose=False):
Expand Down

0 comments on commit 97549a4

Please sign in to comment.