In [1]:
%load_ext autoreload
%autoreload 2

In [69]:
import torch
from glasses.nn.models.classification import *
from glasses.utils.Storage import ForwardModuleStorage
from glasses.nn.models.segmentation.unet import *
from glasses.nn.models.classification.resnet import ResNetEncoder, ResNet, ResNetBottleneckBlock
from glasses.nn.models.classification.efficientnet import *
from glasses.nn.models.classification import *
from torch import nn

In [17]:
name = 'resnet-asdsa'

name.startswith('resnet')

True

In [81]:
def from_encoder(model_def, *args, **kwargs):
    model = model_def( *args, **kwargs)
    encoder = nn.Identity()
    if isinstance(model, ResNet):
        encoder = WithFeatures(model.encoder, 
                              stages = [
                                  model.encoder.stem[-2], 
                                  *model.encoder.layers,
                              ],
                              features_widths = model.encoder.widths[1:]
                             )
        
    elif isinstance(model, EfficientNet):
        encoder = WithFeatures(model.encoder, 
                              stages = [
                                  model.encoder.stem[-2], 
                                  model.encoder.layers[1],
                                  model.encoder.layers[2],
                                  model.encoder.layers[3],
                              ],
                              features_widths = [model.encoder.widths[0],
                                                 model.encoder.widths[2],
                                                 model.encoder.widths[3],
                                                 model.encoder.widths[4],
                                                ]
                             )
        
    return encoder

In [90]:
m = UNet(encoder = partial(from_encoder, EfficientNet.efficientnet_b2),
         decoder = partial(UNetDecoder, widths=[256, 128, 64, 32, 16]),
)

m.summary((1,224,224))

RuntimeError: Sizes of tensors must match except in dimension 2. Got 14 and 112 (The offending index is 0)

In [91]:
m

UNet(
  (encoder): WithFeatures(
    (backbone): EfficientNetEncoder(
      (stem): ConvBnAct(
        (conv): Conv2dPad(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act): SiLU(inplace=True)
      )
      (layers): ModuleList(
        (0): ResNetLayer(
          (block): Sequential(
            (0): EfficientNetBasicBlock(
              (block): Sequential(
                (exp): Identity()
                (depth): ConvBnAct(
                  (conv): Conv2dPad(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
                  (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                  (act): SiLU(inplace=True)
                )
                (att): ChannelSE(
                  (avg_pool): AdaptiveAvgPool2d(output_size=1)
                  (att): Sequential(
                    (con

In [66]:
backbone = EfficientNet.efficientnet_b0().encoder


encoder = WithFeatures(backbone, 
                              stages = [
                                  backbone.stem[-2], 
                                  backbone.layers[1],
                                  backbone.layers[2],
                                  backbone.layers[3],
                              ],
                              features_widths = [32, 24, 40, 80]
                             )


m = UNet(encoder = lambda *args, **kwargs: encoder, decoder = partial(UNetDecoder, 
                                                                      widths=[1280, 256, 128, 64, 32, 16]))

In [67]:
m.summary((3, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         Conv2dPad-1         [-1, 32, 112, 112]             864
       BatchNorm2d-2         [-1, 32, 112, 112]              64
              SiLU-3         [-1, 32, 112, 112]               0
          Identity-4         [-1, 32, 112, 112]               0
         Conv2dPad-5         [-1, 32, 112, 112]             288
       BatchNorm2d-6         [-1, 32, 112, 112]              64
              SiLU-7         [-1, 32, 112, 112]               0
 AdaptiveAvgPool2d-8             [-1, 32, 1, 1]               0
            Conv2d-9              [-1, 8, 1, 1]             264
             SiLU-10              [-1, 8, 1, 1]               0
           Conv2d-11             [-1, 32, 1, 1]             288
          Sigmoid-12             [-1, 32, 1, 1]               0
        ChannelSE-13         [-1, 32, 112, 112]               0
        Conv2dPad-14         [-1, 16, 1

(tensor(44177230), tensor(44177230), tensor(168.5228), tensor(778.0754))

# Benchmark

In [None]:
from glasses.nn.models.classification.efficientnet import *
import timm

from transfer_weights import clone_model
from benchmark import benchmark

src = timm.create_model('tf_efficientnet_lite1', pretrained='True')
dst = EfficientNetLite.efficientnet_lite1(mode='same')

In [None]:
dst = clone_model(src, dst)

In [None]:
dst.summary()

In [None]:
dst

In [None]:
# import torch

# transform = dst.configs['efficientnet_lite1'].transform

# benchmark(dst.cuda(), transform)

In [None]:
# benchmark(src.cuda(), transform)

In [None]:
src = timm.create_model('efficientnet_b0', pretrained='True')
src

In [None]:
dst.summary()