From 97549a41df534f181ab348f647d2b2ed8e7e4450 Mon Sep 17 00:00:00 2001 From: ericup Date: Thu, 29 Feb 2024 18:49:14 +0100 Subject: [PATCH] Update resnet --- celldetection/models/resnet.py | 113 ++++++++++++++++++++++++--------- celldetection/util/util.py | 2 +- 2 files changed, 84 insertions(+), 31 deletions(-) diff --git a/celldetection/models/resnet.py b/celldetection/models/resnet.py index a4bc62f..568769b 100644 --- a/celldetection/models/resnet.py +++ b/celldetection/models/resnet.py @@ -1,10 +1,13 @@ +import torch from torch import nn from torch.nn import functional as F from torchvision.models import resnet as tvr -from ..util.util import Dict, lookup_nn, get_nd_conv, get_nn +from os.path import isfile +from ..util.util import Dict, lookup_nn, get_nd_conv, get_nn, resolve_pretrained from torch.hub import load_state_dict_from_url from .ppm import append_pyramid_pooling_ from typing import Type, Union, Optional +from pytorch_lightning.core.mixins import HyperparametersMixin __all__ = ['get_resnet', 'ResNet50', 'ResNet34', 'ResNet18', 'ResNet152', 'ResNet101', 'WideResNet101_2', 'WideResNet50_2', 'ResNeXt152_32x8d', 'ResNeXt101_32x8d', 'ResNeXt50_32x4d'] @@ -22,15 +25,23 @@ } -def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1, +def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1, kernel_size=3, nd=2) -> nn.Conv2d: """3x3 convolution with padding""" + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * nd + if isinstance(dilation, int): + dilation = (dilation,) * nd + + # Calculate padding for 'same' padding + padding = tuple((ks - 1) * dil // 2 for ks, dil in zip(kernel_size, dilation)) + return get_nd_conv(nd)( in_planes, out_planes, - kernel_size=3, + kernel_size=kernel_size, stride=stride, - padding=dilation, + padding=padding, groups=groups, bias=False, dilation=dilation, @@ -56,6 +67,7 @@ def __init__( # Port from torchvision (to support 3d and add more features) base_width: int = 64, dilation: int = 1, norm_layer='batchnorm2d', + kernel_size=3, nd=2 ) -> None: super().__init__() @@ -64,7 +76,7 @@ def __init__( # Port from torchvision (to support 3d and add more features) raise ValueError("BasicBlock only supports groups=1 and base_width=64") if dilation > 1: raise NotImplementedError("Dilation > 1 not supported in BasicBlock") - self.conv1 = conv3x3(inplanes, planes, stride, nd=nd) + self.conv1 = conv3x3(inplanes, planes, stride, nd=nd, kernel_size=kernel_size) self.bn1 = norm_layer(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes, nd=nd) @@ -87,6 +99,7 @@ def __init__( # Port from torchvision (to support 3d and add more features) base_width: int = 64, dilation: int = 1, norm_layer='batchnorm2d', + kernel_size=3, nd=2 ) -> None: super().__init__() @@ -94,7 +107,7 @@ def __init__( # Port from torchvision (to support 3d and add more features) width = int(planes * (base_width / 64.0)) * groups self.conv1 = conv1x1(inplanes, width, nd=nd) self.bn1 = norm_layer(width) - self.conv2 = conv3x3(width, width, stride, groups, dilation, nd=nd) + self.conv2 = conv3x3(width, width, stride, groups, dilation, kernel_size=kernel_size, nd=nd) self.bn2 = norm_layer(width) self.conv3 = conv1x1(width, planes * self.expansion, nd=nd) self.bn3 = norm_layer(planes * self.expansion) @@ -110,6 +123,7 @@ def _make_layer( # Port from torchvision (to support 3d) blocks: int, stride: int = 1, dilate: bool = False, + kernel_size: int = 3, nd=2, secondary_block=None, downsample_method=None, @@ -126,6 +140,7 @@ def _make_layer( # Port from torchvision (to support 3d) blocks: stride: dilate: + kernel_size: nd: secondary_block: downsample_method: Downsample method. None: 1x1Conv with stride, Norm (standard ResNet), @@ -160,7 +175,7 @@ def _make_layer( # Port from torchvision (to support 3d) layers = [] layers.append( block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer, - nd=nd)) + kernel_size=kernel_size, nd=nd)) self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append(block( @@ -170,6 +185,7 @@ def _make_layer( # Port from torchvision (to support 3d) base_width=self.base_width, dilation=self.dilation, norm_layer=norm_layer, + kernel_size=kernel_size, nd=nd, )) if secondary_block is not None: @@ -178,7 +194,8 @@ def _make_layer( # Port from torchvision (to support 3d) def make_res_layer(block, inplanes, planes, blocks, norm_layer=nn.BatchNorm2d, base_width=64, groups=1, stride=1, - dilation=1, dilate=False, nd=2, secondary_block=None, **kwargs) -> nn.Module: + dilation=1, dilate=False, nd=2, secondary_block=None, downsample_method=None, kernel_size=3, + **kwargs) -> nn.Module: """ Args: @@ -194,6 +211,8 @@ def make_res_layer(block, inplanes, planes, blocks, norm_layer=nn.BatchNorm2d, b dilate: nd: secondary_block: + downsample_method: + kernel_size: kwargs: Returns: @@ -204,7 +223,7 @@ def make_res_layer(block, inplanes, planes, blocks, norm_layer=nn.BatchNorm2d, b groups=groups, dilation=dilation) # almost a ResNet return _make_layer(self=d, block=block, planes=planes, blocks=blocks, stride=stride, dilate=dilate, nd=nd, - secondary_block=secondary_block) + secondary_block=secondary_block, downsample_method=downsample_method, kernel_size=kernel_size) def _apply_mapping_rules(key, rules: dict): @@ -243,7 +262,7 @@ def map_state_dict(in_channels, state_dict, fused_initial): return mapping -class ResNet(nn.Sequential): +class ResNet(nn.Sequential, HyperparametersMixin): def __init__(self, in_channels, *body: nn.Module, initial_strides=2, base_channel=64, initial_pooling=True, final_layer=None, final_activation=None, fused_initial=True, pretrained=False, pyramid_pooling=False, pyramid_pooling_channels=64, pyramid_pooling_kwargs=None, nd=2, **kwargs): @@ -270,32 +289,37 @@ def __init__(self, in_channels, *body: nn.Module, initial_strides=2, base_channe components += [lookup_nn(final_activation)] super(ResNet, self).__init__(*components) if pretrained: - if isinstance(pretrained, str): - state_dict = load_state_dict_from_url(pretrained) - if '.pytorch.org' in pretrained: - state_dict = map_state_dict(in_channels, state_dict, fused_initial=fused_initial) - self.load_state_dict(state_dict) - else: - raise ValueError('There is no default set of weights for this model. ' - 'Please specify a URL using the `pretrained` argument.') + state_dict = resolve_pretrained(pretrained, in_channels=in_channels, fused_initial=fused_initial, + state_dict_mapper=map_state_dict) + self.load_state_dict(state_dict, strict=kwargs.get('pretrained_strict', True)) if pyramid_pooling: pyramid_pooling_kwargs = {} if pyramid_pooling_kwargs is None else pyramid_pooling_kwargs append_pyramid_pooling_(self, pyramid_pooling_channels, nd=nd, **pyramid_pooling_kwargs) class VanillaResNet(ResNet): - def __init__(self, in_channels, out_channels=0, layers=(2, 2, 2, 2), base_channel=64, fused_initial=True, nd=2, - **kwargs): + def __init__(self, in_channels, out_channels=0, layers=(2, 2, 2, 2), base_channel=64, fused_initial=True, + kernel_size=3, per_layer_kernel_sizes: dict = None, nd=2, **kwargs): + if per_layer_kernel_sizes is None: + per_layer_kernel_sizes = {} + if isinstance(per_layer_kernel_sizes, (tuple, list)): + per_layer_kernel_sizes = {i: v for i, v in enumerate(per_layer_kernel_sizes)} + self.save_hyperparameters() self.out_channels = oc = (base_channel, base_channel * 2, base_channel * 4, base_channel * 8) self.out_strides = (4, 8, 16, 32) if out_channels and 'final_layer' not in kwargs.keys(): kwargs['final_layer'] = get_nd_conv(nd)(self.out_channels[-1], out_channels, 1) + super(VanillaResNet, self).__init__( in_channels, - make_res_layer(BasicBlock, base_channel, oc[0], layers[0], stride=1, nd=nd, **kwargs), - make_res_layer(BasicBlock, oc[0], oc[1], layers[1], stride=2, nd=nd, **kwargs), - make_res_layer(BasicBlock, oc[1], oc[2], layers[2], stride=2, nd=nd, **kwargs), - make_res_layer(BasicBlock, oc[2], oc[3], layers[3], stride=2, nd=nd, **kwargs), + make_res_layer(BasicBlock, base_channel, oc[0], layers[0], stride=1, nd=nd, + kernel_size=per_layer_kernel_sizes.get(0, kernel_size), **kwargs), + make_res_layer(BasicBlock, oc[0], oc[1], layers[1], stride=2, nd=nd, + kernel_size=per_layer_kernel_sizes.get(1, kernel_size), **kwargs), + make_res_layer(BasicBlock, oc[1], oc[2], layers[2], stride=2, nd=nd, + kernel_size=per_layer_kernel_sizes.get(2, kernel_size), **kwargs), + make_res_layer(BasicBlock, oc[2], oc[3], layers[3], stride=2, nd=nd, + kernel_size=per_layer_kernel_sizes.get(3, kernel_size), **kwargs), base_channel=base_channel, fused_initial=fused_initial, nd=nd, **kwargs ) if not fused_initial: @@ -322,6 +346,8 @@ def __init__(self, in_channels, out_channels=0, pretrained=False, nd=2, **kwargs pretrained = default_model_urls['ResNet18'] super(ResNet18, self).__init__(in_channels, out_channels=out_channels, layers=(2, 2, 2, 2), pretrained=pretrained, nd=nd, **kwargs) + self.hparams.clear() + self.save_hyperparameters() class ResNet34(VanillaResNet): @@ -330,13 +356,20 @@ def __init__(self, in_channels, out_channels=0, pretrained=False, nd=2, **kwargs pretrained = default_model_urls['ResNet34'] super(ResNet34, self).__init__(in_channels, out_channels=out_channels, layers=(3, 4, 6, 3), pretrained=pretrained, nd=nd, **kwargs) + self.hparams.clear() + self.save_hyperparameters() __init__.__doc__ = ResNet18.__init__.__doc__.replace('ResNet 18', 'ResNet 34') class BottleResNet(ResNet): - def __init__(self, in_channels, out_channels=0, layers=(3, 4, 6, 3), base_channel=64, fused_initial=True, nd=2, - **kwargs): + def __init__(self, in_channels, out_channels=0, layers=(3, 4, 6, 3), base_channel=64, fused_initial=True, + kernel_size=3, per_layer_kernel_sizes: dict = None, nd=2, **kwargs): + if per_layer_kernel_sizes is None: + per_layer_kernel_sizes = {} + if isinstance(per_layer_kernel_sizes, (tuple, list)): + per_layer_kernel_sizes = {i: v for i, v in enumerate(per_layer_kernel_sizes)} + self.save_hyperparameters() ex = Bottleneck.expansion self.out_channels = oc = (base_channel * 4, base_channel * 8, base_channel * 16, base_channel * 32) self.out_strides = (4, 8, 16, 32) @@ -344,10 +377,14 @@ def __init__(self, in_channels, out_channels=0, layers=(3, 4, 6, 3), base_channe kwargs['final_layer'] = nn.Conv2d(self.out_channels[-1], out_channels, 1) super(BottleResNet, self).__init__( in_channels, - make_res_layer(Bottleneck, base_channel, oc[0] // ex, layers[0], stride=1, nd=nd, **kwargs), - make_res_layer(Bottleneck, base_channel * 4, oc[1] // ex, layers[1], stride=2, nd=nd, **kwargs), - make_res_layer(Bottleneck, base_channel * 8, oc[2] // ex, layers[2], stride=2, nd=nd, **kwargs), - make_res_layer(Bottleneck, base_channel * 16, oc[3] // ex, layers[3], stride=2, nd=nd, **kwargs), + make_res_layer(Bottleneck, base_channel, oc[0] // ex, layers[0], stride=1, nd=nd, + kernel_size=per_layer_kernel_sizes.get(0, kernel_size), **kwargs), + make_res_layer(Bottleneck, base_channel * 4, oc[1] // ex, layers[1], stride=2, nd=nd, + kernel_size=per_layer_kernel_sizes.get(1, kernel_size), **kwargs), + make_res_layer(Bottleneck, base_channel * 8, oc[2] // ex, layers[2], stride=2, nd=nd, + kernel_size=per_layer_kernel_sizes.get(2, kernel_size), **kwargs), + make_res_layer(Bottleneck, base_channel * 16, oc[3] // ex, layers[3], stride=2, nd=nd, + kernel_size=per_layer_kernel_sizes.get(3, kernel_size), **kwargs), base_channel=base_channel, fused_initial=fused_initial, nd=nd, **kwargs ) if not fused_initial: @@ -361,6 +398,8 @@ def __init__(self, in_channels, out_channels=0, pretrained=False, nd=2, **kwargs pretrained = default_model_urls['ResNet50'] super(ResNet50, self).__init__(in_channels, out_channels=out_channels, layers=(3, 4, 6, 3), pretrained=pretrained, nd=nd, **kwargs) + self.hparams.clear() + self.save_hyperparameters() __init__.__doc__ = ResNet18.__init__.__doc__.replace('ResNet 18', 'ResNet 50') @@ -371,6 +410,8 @@ def __init__(self, in_channels, out_channels=0, pretrained=False, nd=2, **kwargs pretrained = default_model_urls['ResNet101'] super(ResNet101, self).__init__(in_channels, out_channels=out_channels, layers=(3, 4, 23, 3), pretrained=pretrained, nd=nd, **kwargs) + self.hparams.clear() + self.save_hyperparameters() __init__.__doc__ = ResNet18.__init__.__doc__.replace('ResNet 18', 'ResNet 101') @@ -381,6 +422,8 @@ def __init__(self, in_channels, out_channels=0, pretrained=False, nd=2, **kwargs pretrained = default_model_urls['ResNet152'] super(ResNet152, self).__init__(in_channels, out_channels=out_channels, layers=(3, 8, 36, 3), pretrained=pretrained, nd=nd, **kwargs) + self.hparams.clear() + self.save_hyperparameters() __init__.__doc__ = ResNet18.__init__.__doc__.replace('ResNet 18', 'ResNet 152') @@ -391,6 +434,8 @@ def __init__(self, in_channels, out_channels=0, pretrained=False, nd=2, **kwargs pretrained = default_model_urls['ResNeXt50_32x4d'] super(ResNeXt50_32x4d, self).__init__(in_channels, out_channels=out_channels, layers=(3, 4, 6, 3), groups=32, base_width=4, pretrained=pretrained, nd=nd, **kwargs) + self.hparams.clear() + self.save_hyperparameters() __init__.__doc__ = ResNet18.__init__.__doc__.replace('ResNet 18', 'ResNeXt 50') @@ -401,6 +446,8 @@ def __init__(self, in_channels, out_channels=0, pretrained=False, nd=2, **kwargs pretrained = default_model_urls['ResNeXt101_32x8d'] super(ResNeXt101_32x8d, self).__init__(in_channels, out_channels=out_channels, layers=(3, 4, 23, 3), groups=32, base_width=8, pretrained=pretrained, nd=nd, **kwargs) + self.hparams.clear() + self.save_hyperparameters() __init__.__doc__ = ResNet18.__init__.__doc__.replace('ResNet 18', 'ResNeXt 101') @@ -409,6 +456,8 @@ class ResNeXt152_32x8d(BottleResNet): def __init__(self, in_channels, out_channels=0, nd=2, **kwargs): super(ResNeXt152_32x8d, self).__init__(in_channels, out_channels=out_channels, layers=(3, 8, 36, 3), groups=32, base_width=8, nd=nd, **kwargs) + self.hparams.clear() + self.save_hyperparameters() __init__.__doc__ = ResNet18.__init__.__doc__.replace('ResNet 18', 'ResNeXt 152') @@ -419,6 +468,8 @@ def __init__(self, in_channels, out_channels=0, pretrained=False, nd=2, **kwargs pretrained = default_model_urls['WideResNet50_2'] super(WideResNet50_2, self).__init__(in_channels, out_channels=out_channels, layers=(3, 4, 6, 3), base_width=64 * 2, pretrained=pretrained, nd=nd, **kwargs) + self.hparams.clear() + self.save_hyperparameters() __init__.__doc__ = ResNet18.__init__.__doc__.replace('ResNet 18', 'Wide ResNet 50') @@ -429,6 +480,8 @@ def __init__(self, in_channels, out_channels=0, pretrained=False, nd=2, **kwargs pretrained = default_model_urls['WideResNet101_2'] super(WideResNet101_2, self).__init__(in_channels, out_channels=out_channels, layers=(3, 4, 23, 3), base_width=64 * 2, pretrained=pretrained, nd=nd, **kwargs) + self.hparams.clear() + self.save_hyperparameters() __init__.__doc__ = ResNet18.__init__.__doc__.replace('ResNet 18', 'Wide ResNet 101') diff --git a/celldetection/util/util.py b/celldetection/util/util.py index 8aaa07c..fe5f92c 100644 --- a/celldetection/util/util.py +++ b/celldetection/util/util.py @@ -39,7 +39,7 @@ 'image_to_base64', 'base64_to_image', 'model2dict', 'dict2model', 'is_ipython', 'grouped_glob', 'tweak_attribute_', 'to_batched_h5', 'compare_file_hashes', 'import_file', 'load_imagej_rois', 'glob_h5_split', 'say_goodbye', 'parse_url_params', 'save_requirements', 'get_installed_packages', - 'resolve_model', 'is_package_installed', 'has_argument', 'dict_to_json_string'] + 'resolve_model', 'is_package_installed', 'has_argument', 'dict_to_json_string', 'resolve_pretrained'] def copy_script(dst, no_script_okay=True, frame=None, verbose=False):