In [1]:
import torch
import torch.nn as nn

# Define placeholder classes for VAE_ResidualBlock and VAE_AttentionBlock
class VAE_ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels

class VAE_AttentionBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.channels = channels

class VAE_Encoder(nn.Sequential):
    def __init__(self):
        super().__init__(
            nn.Conv2d(3, 128, kernel_size=3, padding=1),
            VAE_ResidualBlock(128, 128),
            VAE_ResidualBlock(128, 128),
            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=0),
            VAE_ResidualBlock(128, 256),
            VAE_ResidualBlock(256, 256),
            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0),
            VAE_ResidualBlock(256, 512),
            VAE_ResidualBlock(512, 512),
            nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=0),
            VAE_ResidualBlock(512, 512),
            VAE_ResidualBlock(512, 512),
            VAE_ResidualBlock(512, 512),
            VAE_AttentionBlock(512),
            VAE_ResidualBlock(512, 512),
            nn.GroupNorm(32, 512),
            nn.SiLU(),
            nn.Conv2d(512, 8, kernel_size=3, padding=1),
            nn.Conv2d(8, 8, kernel_size=1, padding=0),
        )

# Create an instance of VAE_Encoder
encoder = VAE_Encoder()

# Print the layers
for idx, layer in enumerate(encoder):
    print(f"Layer {idx + 1}: {layer.__class__.__name__}")
    if isinstance(layer, nn.Conv2d):
        print(f"  Input channels: {layer.in_channels}")
        print(f"  Output channels: {layer.out_channels}")
        print(f"  Kernel size: {layer.kernel_size}")
        print(f"  Stride: {layer.stride}")
        print(f"  Padding: {layer.padding}")
    elif isinstance(layer, VAE_ResidualBlock):
        print(f"  In channels: {layer.in_channels}")
        print(f"  Out channels: {layer.out_channels}")
    elif isinstance(layer, VAE_AttentionBlock):
        print(f"  Channels: {layer.channels}")
    elif isinstance(layer, nn.GroupNorm):
        print(f"  Number of groups: {layer.num_groups}")
        print(f"  Number of channels: {layer.num_channels}")
    elif isinstance(layer, nn.SiLU):
        print("  Activation function: SiLU (Swish)")
    print()

Layer 1: Conv2d
  Input channels: 3
  Output channels: 128
  Kernel size: (3, 3)
  Stride: (1, 1)
  Padding: (1, 1)

Layer 2: VAE_ResidualBlock
  In channels: 128
  Out channels: 128

Layer 3: VAE_ResidualBlock
  In channels: 128
  Out channels: 128

Layer 4: Conv2d
  Input channels: 128
  Output channels: 128
  Kernel size: (3, 3)
  Stride: (2, 2)
  Padding: (0, 0)

Layer 5: VAE_ResidualBlock
  In channels: 128
  Out channels: 256

Layer 6: VAE_ResidualBlock
  In channels: 256
  Out channels: 256

Layer 7: Conv2d
  Input channels: 256
  Output channels: 256
  Kernel size: (3, 3)
  Stride: (2, 2)
  Padding: (0, 0)

Layer 8: VAE_ResidualBlock
  In channels: 256
  Out channels: 512

Layer 9: VAE_ResidualBlock
  In channels: 512
  Out channels: 512

Layer 10: Conv2d
  Input channels: 512
  Output channels: 512
  Kernel size: (3, 3)
  Stride: (2, 2)
  Padding: (0, 0)

Layer 11: VAE_ResidualBlock
  In channels: 512
  Out channels: 512

Layer 12: VAE_ResidualBlock
  In channels: 512
  Out ch

In [None]:
!pip install torchsummary




In [2]:
import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self, n_heads, d_embed):
        super().__init__()
        self.n_heads = n_heads
        self.d_embed = d_embed

class VAE_AttentionBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.groupnorm = nn.GroupNorm(32, channels)
        self.attention = SelfAttention(1, channels)

