In [13]:
import torch
from torch import nn
from torch import Tensor
from pytorch_toolbelt.modules import encoders, decoders, heads
from pytorch_toolbelt.modules import ACT_RELU, ACT_SILU


# Creating Encoder

In [14]:
from pytorch_toolbelt.modules import encoders

encoder = encoders.Resnet34Encoder(pretrained=True, layers=[0, 1, 2, 3, 4])
output_spec = encoder.get_output_spec()
output_spec

FeatureMapsSpecification(channels=(64, 64, 128, 256, 512), strides=(2, 4, 8, 16, 32))

In [15]:
from pytorch_toolbelt.utils import describe_outputs

outputs = encoder(torch.randn(1, 3, 256, 256))
describe_outputs(outputs)

[{'size': (1, 64, 128, 128),
  'mean': 0.2657899856567383,
  'std': 0.30636727809906006,
  'dtype': torch.float32},
 {'size': (1, 64, 64, 64),
  'mean': 0.7731860876083374,
  'std': 0.6579753756523132,
  'dtype': torch.float32},
 {'size': (1, 128, 32, 32),
  'mean': 0.2784508764743805,
  'std': 0.33034390211105347,
  'dtype': torch.float32},
 {'size': (1, 256, 16, 16),
  'mean': 0.10355071723461151,
  'std': 0.21384787559509277,
  'dtype': torch.float32},
 {'size': (1, 512, 8, 8),
  'mean': 0.9371781349182129,
  'std': 1.2523103952407837,
  'dtype': torch.float32}]

## Changing number of input channels

In [16]:
encoder = encoder.change_input_channels(1)
outputs = encoder(torch.randn(1, 1, 256, 256))
describe_outputs(outputs)

[{'size': (1, 64, 128, 128),
  'mean': 0.26584798097610474,
  'std': 0.3063778281211853,
  'dtype': torch.float32},
 {'size': (1, 64, 64, 64),
  'mean': 0.7738164067268372,
  'std': 0.6547278165817261,
  'dtype': torch.float32},
 {'size': (1, 128, 32, 32),
  'mean': 0.2798773944377899,
  'std': 0.3327867388725281,
  'dtype': torch.float32},
 {'size': (1, 256, 16, 16),
  'mean': 0.10497166216373444,
  'std': 0.2116292417049408,
  'dtype': torch.float32},
 {'size': (1, 512, 8, 8),
  'mean': 0.9424188137054443,
  'std': 1.2670280933380127,
  'dtype': torch.float32}]

In [17]:
encoder = encoders.HRNetW18Encoder(pretrained=True, layers=[2, 3, 4], use_incre_features=False)
encoder.get_output_spec()

FeatureMapsSpecification(channels=(36, 72, 144), strides=(8, 16, 32))

In [18]:
class GenericSegmentationModel(nn.Module):
    def __init__(self, encoder, decoder, head):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.head = head

    def forward(self, x):
        features = self.encoder(x)
        features = self.decoder(features)
        outputs = self.head(features, output_size=x.shape[-2:])
        return outputs

In [19]:
from pytorch_toolbelt.modules.heads import ResizeHead


def b4_bifpn():
    encoder = encoders.TimmB4Encoder(
        pretrained=True, layers=[1, 2, 3, 4], drop_path_rate=0.2, activation=ACT_SILU
    )
    decoder = decoders.BiFPNDecoder(
        input_spec=encoder.get_output_spec(),
        out_channels=256,
        num_layers=3,
        activation=ACT_SILU
    )
    head = ResizeHead(
        input_spec=decoder.get_output_spec(),
        num_classes=1,
        dropout_rate=0.2,
    )
    return GenericSegmentationModel(
        encoder,
        decoder,
        head,
    )



In [29]:
from pytorch_toolbelt.utils import count_parameters

model = b4_bifpn()
output = model(torch.randn(1, 3, 256, 256))

print(count_parameters(model, human_friendly=True))
print(describe_outputs(output))

{'total': '27.5M', 'trainable': '27.5M', 'encoder': '16.7M', 'decoder': '10.8M', 'head': '2.31K'}
{'size': (1, 1, 256, 256), 'mean': -0.0968250185251236, 'std': 0.28567764163017273, 'dtype': torch.float32}


In [21]:

describe_outputs(output)

{'size': (1, 1, 256, 256),
 'mean': -0.3548843264579773,
 'std': 0.3300875425338745,
 'dtype': torch.float32}

In [22]:

def hrnet_fpn():
    encoder = encoders.HRNetW18Encoder(
        pretrained=True, layers=[1, 2, 3, 4], use_incre_features=False,
    )
    decoder = decoders.FPNDecoder(
        input_spec=encoder.get_output_spec(),
        out_channels=256,
    )
    head = ResizeHead(
        input_spec=decoder.get_output_spec(),
        num_classes=1,
        dropout_rate=0.2,
    )
    return GenericSegmentationModel(
        encoder,
        decoder,
        head,
    )


In [28]:
model = hrnet_fpn()
output = model(torch.randn(1, 3, 256, 256))

print(count_parameters(model, human_friendly=True))
print(describe_outputs(output))

{'total': '11.4M', 'trainable': '11.4M', 'encoder': '9.56M', 'decoder': '1.84M', 'head': '2.31K'}
{'size': (1, 1, 256, 256), 'mean': 0.34645456075668335, 'std': 0.1458025723695755, 'dtype': torch.float32}


In [27]:
from pytorch_toolbelt.modules import UpsampleLayerType


def resnet50d_unet():
    encoder = encoders.TimmResnet50D(
        pretrained=True, layers=[1, 2, 3, 4],
    )
    decoder = decoders.UNetDecoder(
        input_spec=encoder.get_output_spec(),
        out_channels=(128, 256, 512),
        upsample_block=UpsampleLayerType.BILINEAR,
    )
    head = ResizeHead(
        input_spec=decoder.get_output_spec(),
        num_classes=1,
        dropout_rate=0.2,
    )
    return GenericSegmentationModel(
        encoder,
        decoder,
        head,
    )

model = resnet50d_unet()
output = model(torch.randn(1, 3, 256, 256))
print(count_parameters(model, human_friendly=True))
print(describe_outputs(output))

{'total': '43.7M', 'trainable': '43.7M', 'encoder': '23.5M', 'decoder': '20.2M', 'head': '1.15K'}
{'size': (1, 1, 256, 256), 'mean': 0.1509934365749359, 'std': 0.26733869314193726, 'dtype': torch.float32}
