In [1]:
from torchsummary import summary
import torch
import torch.nn as nn
import torch.nn.functional as F

In [57]:
def conv(in_channels, out_channels, kernel_size=3, stride=1, padding=1):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
    )


class BetaUNet(nn.Module):
    """A mixture of BetaVAE and UNet."""
    def __init__(self, beta=1):
        super().__init__()

        self.latent_dim = 2048

        self.dconv_down1 = conv(3, 32, kernel_size=5, stride=2, padding=2)
        self.dconv_down2 = conv(32, 64, kernel_size=5, stride=2, padding=2)
        self.dconv_down3 = conv(64, 128, kernel_size=5, stride=2, padding=2)
        self.dconv_down4 = conv(128, 256, kernel_size=5, stride=2, padding=2)
        self.dconv_down5 = conv(256, 512, kernel_size=5, stride=1, padding=0)

        self.mu_fc = nn.Linear(512*2*2, self.latent_dim)
        self.logvar_fc = nn.Linear(512*2*2, self.latent_dim)  # isotropic Gaussian

        self.latent_fc = nn.Linear(self.latent_dim, 512*2*2)

        self.up5 = nn.ConvTranspose2d(512, 512, kernel_size=6)
        self.dconv_up5 = conv(512+256, 256, kernel_size=3, stride=1)

        self.up4 = nn.ConvTranspose2d(256, 256, kernel_size=2, stride=2)  
        self.dconv_up4 = conv(256+128, 128, kernel_size=3, stride=1)

        self.up3 = nn.ConvTranspose2d(128, 128, kernel_size=2, stride=2)
        self.dconv_up3 = conv(128+64, 64, kernel_size=3, stride=1)

        self.up2 = nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2)
        self.dconv_up2 = conv(64+32, 32, kernel_size=3, stride=1)

        self.up1 = nn.ConvTranspose2d(32, 32, kernel_size=2, stride=2)
        self.conv_last = nn.Conv2d(32, 3, 1)


    def forward(self, x):
        conv1 = self.dconv_down1(x)
        conv2 = self.dconv_down2(conv1)
        conv3 = self.dconv_down3(conv2)
        conv4 = self.dconv_down4(conv3)
        conv5 = self.dconv_down5(conv4)
        # encoding = conv5.flatten()
        # print(encoding.shape)
        
        # mu = self.mu_fc(encoding)
        # logvar = self.logvar_fc(encoding)
        # std = torch.exp(0.5 * logvar)
        # latent_sample = std * torch.randn_like(logvar) + mu
        # latent_out = self.latent_fc(latent_sample)
        # latent_out = latent_out.view(-1, 512, 1, 1)
        
        # conv5_up = self.up5(latent_out)
        # dconv_up5 = self.dconv_up5(torch.cat([conv5_up, conv4], dim=1))
        # conv4_up = self.up4(dconv_up5)
        # dconv_up4 = self.dconv_up4(torch.cat([conv4_up, conv3], dim=1))
        # conv3_up = self.up3(dconv_up4)
        # dconv_up3 = self.dconv_up3(torch.cat([conv3_up, conv2], dim=1))
        # conv2_up = self.up2(dconv_up3)
        # dconv_up2 = self.dconv_up2(torch.cat([conv2_up, conv1], dim=1))
        # conv1_up = self.up1(dconv_up2)
        # out = self.conv_last(conv1_up)
        
        # return torch.sigmoid(out), mu, std**2
        return conv5
    
    def loss(self, prediction, original, mu, var):
        reconstruction_loss = F.binary_cross_entropy(
            prediction, original, reduction="sum"
        )

        kl_divergence = -0.5 * (1 + var.log() - mu**2 - var).sum()
        
        return reconstruction_loss + self.beta * kl_divergence



