In [1]:
from collections import namedtuple
from torch import nn
from torch.utils import model_zoo

import models.selim_zoo.unet as models

In [2]:
model = namedtuple("model", ["url", "model"])

models = {
    "resnet34": model(
        url="https://github.com/alimbekovKZ/lungs_segmentation/releases/download/1.0.0/resnet34.pth",
        model=models.Resnet(seg_classes=2, backbone_arch='resnet34'),
    ),
    "densenet121": model(
        url="https://github.com/alimbekovKZ/lungs_segmentation/releases/download/1.0.0/densenet121.pth",
        model=models.DensenetUnet(seg_classes=2, backbone_arch='densenet121'),
    )
}

In [6]:
def create_model(model_name: str) -> nn.Module:
    model = models[model_name].model
    state_dict = model_zoo.load_url(models[model_name].url, progress=True, map_location="cpu")#["state_dict"]
    model.load_state_dict(state_dict)
    return model

In [8]:
model = create_model('resnet34')

In [9]:
from lungs_segmentation.pre_trained_models import create_model

In [10]:
model = create_model('resnet34')

In [11]:
model

Resnet(
  (bottlenecks): ModuleList(
    (0): ConvBottleneck(
      (seq): Sequential(
        (0): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
      )
    )
    (1): ConvBottleneck(
      (seq): Sequential(
        (0): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
      )
    )
    (2): ConvBottleneck(
      (seq): Sequential(
        (0): Conv2d(192, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
      )
    )
    (3): ConvBottleneck(
      (seq): Sequential(
        (0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
      )
    )
  )
  (decoder_stages): ModuleList(
    (0): UnetDecoderBlock(
      (layer): Sequential(
        (0): Upsample(scale_factor=2.0, mode=nearest)
        (1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (2): ReLU(inplace=True)
