In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchinfo import summary
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.nn import TransformerDecoder, TransformerDecoderLayer

In [2]:
transform_enc = TransformerEncoder(
    TransformerEncoderLayer(d_model=512, nhead=8),
    num_layers=2,
)


In [3]:
transform_enc(torch.rand(1, 32, 512)).shape

torch.Size([1, 32, 512])

In [4]:
class ConvEncoder(nn.Module):
    def __init__(self, embed_size, image_size=64, image_channels=3):
        super(ConvEncoder, self).__init__()
        self.embed_size = embed_size
        self.conv_layers = nn.Sequential(
            nn.Conv2d(image_channels, 64, kernel_size=4, stride=2, padding=1),  # Output: [64, 32, 32]
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),  # Output: [128, 16, 16]
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),  # Output: [256, 8, 8]
            nn.ReLU(),
        )

        # Calculate the size of the features after convolutional layers
        conv_output_size = 256 * (image_size // 8) * (image_size // 8)
        self.fc = nn.Linear(conv_output_size, embed_size)

        # Transformer Encoder Layer
        encoder_layers = TransformerEncoderLayer(d_model=embed_size, nhead=8)
        self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers=1)

    def forward(self, x):
        x = self.conv_layers(x)
        x = torch.flatten(x, start_dim=1)
        x = self.fc(x)
        x = x.unsqueeze(1)  # Add sequence dimension
        x = self.transformer_encoder(x)
        return x


In [5]:
class ConvDecoder(nn.Module):
    def __init__(self, embed_size, image_size=64, image_channels=3):
        super(ConvDecoder, self).__init__()
        self.embed_size = embed_size
        self.image_size = image_size

        # Transformer Decoder Layer
        decoder_layers = TransformerDecoderLayer(d_model=embed_size, nhead=8)
        self.transformer_decoder = TransformerDecoder(decoder_layers, num_layers=1)

        self.fc = nn.Linear(embed_size, 256 * (image_size // 8) * (image_size // 8))

        self.conv_layers = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, image_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )

    def forward(self, x, memory):
        x = self.transformer_decoder(x, memory)
        x = torch.flatten(x, start_dim=1)
        x = self.fc(x)
        x = x.view(-1, 256, self.image_size // 8, self.image_size // 8)  # Reshape to match the conv layers input
        x = self.conv_layers(x)
        return x

In [12]:
class ConvTransformerAutoencoder(nn.Module):
    def __init__(self, embed_size, image_size=64, image_channels=3):
        super(ConvTransformerAutoencoder, self).__init__()
        self.encoder = ConvEncoder(embed_size, image_size, image_channels)
        self.decoder = ConvDecoder(embed_size, image_size, image_channels)

    def forward(self, x):
        memory = self.encoder(x)
        print(memory.shape)
        x = self.decoder(memory, memory)
        return x


In [13]:
autoencoder = ConvTransformerAutoencoder(embed_size=256, image_size=64, image_channels=3)
print(summary(autoencoder, depth=2))
input_image = torch.rand(1, 3, 64, 64)  # Example input image
output_image = autoencoder(input_image)
print("Output Image Shape:", output_image.shape)


Layer (type:depth-idx)                                                 Param #
ConvTransformerAutoencoder                                             --
├─ConvEncoder: 1-1                                                     --
│    └─Sequential: 2-1                                                 658,880
│    └─Linear: 2-2                                                     4,194,560
│    └─TransformerEncoder: 2-3                                         1,315,072
├─ConvDecoder: 1-2                                                     --
│    └─TransformerDecoder: 2-4                                         1,578,752
│    └─Linear: 2-5                                                     4,210,688
│    └─Sequential: 2-6                                                 658,627
Total params: 12,616,579
Trainable params: 12,616,579
Non-trainable params: 0
torch.Size([1, 1, 256])
Output Image Shape: torch.Size([1, 3, 64, 64])


In [14]:
autoencoder = ConvTransformerAutoencoder(embed_size=256, image_size=32, image_channels=3)
print(summary(autoencoder, depth=2))
input_image = torch.rand(1, 3, 32, 32)  # Example input image
output_image = autoencoder(input_image)
print("Output Image Shape:", output_image.shape)

Layer (type:depth-idx)                                                 Param #
ConvTransformerAutoencoder                                             --
├─ConvEncoder: 1-1                                                     --
│    └─Sequential: 2-1                                                 658,880
│    └─Linear: 2-2                                                     1,048,832
│    └─TransformerEncoder: 2-3                                         1,315,072
├─ConvDecoder: 1-2                                                     --
│    └─TransformerDecoder: 2-4                                         1,578,752
│    └─Linear: 2-5                                                     1,052,672
│    └─Sequential: 2-6                                                 658,627
Total params: 6,312,835
Trainable params: 6,312,835
Non-trainable params: 0
torch.Size([1, 1, 256])
Output Image Shape: torch.Size([1, 3, 32, 32])


In [15]:
autoencoder = ConvTransformerAutoencoder(embed_size=512, image_size=32, image_channels=3)
print(summary(autoencoder, depth=2))
input_image = torch.rand(1, 3, 32, 32)  # Example input image
output_image = autoencoder(input_image)
print("Output Image Shape:", output_image.shape)

Layer (type:depth-idx)                                                 Param #
ConvTransformerAutoencoder                                             --
├─ConvEncoder: 1-1                                                     --
│    └─Sequential: 2-1                                                 658,880
│    └─Linear: 2-2                                                     2,097,664
│    └─TransformerEncoder: 2-3                                         3,152,384
├─ConvDecoder: 1-2                                                     --
│    └─TransformerDecoder: 2-4                                         4,204,032
│    └─Linear: 2-5                                                     2,101,248
│    └─Sequential: 2-6                                                 658,627
Total params: 12,872,835
Trainable params: 12,872,835
Non-trainable params: 0
torch.Size([1, 1, 512])
Output Image Shape: torch.Size([1, 3, 32, 32])


In [16]:
autoencoder = ConvTransformerAutoencoder(embed_size=256, image_size=32, image_channels=6)
print(summary(autoencoder, depth=2))
input_image = torch.rand(1, 6, 32, 32)  # Example input image
output_image = autoencoder(input_image)
print("Output Image Shape:", output_image.shape)

Layer (type:depth-idx)                                                 Param #
ConvTransformerAutoencoder                                             --
├─ConvEncoder: 1-1                                                     --
│    └─Sequential: 2-1                                                 661,952
│    └─Linear: 2-2                                                     1,048,832
│    └─TransformerEncoder: 2-3                                         1,315,072
├─ConvDecoder: 1-2                                                     --
│    └─TransformerDecoder: 2-4                                         1,578,752
│    └─Linear: 2-5                                                     1,052,672
│    └─Sequential: 2-6                                                 661,702
Total params: 6,318,982
Trainable params: 6,318,982
Non-trainable params: 0
torch.Size([1, 1, 256])
Output Image Shape: torch.Size([1, 6, 32, 32])