In [58]:
model = BetaUNet()
summary(model, (3, 96, 96))

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 32, 48, 48]          --
|    └─Conv2d: 2-1                       [-1, 32, 48, 48]          2,432
|    └─BatchNorm2d: 2-2                  [-1, 32, 48, 48]          64
|    └─ReLU: 2-3                         [-1, 32, 48, 48]          --
├─Sequential: 1-2                        [-1, 64, 24, 24]          --
|    └─Conv2d: 2-4                       [-1, 64, 24, 24]          51,264
|    └─BatchNorm2d: 2-5                  [-1, 64, 24, 24]          128
|    └─ReLU: 2-6                         [-1, 64, 24, 24]          --
├─Sequential: 1-3                        [-1, 128, 12, 12]         --
|    └─Conv2d: 2-7                       [-1, 128, 12, 12]         204,928
|    └─BatchNorm2d: 2-8                  [-1, 128, 12, 12]         256
|    └─ReLU: 2-9                         [-1, 128, 12, 12]         --
├─Sequential: 1-4                        [-1, 256, 6, 6]           --
|

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 32, 48, 48]          --
|    └─Conv2d: 2-1                       [-1, 32, 48, 48]          2,432
|    └─BatchNorm2d: 2-2                  [-1, 32, 48, 48]          64
|    └─ReLU: 2-3                         [-1, 32, 48, 48]          --
├─Sequential: 1-2                        [-1, 64, 24, 24]          --
|    └─Conv2d: 2-4                       [-1, 64, 24, 24]          51,264
|    └─BatchNorm2d: 2-5                  [-1, 64, 24, 24]          128
|    └─ReLU: 2-6                         [-1, 64, 24, 24]          --
├─Sequential: 1-3                        [-1, 128, 12, 12]         --
|    └─Conv2d: 2-7                       [-1, 128, 12, 12]         204,928
|    └─BatchNorm2d: 2-8                  [-1, 128, 12, 12]         256
|    └─ReLU: 2-9                         [-1, 128, 12, 12]         --
├─Sequential: 1-4                        [-1, 256, 6, 6]           --
|

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 3, 96, 96]           --
|    └─Conv2d: 2-1                       [-1, 64, 48, 48]          3,136
|    └─LeakyReLU: 2-2                    [-1, 64, 48, 48]          --
|    └─Conv2d: 2-3                       [-1, 64, 24, 24]          65,600
|    └─BatchNorm2d: 2-4                  [-1, 64, 24, 24]          128
|    └─LeakyReLU: 2-5                    [-1, 64, 24, 24]          --
|    └─Conv2d: 2-6                       [-1, 128, 12, 12]         131,200
|    └─BatchNorm2d: 2-7                  [-1, 128, 12, 12]         256
|    └─LeakyReLU: 2-8                    [-1, 128, 12, 12]         --
|    └─Conv2d: 2-9                       [-1, 256, 6, 6]           524,544
|    └─BatchNorm2d: 2-10                 [-1, 256, 6, 6]           512
|    └─LeakyReLU: 2-11                   [-1, 256, 6, 6]           --
|    └─Conv2d: 2-12                      [-1, 512, 3, 3]         

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 3, 96, 96]           --
|    └─Conv2d: 2-1                       [-1, 64, 48, 48]          3,136
|    └─LeakyReLU: 2-2                    [-1, 64, 48, 48]          --
|    └─Conv2d: 2-3                       [-1, 64, 24, 24]          65,600
|    └─BatchNorm2d: 2-4                  [-1, 64, 24, 24]          128
|    └─LeakyReLU: 2-5                    [-1, 64, 24, 24]          --
|    └─Conv2d: 2-6                       [-1, 128, 12, 12]         131,200
|    └─BatchNorm2d: 2-7                  [-1, 128, 12, 12]         256
|    └─LeakyReLU: 2-8                    [-1, 128, 12, 12]         --
|    └─Conv2d: 2-9                       [-1, 256, 6, 6]           524,544
|    └─BatchNorm2d: 2-10                 [-1, 256, 6, 6]           512
|    └─LeakyReLU: 2-11                   [-1, 256, 6, 6]           --
|    └─Conv2d: 2-12                      [-1, 512, 3, 3]         

In [12]:
import torch.nn as nn

m = nn.Sigmoid()
loss = nn.BCELoss()
input = torch.randn(3, requires_grad=True)
target = torch.ones(input.shape)
print(target)
output = loss(m(input), target)
output.backward()

tensor([1., 1., 1.])
