In [16]:
import torch
import torch.nn as nn
import torchvision.models as models

from src.networks import DoubleConv


In [30]:
class DenseNet121UNet(nn.Module):
    def __init__(self, n_channels: int = 3, init_features: int = 64):
        super(DenseNet121UNet, self).__init__()
        self.densenet = models.densenet121(pretrained=True)

        self.center = DoubleConv(init_features * 16, init_features * 8)

        # Decoder (Expanding Path)
        self.up4 = nn.ConvTranspose2d(init_features * 16, init_features * 8, kernel_size=2,
                                      stride=2)
        self.dec4 = DoubleConv(init_features * 16, init_features * 8)

        self.up3 = nn.ConvTranspose2d(init_features * 8, init_features * 4, kernel_size=2, stride=2)
        self.dec3 = DoubleConv(init_features * 8, init_features * 4)

        self.up2 = nn.ConvTranspose2d(init_features * 4, init_features * 2, kernel_size=2, stride=2)
        self.dec2 = DoubleConv(init_features * 4, init_features * 2)

        self.up1 = nn.ConvTranspose2d(init_features * 2, init_features, kernel_size=2, stride=2)
        self.dec1 = DoubleConv(init_features * 2, init_features)

        self.final_conv = nn.Conv2d(init_features, 3, kernel_size=1)

        self.pool = nn.MaxPool2d(2)

    def forward(self, input: torch.Tensor) -> torch.Tensor:

        x = self.densenet.features.conv0(input)
        x = self.densenet.features.norm0(x)
        enc1 = self.densenet.features.relu0(x) # torch.Size([16, 64, 128, 128])
        pool1 = self.densenet.features.pool0(enc1)

        print(f"enc1: {enc1.shape}")
        print(f"pool1: {pool1.shape}")

        x = self.densenet.features.denseblock1(pool1)
        enc2 = x  # torch.Size([16, 256, 64, 64])
        pool2 = self.densenet.features.transition1(enc2)

        print(f"enc2: {enc2.shape}")
        print(f"pool2: {pool2.shape}")

        x = self.densenet.features.denseblock2(pool2)
        enc3 = x  # torch.Size([16, 512, 32, 32])
        pool3 = self.densenet.features.transition2(enc3)

        print(f"enc3: {enc3.shape}")
        print(f"pool3: {pool3.shape}")

        x = self.densenet.features.denseblock3(pool3)
        enc4 = x  # torch.Size([16, 1024, 16, 16])
        pool4 = self.densenet.features.transition3(enc4)

        print(f"enc4: {enc4.shape}")
        print(f"pool4: {pool4.shape}")

        x = self.densenet.features.denseblock4(pool4)
        enc5 = x  # torch.Size([16, 1024, 8, 8])

        print(f"enc5: {enc5.shape}")

        # Decoder
        up4 = self.up4(enc5)
        up4 - torch.functional.inperpolate(up4, )
        concat4 = torch.cat([enc4, up4], dim=1)

        print(f"up4: {up4.shape}")
        print(f"concat4: {concat4.shape}")

        dec4 = self.dec4(concat4)

        up3 = self.up3(dec4)
        concat3 = torch.cat([enc3, up3], dim=1)
        dec3 = self.dec3(concat3)

        up2 = self.up2(dec3)
        concat2 = torch.cat([enc2, up2], dim=1)
        dec2 = self.dec2(concat2)

        up1 = self.up1(dec2)
        concat1 = torch.cat([enc1, up1], dim=1)
        dec1 = self.dec1(concat1)

        return nn.functional.tanh(self.final_conv(dec1))

In [31]:
model = DenseNet121UNet()
input = torch.randn(1, 3, 256, 256)
output = model(input)

print(output.shape)

enc1: torch.Size([1, 64, 128, 128])
pool1: torch.Size([1, 64, 64, 64])
enc2: torch.Size([1, 256, 64, 64])
pool2: torch.Size([1, 128, 32, 32])
enc3: torch.Size([1, 512, 32, 32])
pool3: torch.Size([1, 256, 16, 16])
enc4: torch.Size([1, 1024, 16, 16])
pool4: torch.Size([1, 512, 8, 8])
enc5: torch.Size([1, 1024, 8, 8])
up4: torch.Size([1, 512, 16, 16])
concat4: torch.Size([1, 1536, 16, 16])


