A minimal, customizable PyTorch package for building and training convolutional autoencoders based on a simplified U-Net architecture (without skip connections). Ideal for representation learning, image compression, and reconstruction tasks.
- 📦 Modular architecture (
Encoder
,Decoder
,AutoEncoder
) - 🔁 Symmetric U-Net-like design without skip connections
- ⚡ Tanh output activation for stable image reconstruction
- 🧠 Residual blocks with RMS normalization and SiLU activation
- 📱 Designed for image inputs (
3×H×W
) with configurable channels and latent dim - 🧪 Works with batched input tensors (e.g.,
torch.Tensor[B, C, H, W]
)
pip install convolutional-autoencoder-pytorch
convolutional-autoencoder-pytorch/
├── convolutional_autoencoder_pytorch/
│ ├── __init__.py
│ └── module.py # All architecture classes and logic
├── pyproject.toml
├── LICENSE
└── README.md
import torch
from convolutional_autoencoder_pytorch import AutoEncoder
model = AutoEncoder(
dim=64,
dim_mults=(1, 2, 4, 8),
dim_latent=128,
image_channels=3
)
images = torch.randn(8, 3, 128, 128) # batch of images
reconstructed, latent = model(images)
# Or just get the reconstruction
recon = model.reconstruct(images)
import torch.nn.functional as F
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
def train_step(images):
model.train()
optimizer.zero_grad()
recon, _ = model(images)
loss = F.mse_loss(recon, images)
loss.backward()
optimizer.step()
return loss.item()
Parameter | Description | Default |
---|---|---|
dim |
Base channel size | 64 |
dim_mults |
List of multipliers for down/up blocks | (1, 2, 4, 8) |
dim_latent |
Latent bottleneck dimension | 64 |
image_channels |
Input/output image channels (e.g., 3) | 3 |
dropout |
Dropout probability | 0.0 |
Developed by Mehran Bazrafkan
This project is an original implementation of a simplified autoencoder architecture. Some ideas and design inspirations were drawn from the open-source denoising-diffusion-pytorch
project by Phil Wang, but the code and architecture were written independently.
Contributions, issues, and feedback are welcome via GitHub Issues.
This project is licensed under the terms of the MIT LICENSE.