In [1]:
import torch
import torch.nn as nn


In [2]:

class Bottleneck(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        # Typical MobileNetV2 bottleneck: expansion, depthwise, pointwise
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, in_channels*6, 1), nn.ReLU6(),
            nn.Conv2d(in_channels*6, in_channels*6, 3, stride, 1, groups=in_channels*6), nn.ReLU6(),
            nn.Conv2d(in_channels*6, out_channels, 1)
        )
    def forward(self, x):
        return self.block(x)


In [None]:

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.initial = nn.Conv2d(3, 32, 3, stride=2, padding=1) # starting conv
        # Add bottleneck layers as specified
        self.bottlenecks = nn.Sequential(
            Bottleneck(32, 16, 1),
            Bottleneck(16, 24, 2),
            # ... (continue according to architecture diagram)
        )
    def forward(self, x):
        x = self.initial(x)
        # Keep outputs for skip connections
        skips = []
        for layer in self.bottlenecks:
            x = layer(x)
            skips.append(x)
        return x, skips


In [None]:

class PixelShuffleBlock(nn.Module):
    def __init__(self, in_channels, out_channels, upscale=2):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels * (upscale**2), 1)
        self.pixel_shuffle = nn.PixelShuffle(upscale)
    def forward(self, x):
        x = self.conv(x)
        return self.pixel_shuffle(x)


In [None]:

class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels, num_bottlenecks, upscale=2):
        super().__init__()
        self.pixel_shuffle = PixelShuffleBlock(in_channels, out_channels, upscale)
        self.bottlenecks = nn.Sequential(
            *[Bottleneck(out_channels, out_channels, 1) for _ in range(num_bottlenecks)]
        )
    def forward(self, x, skip):
        # Combine with skip connection
        x = torch.cat([x, skip], dim=1)
        x = self.pixel_shuffle(x)
        x = self.bottlenecks(x)
        return x


In [None]:

class OutputHead(nn.Module):
    def __init__(self, in_channels, num_classes=1):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, num_classes, 1)
        self.softmax = nn.Softmax(dim=1) if num_classes > 1 else nn.Identity()
    def forward(self, x):
        return self.softmax(self.conv(x))


In [None]:

class FullModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        # Example for 4 decoder blocks
        self.decoder1 = DecoderBlock(?, 24, 4, upscale=2)
        self.decoder2 = DecoderBlock(?, 16, 3, upscale=2)
        self.decoder3 = DecoderBlock(?, 8, 3, upscale=2)
        self.decoder4 = DecoderBlock(?, 8, 2, upscale=2)
        self.depth_heads = nn.ModuleList([
            OutputHead(?, 1) for _ in range(4)
        ])
        self.seg_heads = nn.ModuleList([
            OutputHead(?, 19) for _ in range(4)
        ])
    def forward(self, x):
        x, skips = self.encoder(x)
        dec1 = self.decoder1(x, skips)
        dec2 = self.decoder2(dec1, skips[1])
        dec3 = self.decoder3(dec2, skips[2])
        dec4 = self.decoder4(dec3, skips)
        depth_outputs = [head(d) for head, d in zip(self.depth_heads, [dec1, dec2, dec3, dec4])]
        seg_outputs = [head(s) for head, s in zip(self.seg_heads, [dec1, dec2, dec3, dec4])]
        return depth_outputs, seg_outputs
