In [37]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [38]:
# Needs to be installed in an environment with CUDA available
!pip install inplace-abn
# Restart notebook after installation

Defaulting to user installation because normal site-packages is not writeable


In [39]:
from pathlib import Path
from src.models.tresnet.tresnet import TResNet
import torch

Num classes: 9605


In [48]:
model[5]

Bottleneck(
  (conv1): Sequential(
    (0): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (1): InPlaceABN(512, eps=1e-05, momentum=0.1, affine=True, activation=leaky_relu[0.001])
  )
  (conv2): Sequential(
    (0): Sequential(
      (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): InPlaceABN(512, eps=1e-05, momentum=0.1, affine=True, activation=leaky_relu[0.001])
    )
    (1): AntiAliasDownsampleLayer()
  )
  (conv3): Sequential(
    (0): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (1): InPlaceABN(2048, eps=1e-05, momentum=0.1, affine=True, activation=identity)
  )
  (relu): ReLU(inplace=True)
  (downsample): Sequential(
    (0): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (1): Sequential(
      (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): InPlaceABN(2048, eps=1e-05, momentum=0.1, affine=True, activation=identity)
    )
  )
)

In [49]:
checkpoint_path = Path("checkpoints/mtresnet_opim_86.72.pth")

def build_tresnet_m_model(pretrained: bool = True, truncated: bool = False):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if pretrained:
        state = torch.load(checkpoint_path, map_location=device)
        num_classes = state["num_classes"]
    else:
        num_classes = 1
    
    model = TResNet(
        layers=[3, 4, 11, 3],
        num_classes=num_classes,
        in_chans=3,
        do_bottleneck_head=True
    )
    if pretrained:
        model.load_state_dict(state["model"], strict=True)

    base = nn.Sequential(
        model.body.SpaceToDepth,
        model.body.conv1,
        model.body.layer1,
        model.body.layer2,
        model.body.layer3
    )
    if truncated:
        return nn.Sequential(base, model.body.layer4[0].conv1)
    else:
        return nn.Sequential(base, model.body.layer4[0])

In [52]:
import torch.nn as nn
from torchinfo import summary

model = build_tresnet_m_model(pretrained=True, truncated=True)
summary(model, input_size=(1, 3, 1024, 1024), depth=3)

Layer (type:depth-idx)                                       Output Shape              Param #
Sequential                                                   [1, 512, 64, 64]          --
├─Sequential: 1-1                                            [1, 1024, 64, 64]         --
│    └─SpaceToDepthModule: 2-1                               [1, 48, 256, 256]         --
│    └─Sequential: 2-2                                       [1, 64, 256, 256]         --
│    │    └─Conv2d: 3-1                                      [1, 64, 256, 256]         27,648
│    │    └─InPlaceABN: 3-2                                  [1, 64, 256, 256]         128
│    └─Sequential: 2-3                                       [1, 64, 256, 256]         --
│    │    └─BasicBlock: 3-3                                  [1, 64, 256, 256]         82,304
│    │    └─BasicBlock: 3-4                                  [1, 64, 256, 256]         82,304
│    │    └─BasicBlock: 3-5                                  [1, 64, 256, 256]    