class VAE_ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.groupnorm_1 = nn.GroupNorm(32, in_channels)
        self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.groupnorm_2 = nn.GroupNorm(32, out_channels)
        self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        if in_channels == out_channels:
            self.residual_layer = nn.Identity()
        else:
            self.residual_layer = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)

class VAE_Decoder(nn.Sequential):
    def __init__(self):
        super().__init__(
            nn.Conv2d(4, 4, kernel_size=1, padding=0),
            nn.Conv2d(4, 512, kernel_size=3, padding=1),
            VAE_ResidualBlock(512, 512),
            VAE_AttentionBlock(512),
            VAE_ResidualBlock(512, 512),
            VAE_ResidualBlock(512, 512),
            VAE_ResidualBlock(512, 512),
            VAE_ResidualBlock(512, 512),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            VAE_ResidualBlock(512, 512),
            VAE_ResidualBlock(512, 512),
            VAE_ResidualBlock(512, 512),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            VAE_ResidualBlock(512, 256),
            VAE_ResidualBlock(256, 256),
            VAE_ResidualBlock(256, 256),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            VAE_ResidualBlock(256, 128),
            VAE_ResidualBlock(128, 128),
            VAE_ResidualBlock(128, 128),
            nn.GroupNorm(32, 128),
            nn.SiLU(),
            nn.Conv2d(128, 3, kernel_size=3, padding=1),
        )

# Create an instance of VAE_Decoder
decoder = VAE_Decoder()

# Print the layers
for idx, layer in enumerate(decoder):
    print(f"Layer {idx + 1}: {layer.__class__.__name__}")
    if isinstance(layer, nn.Conv2d):
        print(f"  Input channels: {layer.in_channels}")
        print(f"  Output channels: {layer.out_channels}")
        print(f"  Kernel size: {layer.kernel_size}")
        print(f"  Stride: {layer.stride}")
        print(f"  Padding: {layer.padding}")
    elif isinstance(layer, VAE_ResidualBlock):
        print(f"  In channels: {layer.in_channels}")
        print(f"  Out channels: {layer.out_channels}")
    elif isinstance(layer, VAE_AttentionBlock):
        print(f"  Channels: {layer.groupnorm.num_channels}")
    elif isinstance(layer, nn.Upsample):
        print(f"  Scale factor: {layer.scale_factor}")
    elif isinstance(layer, nn.GroupNorm):
        print(f"  Number of groups: {layer.num_groups}")
        print(f"  Number of channels: {layer.num_channels}")
    elif isinstance(layer, nn.SiLU):
        print("  Activation function: SiLU (Swish)")
    print()

Layer 1: Conv2d
  Input channels: 4
  Output channels: 4
  Kernel size: (1, 1)
  Stride: (1, 1)
  Padding: (0, 0)

Layer 2: Conv2d
  Input channels: 4
  Output channels: 512
  Kernel size: (3, 3)
  Stride: (1, 1)
  Padding: (1, 1)

Layer 3: VAE_ResidualBlock
  In channels: 512
  Out channels: 512

Layer 4: VAE_AttentionBlock
  Channels: 512

Layer 5: VAE_ResidualBlock
  In channels: 512
  Out channels: 512

Layer 6: VAE_ResidualBlock
  In channels: 512
  Out channels: 512

Layer 7: VAE_ResidualBlock
  In channels: 512
  Out channels: 512

Layer 8: VAE_ResidualBlock
  In channels: 512
  Out channels: 512

Layer 9: Upsample
  Scale factor: 2.0

Layer 10: Conv2d
  Input channels: 512
  Output channels: 512
  Kernel size: (3, 3)
  Stride: (1, 1)
  Padding: (1, 1)

Layer 11: VAE_ResidualBlock
  In channels: 512
  Out channels: 512

Layer 12: VAE_ResidualBlock
  In channels: 512
  Out channels: 512

Layer 13: VAE_ResidualBlock
  In channels: 512
  Out channels: 512

Layer 14: Upsample
  Scal