In [None]:
class VAE_unet(nn.Module):
    def __init__(self,in_channels,out__channels): #the size (batch,in_channels,54,81)
        super(VAE, self).__init__()
        
        # Encoder
        self.enc_conv1 = nn.Conv2d(in_channels, 21, kernel_size=4, stride=2, padding=1) 
        self.enc_conv2 = nn.Conv2d(21, 32, kernel_size=4, stride=2, padding=1)  
        self.enc_conv3 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)
        self.enc_conv4 = nn.Conv2d(64, 96, kernel_size=4, stride=2, padding=1)
        self.fc1 = nn.Linear(96*3*5, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc2_mu = nn.Linear(128, in_channels)
        self.fc2_logvar = nn.Linear(128, in_channels)
        
        # Decoder
        self.fc3 = nn.Linear(in_channels, 128)
        self.fc4 = nn.Linear(128, 96*3*5)
        self.dec_conv1 = nn.ConvTranspose2d(96*2, 64, kernel_size=4, stride=2, padding=1)
        self.dec_conv2 = nn.ConvTranspose2d(64*2, 32, kernel_size=4, stride=2, padding=1)  
        self.dec_conv3 = nn.ConvTranspose2d(32*2, 21, kernel_size=4, stride=2, padding=1)  
        self.dec_conv4 = nn.ConvTranspose2d(21*2, 3, kernel_size=4, stride=2, padding=1)   
        self.last_conv= nn.ConvTranspose2d(20,out__channels , kernel_size=3, stride=1, padding=1)   # with the channel of the input
        
    def encode(self, x):
        h1 = torch.relu(self.enc_conv1(x))
        h2 = torch.relu(self.enc_conv2(h1))
        h3 = torch.relu(self.enc_conv3(h2))
        h4 = torch.relu(self.enc_conv4(h3))
        h = h4.view(-1,  96*3*5)
        h = torch.relu(self.fc1(h))
        h = torch.relu(self.fc2(h))
        return self.fc2_mu(h), self.fc2_logvar(h) ,[h4,h3,h2,h1]
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z,in_x,layers):
        h = torch.relu(self.fc3(z))
        h = torch.relu(self.fc4(h))
        h = h.view(-1, 96, 3, 5)
        h=torch.cat((h,layers[0]),dim=1)
        h = torch.relu(self.dec_conv1(h))
        h=torch.cat((h,layers[1]),dim=1)
        h = torch.relu(self.dec_conv2(h))
        padding = (0,0, (13-h.shape[2])//2,(13-h.shape[2]) - (13-h.shape[2]) // 2)
        padded_h = F.pad(h, padding)      
        h=torch.cat((padded_h,layers[2]),dim=1)
        h = torch.relu(self.dec_conv3(h))
        padding = (0,0, (27-h.shape[2])//2,(27-h.shape[2]) - (27-h.shape[2]) // 2)
        padded_h = F.pad(h, padding)
        h=torch.cat((padded_h,layers[3]),dim=1)
        h = self.dec_conv4(h)
        padding = (0, 81-h.shape[3], (54-h.shape[2])//2,(54-h.shape[2]) - (54-h.shape[2]) // 2)
        padded_h = F.pad(h, padding)
        h_cat=torch.cat((padded_h,in_x),dim=1)
        h=self.last_conv(h_cat)
        return h 
    
    def forward(self, x):
        in_x=x.clone()
        mu, logvar,layers= self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z,in_x,layers), mu, logvar

def loss_function_mse(recon_x, x, mu, logvar):
    MSE = nn.functional.mse_loss(recon_x, x)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    return MSE + KLD,MSE

In [None]:
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
            
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

    
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=False):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = (DoubleConv(n_channels, 4))
        self.down1 = (Down(4, 8))
        self.down2 = (Down(8, 16))
        self.down3 = (Down(16, 32))
        self.down4 = (Down(32, 64))
        factor = 2 if bilinear else 1
        self.down5 = (Down(64, 128 // factor))
        self.up1 = (Up(128, 64 // factor, bilinear))
        self.up2 = (Up(64, 32 // factor, bilinear))
        self.up3 = (Up(32, 16 // factor, bilinear))
        self.up4 = (Up(16, 8 // factor, bilinear))
        self.up5 = (Up(8, 4, bilinear))
        self.outc = (OutConv(n_channels+4, n_classes))

    def forward(self, x):
        in_x = x
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x6 = self.down5(x5)
        x = self.up1(x6, x5)
        x = self.up2(x, x4)
        x = self.up3(x, x3)
        x = self.up4(x, x2)
        x = self.up5(x, x1)
        in_x=torch.cat([in_x,x],dim=1)
        logits = self.outc(in_x)

        return logits