In [None]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import os

from torchinfo import summary


In [None]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class VAEEncoder(nn.Module):
    def __init__(self, in_channels=1, latent_dim=32, num_classes=3, label_embedding_dim=16):
        super(VAEEncoder, self).__init__()
        self.label_emb = nn.Embedding(num_classes, label_embedding_dim)
        self.conv1 = nn.Conv3d(in_channels + label_embedding_dim, 16, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv3d(16, 32, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv3d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv4_mu = nn.Conv3d(64, latent_dim, kernel_size=4, stride=1, padding=0)
        self.conv4_log_var = nn.Conv3d(64, latent_dim, kernel_size=4, stride=1, padding=0)

    def forward(self, x, label_embedding):
        label_embedding = label_embedding.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        label_embedding = label_embedding.expand(-1, -1, x.shape[2], x.shape[3], x.shape[4])
        x = torch.cat([x, label_embedding], dim=1)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        mu = self.conv4_mu(x)
        log_var = self.conv4_log_var(x)
        return mu, log_var


class VAEDecoder(nn.Module):
    def __init__(self, latent_dim=32, out_channels=1, num_classes=3, label_embedding_dim=16):
        super(VAEDecoder, self).__init__()
        self.label_emb = nn.Embedding(num_classes, label_embedding_dim)
        self.conv_trans1 = nn.ConvTranspose3d(latent_dim + label_embedding_dim, 32, kernel_size=4, stride=1, padding=0)
        self.conv_trans2 = nn.ConvTranspose3d(32, 16, kernel_size=4, stride=2, padding=1)
        self.conv_trans3 = nn.ConvTranspose3d(16, 4, kernel_size=3, stride=1, padding=1)
        self.conv_trans4 = nn.ConvTranspose3d(4, out_channels, kernel_size=4, stride=2, padding=1)
        self.output_layer = nn.Sigmoid()

    def forward(self, z, label_embedding):
        
        label_embedding = label_embedding.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        label_embedding = label_embedding.expand(-1, -1, z.shape[2], z.shape[3], z.shape[4])  
        z = torch.cat([z, label_embedding], dim=1)
        z = F.gelu(self.conv_trans1(z))
        z = F.gelu(self.conv_trans2(z))
        z = F.gelu(self.conv_trans3(z))
        z = self.conv_trans4(z)
        return self.output_layer(z)


class VAE(nn.Module):
    def __init__(self, in_channels=1, latent_dim=32, out_channels=1, num_classes=3, label_embedding_dim=16):
        super(VAE, self).__init__()
        self.encoder = VAEEncoder(in_channels, latent_dim, num_classes, label_embedding_dim)
        self.decoder = VAEDecoder(latent_dim, out_channels, num_classes, label_embedding_dim)

    def forward(self, x, label_embedding):
        mu, log_var = self.encoder(x, label_embedding)
        z = self.reparameterize(mu, log_var)
        recon_x = self.decoder(z, label_embedding)
        return recon_x, mu, log_var

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std


class VAEGenerator(nn.Module):
    def __init__(self, in_channel=1, out_channel=1, latent_dim=32, num_classes=3, label_embedding_dim=16):
        super(VAEGenerator, self).__init__()
        self.vae = VAE(in_channel, latent_dim, out_channel, num_classes, label_embedding_dim)

    def forward(self, x, label_embedding):
        recon_x, mu, log_var = self.vae(x, label_embedding)
        return recon_x, mu, log_var



class Discriminator(torch.nn.Module):
    def __init__(self, in_channel=1, out_channel=1, num_classes=3, label_embedding_dim=16):
        super(Discriminator, self).__init__()
        self.label_emb = nn.Embedding(num_classes, label_embedding_dim)
        self.conv_net = nn.Sequential(
            
            nn.Conv3d(in_channels=in_channel + label_embedding_dim, out_channels=64, kernel_size=4, stride=2, padding=1),
            torch.nn.InstanceNorm3d(64),
            nn.LeakyReLU(0.2),
            
            nn.Conv3d(64, 32, kernel_size=4, stride=2, padding=1),
            torch.nn.InstanceNorm3d(32),
            nn.LeakyReLU(0.2),

            nn.Conv3d(32, 16, kernel_size=3, stride=1, padding=1),
            torch.nn.InstanceNorm3d(16),
            nn.LeakyReLU(0.2),

            nn.Conv3d(16, 4, kernel_size=3, stride=1, padding=1),
            torch.nn.InstanceNorm3d(4),
            nn.LeakyReLU(0.2),
            
            nn.Conv3d(4, out_channel, kernel_size=4, stride=1, padding=0),
        )

    def forward(self, z, label_embedding):
        label_embedding = label_embedding.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        label_embedding = label_embedding.expand(-1, -1, z.shape[2], z.shape[3], z.shape[4])
        z = torch.cat([z, label_embedding], dim=1)
        return self.conv_net(z)


def gradient_penalty(critic, real, fake, label_embedding, device="cpu"):
    BATCH_SIZE, C, H, W, D = real.shape
    beta = torch.rand((BATCH_SIZE, 1, 1, 1, 1), device=device)
    beta = beta.expand_as(real)
    interpolated_images = real * beta + fake.detach() * (1 - beta)
    interpolated_images.requires_grad_(True)

    mixed_scores = critic(interpolated_images, label_embedding)

    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(BATCH_SIZE, -1)
    gradient_norm = gradient.norm(2, dim=1)

    return torch.mean((gradient_norm - 1) ** 2)

In [None]:


input_shape = (2, 1, 128, 128, 128)


vae_generator = VAEGenerator(
    in_channel=1, 
    out_channel=1, 
    latent_dim=32, 
    num_classes=3, 
    label_embedding_dim=16
).to(device)

discriminator = Discriminator(
    in_channel=1, 
    out_channel=1, 
    num_classes=3, 
    label_embedding_dim=16
).to(device)


batch_size = 2
in_channels = 1
D, H, W = 128, 128, 128
x_dummy = torch.randn(batch_size, in_channels, D, H, W).to(device)
label_dummy = torch.randint(0, 3, (batch_size,)).to(device)
label_embedding_dummy_gen = vae_generator.vae.encoder.label_emb(label_dummy)
label_embedding_dummy_disc = discriminator.label_emb(label_dummy)


def calculate_model_memory(model):
    total_params = sum(p.numel() for p in model.parameters())
    total_size_bytes = total_params * 4  
    total_size_MB = total_size_bytes / (1024 ** 2)
    return total_params, total_size_MB


print("VAEGenerator:")
summary(
    vae_generator,
    input_data=(x_dummy, label_embedding_dummy_gen),
    depth=5,
    verbose=2
)

print("\nDiscriminator:")
summary(
    discriminator,
    input_data=(x_dummy, label_embedding_dummy_disc),
    depth=5,
    verbose=2
)