RuntimeError: Given groups=1, weight of size [512, 1024, 3, 3], expected input[1, 1536, 16, 16] to have 1024 channels, but got 1536 channels instead

In [18]:
class UNetRegressor(nn.Module):
    def __init__(self, init_features: int = 64) -> None:
        super().__init__()

        # Encoder (Contracting Path)
        self.enc1 = DoubleConv(3, init_features)
        self.enc2 = DoubleConv(init_features, init_features * 2)
        self.enc3 = DoubleConv(init_features * 2, init_features * 4)
        self.enc4 = DoubleConv(init_features * 4, init_features * 8)
        self.enc5 = DoubleConv(init_features * 8, init_features * 16)

        # Decoder (Expanding Path)
        self.up4 = nn.ConvTranspose2d(init_features * 16, init_features * 8, kernel_size=2,
                                      stride=2)
        self.dec4 = DoubleConv(init_features * 16, init_features * 8)

        self.up3 = nn.ConvTranspose2d(init_features * 8, init_features * 4, kernel_size=2, stride=2)
        self.dec3 = DoubleConv(init_features * 8, init_features * 4)

        self.up2 = nn.ConvTranspose2d(init_features * 4, init_features * 2, kernel_size=2, stride=2)
        self.dec2 = DoubleConv(init_features * 4, init_features * 2)

        self.up1 = nn.ConvTranspose2d(init_features * 2, init_features, kernel_size=2, stride=2)
        self.dec1 = DoubleConv(init_features * 2, init_features)

        self.final_conv = nn.Conv2d(init_features, 3, kernel_size=1)

        self.pool = nn.MaxPool2d(2)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Encoder
        enc1 = self.enc1(x)
        pool1 = self.pool(enc1)

        print(f"enc1: {enc1.shape}")
        print(f"pool1: {pool1.shape}")

        enc2 = self.enc2(pool1)
        pool2 = self.pool(enc2)

        print(f"enc2: {enc2.shape}")
        print(f"pool2: {pool2.shape}")

        enc3 = self.enc3(pool2)
        pool3 = self.pool(enc3)

        print(f"enc3: {enc3.shape}")
        print(f"pool3: {pool3.shape}")

        enc4 = self.enc4(pool3)
        pool4 = self.pool(enc4)

        print(f"enc4: {enc4.shape}")
        print(f"pool4: {pool4.shape}")

        # Bridge
        enc5 = self.enc5(pool4)

        print(f"enc5: {enc5.shape}")

        # Decoder
        up4 = self.up4(enc5)
        concat4 = torch.cat([enc4, up4], dim=1)
        dec4 = self.dec4(concat4)

        up3 = self.up3(dec4)
        concat3 = torch.cat([enc3, up3], dim=1)
        dec3 = self.dec3(concat3)

        up2 = self.up2(dec3)
        concat2 = torch.cat([enc2, up2], dim=1)
        dec2 = self.dec2(concat2)

        up1 = self.up1(dec2)
        concat1 = torch.cat([enc1, up1], dim=1)
        dec1 = self.dec1(concat1)

        return nn.functional.tanh(self.final_conv(dec1))

In [19]:
model = UNetRegressor()
input = torch.randn(1, 3, 256, 256)
output = model(input)

print(output.shape)

down1: torch.Size([1, 64, 256, 256])
pool1: torch.Size([1, 64, 128, 128])
down2: torch.Size([1, 128, 128, 128])
pool2: torch.Size([1, 128, 64, 64])
down3: torch.Size([1, 256, 64, 64])
pool3: torch.Size([1, 256, 32, 32])
down4: torch.Size([1, 512, 32, 32])
pool4: torch.Size([1, 512, 16, 16])
down5: torch.Size([1, 1024, 16, 16])
torch.Size([1, 3, 256, 256])
