In [None]:
import torch
import torch.nn as nn

torch.cuda.empty_cache()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
modalities = 3
H_in, W_in = 512, 512

In [None]:
class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, gn_groups=8):
        super(ResBlock, self).__init__()
        self.residual = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.block = nn.Sequential(
            nn.GroupNorm(gn_groups, in_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.GroupNorm(gn_groups, out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        )
    def forward(self, x):
        res = self.residual(x)
        x = self.block(x)
        x = x + res
        return x

In [None]:
class UNetEncoder(nn.Module):
    def __init__(self, in_channels, out_shape=(1, H_in, W_in), gn_groups=8, n_init_features=32):
        super(UNetEncoder, self).__init__()
        H_in, W_in = out_shape[1], out_shape[2]
        self.initial_conv = nn.Conv2d(in_channels, n_init_features, kernel_size=3, stride=1, padding=1)
        self.blocks = nn.ModuleList([
            ResBlock(n_init_features * 1, n_init_features * 1, gn_groups),
            nn.Sequential(
                ResBlock(n_init_features * 1, n_init_features * 2, gn_groups),
                ResBlock(n_init_features * 2, n_init_features * 2, gn_groups)
            ),
            nn.Sequential(
                ResBlock(n_init_features * 2, n_init_features * 4, gn_groups),
                ResBlock(n_init_features * 4, n_init_features * 4, gn_groups)
            ),
            nn.Sequential(
                ResBlock(n_init_features * 4, n_init_features * 8, gn_groups),
                ResBlock(n_init_features * 8, n_init_features * 8, gn_groups),
                ResBlock(n_init_features * 8, n_init_features * 8, gn_groups),
                ResBlock(n_init_features * 8, n_init_features * 8, gn_groups)
            )])
        self.downsamples = nn.ModuleList([
            nn.AdaptiveMaxPool2d((H_in // 2, W_in // 2)),
            nn.AdaptiveMaxPool2d((H_in // 4, W_in // 4)),
            nn.AdaptiveMaxPool2d((H_in // 8, W_in // 8))
        ])
        
    def forward(self, x):
        x = self.initial_conv(x)
        x = nn.Dropout2d(0.2)(x)
        x = self.blocks[0](x)
        skips = []
        for block, downsample in zip(self.blocks[1:], self.downsamples):
            x = block(x)
            x = downsample(x)
            skips.append(x)
        return x, skips

In [None]:
class UNetDecoder(nn.Module):
    def __init__(self, out_shape=(1, H_in, W_in), gn_groups=8, n_init_features=32):
        super(UNetDecoder, self).__init__()
        self.blocks = nn.ModuleList([
            ResBlock(n_init_features * 4, n_init_features * 4, gn_groups),
            ResBlock(n_init_features * 2, n_init_features * 2, gn_groups),
            ResBlock(n_init_features * 1, n_init_features * 1, gn_groups)
        ])
        self.upsamples = nn.ModuleList([
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Upsample(scale_factor=2, mode='bilinear')
        ])
        self.downsize_features = nn.ModuleList([
            nn.Conv2d(n_init_features * 8, n_init_features * 4, kernel_size=1, padding=0),
            nn.Conv2d(n_init_features * 4, n_init_features * 2, kernel_size=1, padding=0),
            nn.Conv2d(n_init_features * 2, n_init_features * 1, kernel_size=1, padding=0)
        ])
        self.final_convs = nn.ModuleList([
            nn.Conv2d(n_init_features * 1, n_init_features * 1, kernel_size=3, stride=1, padding='same'),
            nn.Conv2d(n_init_features * 1, 1, kernel_size=1, padding='same')
        ])
        
    def forward(self, x, skips):
        for block, upsample, downsample_channels, skip in zip(self.blocks, self.upsamples, self.downsize_features, reversed(skips)):
            x = downsample_channels(x)
            x = upsample(x)
            # x = x + skip # TODO: add skip connections, currently raises error
            x = block(x)
        for final_conv in self.final_convs:
            x = final_conv(x)
        return x

In [None]:
class VAEDecoder(nn.Module):
    def __init__(self, out_shape=(1, H_in, W_in), gn_groups=8, n_init_features=32):
        super(VAEDecoder, self).__init__()
        self.initial_layers = nn.Sequential(
            nn.GroupNorm(gn_groups, n_init_features * 1),
            nn.ReLU(),
            nn.Conv2d(n_init_features * 1, n_init_features // 2, kernel_size=3, stride=2, padding=1),
            nn.Flatten(),
            nn.Linear(n_init_features // 2 * H_in // 2 * W_in // 2, 256)
        )
        self.mu = nn.Linear(256, 128)
        self.logvar = nn.Linear(256, 128)
        self.sample = lambda mu, logvar: mu + torch.randn_like(logvar) * torch.exp(0.5 * logvar)
        self.upsample = nn.Sequential(
            nn.Linear(128, n_init_features // 4 * H_in // 16 * W_in // 16),
            nn.ReLU(),
            nn.Unflatten(1, (n_init_features // 4, H_in // 16, W_in // 16)),
            nn.Conv2d(n_init_features // 4, n_init_features * 8, kernel_size=3, stride=1, padding=1),
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(n_init_features * 8, n_init_features * 4, kernel_size=3, stride=1, padding=1),
            nn.Upsample(scale_factor=2, mode='bilinear'),
            ResBlock(n_init_features * 4, n_init_features * 4, gn_groups),
            nn.Conv2d(n_init_features * 4, n_init_features * 2, kernel_size=3, stride=1, padding=1),
            nn.Upsample(scale_factor=2, mode='bilinear'),
            ResBlock(n_init_features * 2, n_init_features * 2, gn_groups),
            nn.Conv2d(n_init_features * 2, n_init_features * 1, kernel_size=3, stride=1, padding=1),
            nn.Upsample(scale_factor=2, mode='bilinear'),
            ResBlock(n_init_features * 1, n_init_features * 1, gn_groups)
        )
        self.final_conv = nn.Sequential(
            nn.Conv2d(n_init_features * 1, n_init_features * 1, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(n_init_features * 1, 1, kernel_size=1)
        )

    def forward(self, x):
        x = self.initial_layers(x)
        mu, logvar = self.mu(x), self.logvar(x)
        x = self.sample(mu, logvar)
        x = self.upsample(x)
        x = self.final_conv(x)
        return x, mu, logvar

In [None]:
class UNet(nn.Module):
    def __init__(self, in_channels, out_shape=(1, H_in, W_in), gn_groups=8, n_init_features=32):
        super(UNet, self).__init__()
        self.encoder = UNetEncoder(in_channels, out_shape, gn_groups, n_init_features)
        self.decoder = UNetDecoder(out_shape, gn_groups, n_init_features)
        self.vae_decoder = VAEDecoder(out_shape, gn_groups, n_init_features)
        
    def forward(self, x):
        encoder_output, skips = self.encoder(x)
        vae_output, mu, logvar = self.vae_decoder(encoder_output)
        decoder_output = self.decoder(encoder_output, skips)
        return decoder_output, vae_output, mu, logvar

In [None]:
model = UNet(modalities, (1, H_in, W_in)).to(device)
model(torch.randn(1, modalities, H_in, W_in).to(device))

## Format number of parameters nicely
print(f"Number of parameters: {sum(p.numel() for p in model.parameters()):,}")

RuntimeError: Expected weight to be a vector of size equal to the number of channels in input, but got weight of shape [32] and input of shape [1, 256, 64, 64]