In [None]:
from collections import OrderedDict

import torch
import torch.nn as nn

class UNet1d(nn.Module):

    def __init__(self, in_channels=3, out_channels=1, init_features=32):
        super(UNet1d, self).__init__()

        features = init_features
        self.encoder1 = UNet1d._block(in_channels, features, name="enc1")
        self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)
        self.encoder2 = UNet1d._block(features, features * 2, name="enc2")
        self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)
        self.encoder3 = UNet1d._block(features * 2, features * 4, name="enc3")
        self.pool3 = nn.MaxPool1d(kernel_size=2, stride=2)
        self.encoder4 = UNet1d._block(features * 4, features * 8, name="enc4")
        self.pool4 = nn.MaxPool1d(kernel_size=2, stride=2)

        self.bottleneck = UNet1d._block(features * 8, features * 16, name="bottleneck")

        self.upconv4 = nn.ConvTranspose1d(
            features * 16, features * 8, kernel_size=2, stride=2
        )
        self.decoder4 = UNet1d._block((features * 8) * 2, features * 8, name="dec4")
        self.upconv3 = nn.ConvTranspose1d(
            features * 8, features * 4, kernel_size=2, stride=2
        )
        self.decoder3 = UNet1d._block((features * 4) * 2, features * 4, name="dec3")
        self.upconv2 = nn.ConvTranspose1d(
            features * 4, features * 2, kernel_size=2, stride=2
        )
        self.decoder2 = UNet1d._block((features * 2) * 2, features * 2, name="dec2")
        self.upconv1 = nn.ConvTranspose1d(
            features * 2, features, kernel_size=2, stride=2
        )
        self.decoder1 = UNet1d._block(features * 2, features, name="dec1")

        self.conv = nn.Conv1d(
            in_channels=features, out_channels=out_channels, kernel_size=1
        )

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))
        enc4 = self.encoder4(self.pool3(enc3))

        bottleneck = self.bottleneck(self.pool4(enc4))

        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)

        return self.conv(dec1)

    @staticmethod
    def _block(in_channels, features, name):
        return nn.Sequential(
            OrderedDict(
                [
                    (
                        name + "conv1",
                        nn.Conv1d(
                            in_channels=in_channels,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm1", nn.BatchNorm1d(num_features=features)),
                    (name + "relu1", nn.ReLU(inplace=True)),
                    (
                        name + "conv2",
                        nn.Conv1d(
                            in_channels=features,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm2", nn.BatchNorm1d(num_features=features)),
                    (name + "relu2", nn.ReLU(inplace=True)),
                ]
            )
        )

# in_channels = 6
# unet = UNet1d(in_channels=in_channels)
# print(summary(
#     model=unet, 
#     input_size=(1, in_channels, 512),
#     col_names=["input_size", "output_size", "num_params", "trainable"],
#     col_width=20,
#     row_settings=["var_names"]
# ))

In [8]:
import torch

import torch.nn as nn

from collections import OrderedDict
from torchinfo import summary
from functools import reduce
from itertools import accumulate
from torchview import draw_graph

class UNetFlex1d(nn.Module):

    def __init__(self, in_channels=3, out_channels=1, blocks=[32, 64, 128, 256]):
        super(UNetFlex1d, self).__init__()

        self.blocks = blocks

        # Encoder Modules
        in_features = in_channels
        for enc_id, block in enumerate(blocks, start=1):
            setattr(
                self, f'encoder{enc_id}',
                UNetFlex1d._block(in_features, block, name=f'enc{enc_id}')
            )
            in_features = block

        # Pooling Layers
        for pool_id, _ in enumerate(blocks, start=1):
            setattr(
                self, f'pool{pool_id}',
                nn.MaxPool1d(kernel_size=2, stride=2)
            )

        # Bottleneck Module
        self.bottleneck = UNetFlex1d._block(blocks[-1], blocks[-1]*2, name="bottleneck")

        # Upconv Layers
        for up_id, block in enumerate(blocks, start=1):
            setattr(
                self, f'upconv{up_id}',
                nn.ConvTranspose1d(
                    block*2, block, kernel_size=2, stride=2,
                )
            )

        # Decoder Modules
        for dec_id, block in enumerate(blocks, start=1):
            setattr(
                self, f'decoder{dec_id}',
                UNetFlex1d._block(block*2, block, name=f'dec{dec_id}')
            )

        # Output Layer
        self.output = nn.Conv1d(
            in_channels=blocks[0], out_channels=out_channels, kernel_size=1,
        )
    
    def forward(self, x):
        # 1. Encoder Leg
        block_encodings = [None]
        for block_id, _ in enumerate(self.blocks, start=1):
            encoder = getattr(self, f'encoder{block_id}')
            pool = getattr(self, f'pool{block_id}')

            x = encoder(x)
            block_encodings.append(x)
            x = pool(x)

        # 2. Apply Bottleneck
        x = self.bottleneck(x)

        # 3. Decoder Leg
        for block_id in range(len(self.blocks), 0, -1):
            upconv = getattr(self, f'upconv{block_id}')
            decoder = getattr(self, f'decoder{block_id}')
            block_encoding = block_encodings[block_id]

            x = upconv(x)
            x = torch.cat((block_encoding, x), dim=1)
            x = decoder(x)

        # 4. Output
        output = self.output(x)

        return output

    @staticmethod
    def _block(in_channels, features, name):
        return nn.Sequential(
            OrderedDict(
                [
                    (
                        name + "conv1",
                        nn.Conv1d(
                            in_channels=in_channels,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm1", nn.BatchNorm1d(num_features=features)),
                    (name + "relu1", nn.ReLU(inplace=True)),
                    (
                        name + "conv2",
                        nn.Conv1d(
                            in_channels=features,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm2", nn.BatchNorm1d(num_features=features)),
                    (name + "relu2", nn.ReLU(inplace=True)),
                ]
            )
        )

model = UNetFlex1d(in_channels=1)

summary(
    model=model, 
    input_size=(16, 1, 512),
    # input_size=(4, 1, 512, 512),
    col_names=["output_size", "num_params", "trainable"],
    col_width=20,
    row_settings=["var_names"],
)

# graph = draw_graph(
#     model, 
#     input_size=(1, 1, 224, 224), 
#     # expand_nested=True
# )

# # View Model Architecture
# graph.visual_graph

torch.Size([16, 32, 512])
torch.Size([16, 32, 256])
torch.Size([16, 64, 256])
torch.Size([16, 64, 128])
torch.Size([16, 128, 128])
torch.Size([16, 128, 64])
torch.Size([16, 256, 64])
torch.Size([16, 256, 32])


Layer (type (var_name))                  Output Shape         Param #              Trainable
UNetFlex1d (UNetFlex1d)                  [16, 1, 512]         --                   True
├─Sequential (encoder1)                  [16, 32, 512]        --                   True
│    └─Conv1d (enc1conv1)                [16, 32, 512]        96                   True
│    └─BatchNorm1d (enc1norm1)           [16, 32, 512]        64                   True
│    └─ReLU (enc1relu1)                  [16, 32, 512]        --                   --
│    └─Conv1d (enc1conv2)                [16, 32, 512]        3,072                True
│    └─BatchNorm1d (enc1norm2)           [16, 32, 512]        64                   True
│    └─ReLU (enc1relu2)                  [16, 32, 512]        --                   --
├─MaxPool1d (pool1)                      [16, 32, 256]        --                   --
├─Sequential (encoder2)                  [16, 64, 256]        --                   True
│    └─Conv1d (enc2conv1)        