Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
ericup committed Jul 15, 2021
1 parent 8d1fccd commit df0647e
Show file tree
Hide file tree
Showing 12 changed files with 35 additions and 15 deletions.
16 changes: 6 additions & 10 deletions celldetection/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
from .commons import TwoConvBnRelu
from .unet import UNetEncoder, UNet, U12, U17, U22, SlimU22, WideU22
from .resnet import get_resnet, ResNet50, ResNet34, ResNet18, ResNet152, ResNet101, WideResNet101_2, WideResNet50_2, \
ResNeXt152_32x8d, ResNeXt101_32x8d, ResNeXt50_32x4d
from .cpn import CPN, CpnSlimU22, CpnU22, CpnWideU22, CpnResNet18FPN, CpnResNet34FPN, CpnResNet50FPN, CpnResNet101FPN, \
CpnResNet152FPN, CpnResNeXt50FPN, CpnResNeXt101FPN, CpnResNeXt152FPN, CpnWideResNet50FPN, \
CpnWideResNet101FPN, CpnMobileNetV3LargeFPN, CpnMobileNetV3SmallFPN
from .fpn import FPN, ResNeXt50FPN, ResNeXt101FPN, ResNet18FPN, ResNet34FPN, ResNeXt152FPN, WideResNet50FPN, \
WideResNet101FPN, ResNet50FPN, ResNet101FPN, ResNet152FPN, MobileNetV3SmallFPN, MobileNetV3LargeFPN
from .inference import Inference
from .commons import *
from .unet import *
from .resnet import *
from .cpn import *
from .fpn import *
from .inference import *
from .mobilenetv3 import *
2 changes: 2 additions & 0 deletions celldetection/models/commons.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch.nn as nn

__all__ = ['TwoConvBnRelu']


class TwoConvBnRelu(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, stride=1, mid_channels=None, **kwargs):
Expand Down
4 changes: 4 additions & 0 deletions celldetection/models/cpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
from .fpn import ResNet34FPN, ResNet18FPN, ResNet50FPN, ResNet101FPN, ResNet152FPN, ResNeXt50FPN, \
ResNeXt101FPN, ResNeXt152FPN, WideResNet50FPN, WideResNet101FPN, MobileNetV3LargeFPN, MobileNetV3SmallFPN

__all__ = ['CPN', 'CpnSlimU22', 'CpnU22', 'CpnWideU22', 'CpnResNet18FPN', 'CpnResNet34FPN', 'CpnResNet50FPN',
'CpnResNet101FPN', 'CpnResNet152FPN', 'CpnResNeXt50FPN', 'CpnResNeXt101FPN', 'CpnResNeXt152FPN',
'CpnWideResNet50FPN', 'CpnWideResNet101FPN', 'CpnMobileNetV3LargeFPN', 'CpnMobileNetV3SmallFPN']


class ReadOut(nn.Module):
def __init__(
Expand Down
4 changes: 4 additions & 0 deletions celldetection/models/fpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
ResNeXt152_32x8d, WideResNet50_2, WideResNet101_2
from .mobilenetv3 import MobileNetV3Large, MobileNetV3Small

__all__ = ['FPN', 'ResNeXt50FPN', 'ResNeXt101FPN', 'ResNet18FPN', 'ResNet34FPN', 'ResNeXt152FPN', 'WideResNet50FPN',
'WideResNet101FPN', 'ResNet50FPN', 'ResNet101FPN', 'ResNet152FPN', 'MobileNetV3SmallFPN',
'MobileNetV3LargeFPN']


class FPN(BackboneWithFPN):
def __init__(self, backbone, channels=256, return_layers: dict = None):
Expand Down
2 changes: 2 additions & 0 deletions celldetection/models/inference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch
from ..util.util import asnumpy

__all__ = ['Inference']


class Inference:
def __init__(self, model, device=None, amp=False, transforms=None):
Expand Down
3 changes: 3 additions & 0 deletions celldetection/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from torchvision.models.resnet import ResNet as RN, Bottleneck, BasicBlock
from ..util.util import Dict

__all__ = ['get_resnet', 'ResNet50', 'ResNet34', 'ResNet18', 'ResNet152', 'ResNet101', 'WideResNet101_2',
'WideResNet50_2', 'ResNeXt152_32x8d', 'ResNeXt101_32x8d', 'ResNeXt50_32x4d']


def make_res_layer(block, inplanes, planes, blocks, norm_layer=nn.BatchNorm2d, base_width=64, groups=1, stride=1,
dilation=1, dilate=False, **kwargs) -> nn.Module:
Expand Down
3 changes: 2 additions & 1 deletion celldetection/models/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork
from collections import OrderedDict
from typing import List, Tuple, Dict

from .commons import TwoConvBnRelu

__all__ = ['UNetEncoder', 'UNet', 'U12', 'U17', 'U22', 'SlimU22', 'WideU22']


class UNetEncoder(nn.Sequential):
def __init__(self, in_channels, depth=5, base_channels=64, factor=2, pool=True):
Expand Down
1 change: 1 addition & 0 deletions celldetection/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .commons import *
2 changes: 2 additions & 0 deletions celldetection/ops/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import torch.nn.functional as F
from typing import List

__all__ = ['downsample_labels']


def downsample_labels(inputs, size: List[int]):
"""
Expand Down
5 changes: 2 additions & 3 deletions celldetection/util/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
from .util import Dict, lookup_nn, reduce_loss_dict, to_device, asnumpy, fetch_model, random_code_name, dict_hash, \
fetch_image, random_seed
from .timer import start_timer, stop_timer, print_timing
from .util import *
from .timer import *
5 changes: 4 additions & 1 deletion celldetection/util/timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from time import time
import numpy as np

__all__ = ['start_timer', 'stop_timer', 'print_timing']

TIMINGS = {}


Expand All @@ -26,7 +28,8 @@ def convert_seconds(seconds):


def seconds_to_str(seconds):
s = [f'{i} {n[:-1] if i == 1 else n}' for i, n in zip(convert_seconds(seconds), ('days', 'hours', 'minutes', 'seconds'))]
s = [f'{i} {n[:-1] if i == 1 else n}' for i, n in
zip(convert_seconds(seconds), ('days', 'hours', 'minutes', 'seconds'))]
s = ', '.join(s)
return s

Expand Down
3 changes: 3 additions & 0 deletions celldetection/util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
import hashlib
import json

__all__ = ['Dict', 'lookup_nn', 'reduce_loss_dict', 'to_device', 'asnumpy', 'fetch_model', 'random_code_name',
'dict_hash', 'fetch_image', 'random_seed']


class Dict(dict):
__getattr__ = dict.__getitem__ # alternative: dict.get if KeyError is not desired
Expand Down

0 comments on commit df0647e

Please sign in to comment.