In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from glasses.nn.models.classification import *
from glasses.utils.Storage import ForwardModuleStorage
from glasses.nn.models.segmentation.unet import *
from glasses.nn.models.classification.resnet import ResNetEncoder, ResNet, ResNetBottleneckBlock
from glasses.nn.models.classification.efficientnet import *
from glasses.nn.models.classification import *


In [8]:
backbone = EfficientNet.efficientnet_b0().encoder


encoder = SegmentationEncoder(backbone, 
                              stages = [
                                  backbone.stem[-2], 
                                  backbone.layers[1],
                                  backbone.layers[2],
                                  backbone.layers[3],
                              ],
                              widths = [32, 24, 40, 80]
                             )


m = UNet(encoder = lambda *args, **kwargs: encoder, decoder = partial(UNetDecoder, 
                                                                      widths=[1280, 256, 128, 64, 32, 16]))

[80, 40, 24, 32, 0, 0]


In [9]:
m.summary()

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         Conv2dPad-1         [-1, 32, 112, 112]             864
       BatchNorm2d-2         [-1, 32, 112, 112]              64
              SiLU-3         [-1, 32, 112, 112]               0
          Identity-4         [-1, 32, 112, 112]               0
         Conv2dPad-5         [-1, 32, 112, 112]             288
       BatchNorm2d-6         [-1, 32, 112, 112]              64
              SiLU-7         [-1, 32, 112, 112]               0
 AdaptiveAvgPool2d-8             [-1, 32, 1, 1]               0
            Conv2d-9              [-1, 8, 1, 1]             264
             SiLU-10              [-1, 8, 1, 1]               0
           Conv2d-11             [-1, 32, 1, 1]             288
          Sigmoid-12             [-1, 32, 1, 1]               0
        ChannelSE-13         [-1, 32, 112, 112]               0
        Conv2dPad-14         [-1, 16, 1

(tensor(7319630), tensor(7319630), tensor(27.9222), tensor(308.2561))

In [7]:
_ = encoder(torch.randn(1,3,224,224))

[ f.shape for f in encoder.features]

[torch.Size([1, 32, 112, 112]),
 torch.Size([1, 24, 56, 56]),
 torch.Size([1, 40, 28, 28]),
 torch.Size([1, 80, 14, 14])]

In [None]:
backbone = EfficientNet.efficientnet_b0()
# backbone =ResNet.resnet18()

encoder = SegmentationEncoder(backbone.encoder, 
                              [backbone.encoder.stem[-2], *backbone.encoder.layers],
                             )


m = UNet(encoder = lambda *args, **kwargs: encoder, decoder = partial(UNetDecoder, 
                                                                      widths=[512, 256, 128, 64, 32, 16]))

In [None]:
m

In [None]:
m.summary()

In [None]:
UNet()

In [None]:
from typing import List

class SegmentationEncoder(nn.Module):
    
    def __init__(self, backbone, *args, **kwargs):
        super().__init__()
        self.backbone = backbone(*args, **kwargs)

        self.storage = ForwardModuleStorage(self.backbone, self.backbone.stages())
        
    def forward(self, x):
        return self.backbone(x)
    
    @property
    def features(self):
        return list(self.storage.state.values())

In [None]:
encoder = SegmentationEncoder(lambda : ResNet.resnet18(in_channels=1).encoder)
x = torch.randn(8,1,224,224)

encoder(x)

print([f.shape for f in encoder.features])

In [None]:
m = UNet()
m.summary((1,224,224))

In [None]:
m.summary()

In [None]:
import segmentation_models_pytorch as smp

model = smp.Unet('resnet34')
model

In [None]:
from torchsummary import summary


summary(model.cuda(), (3,224,224))

In [None]:
ResNet.resnet26d(in_channels=3).summary()

In [None]:
ResNet.resnet26(in_channels=1)

# Benchmark

In [None]:
from glasses.nn.models.classification.efficientnet import *
import timm

from transfer_weights import clone_model
from benchmark import benchmark

src = timm.create_model('tf_efficientnet_lite1', pretrained='True')
dst = EfficientNetLite.efficientnet_lite1(mode='same')

In [None]:
dst = clone_model(src, dst)

In [None]:
dst.summary()

In [None]:
dst

In [None]:
# import torch

# transform = dst.configs['efficientnet_lite1'].transform

# benchmark(dst.cuda(), transform)

In [None]:
# benchmark(src.cuda(), transform)

In [None]:
src = timm.create_model('efficientnet_b0', pretrained='True')
src

In [None]:
dst.summary()