In [24]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from modules.SaveOutputs import save_reconstructions
from tqdm import tqdm

import numpy as np

# --- Hyperparameters ---
# The dimension of the latent space, a key hyperparameter for VAEs.
LATENT_DIM = 128
NUM_EPOCHS = 5
BATCH_SIZE = 128
LEARNING_RATE = 1e-3

# --- 1. Define the VAE Model ---
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        # --- New Encoder Network with Split Paths ---
        # The encoder starts with a shared convolutional block.
        self.initial_encoder = nn.Sequential(
            # Input: [1, 64, 64]
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            # Output: [32, 32, 32]
        )

        # Path 1: A simple convolutional path
        self.path1 = nn.Sequential(
            # Input: [32, 32, 32]
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            # Output: [64, 16, 16]
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            # Output: [128, 8, 8]
        )

        # Path 2: A "VGG-like" path with more layers for finer feature extraction
        self.path2 = nn.Sequential(
            # Input: [32, 32, 32]
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), # VGG-like conv
            nn.ReLU(),
            # Output: [64, 32, 32]
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            # Output: [128, 16, 16]
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            # Output: [256, 8, 8]
        )

        # A final convolutional layer to merge the features from both paths
        # Concatenated output will be [128 + 256, 8, 8] = [384, 8, 8]
        self.final_encoder = nn.Sequential(
            nn.Conv2d(384, 512, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            # Output: [512, 4, 4]
            nn.Flatten()
        )

        # Linear layers for latent space, adjusted for the new merged feature size
        self.fc_mu = nn.Linear(512 * 4 * 4, LATENT_DIM)
        self.fc_log_var = nn.Linear(512 * 4 * 4, LATENT_DIM)

        # --- Decoder Network ---
        # The decoder now takes a latent vector and reconstructs the image.
        self.decoder_fc = nn.Linear(LATENT_DIM, 512 * 4 * 4)

        # We use a series of transposed convolutional layers (ConvTranspose2d) to upsample.
        # This part of the decoder is now larger to match the new encoder output.
        self.decoder = nn.Sequential(
            # Input: [512, 4, 4]
            nn.Unflatten(1, (512, 4, 4)),
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            # Output: [256, 8, 8]
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            # Output: [128, 16, 16]
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            # Output: [64, 32, 32]
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
            # Final activation to output pixel values in the [0, 1] range.
            nn.Sigmoid()
            # Output: [1, 64, 64]
        )

    # The reparameterization trick allows us to backpropagate through the sampling process.
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)  # standard deviation
        epsilon = torch.randn_like(std) # sample from a standard normal distribution
        return mu + epsilon * std

    def forward(self, x):
        # Initial shared encoding
        h = self.initial_encoder(x)

        # Split into two paths
        h1 = self.path1(h)
        h2 = self.path2(h)

        # Concatenate features from both paths along the channel dimension
        h_merged = torch.cat((h1, h2), dim=1)

        # Final shared encoding and flatten
        h_final = self.final_encoder(h_merged)

        mu = self.fc_mu(h_final)
        log_var = self.fc_log_var(h_final)

        # Sample a point from the latent space using the reparameterization trick
        z = self.reparameterize(mu, log_var)

        # Decode the latent vector back into an image
        reconstructed_x = self.decoder(self.decoder_fc(z))

        return reconstructed_x, mu, log_var

In [19]:
# --- 2. Define the VAE Loss Function ---
# The VAE loss is a combination of two components:
# 1. Reconstruction Loss: How well the VAE reconstructs the input image.
# 2. KL Divergence Loss: A regularization term that keeps the latent space distribution
#    close to a standard normal distribution.
def vae_loss(reconstructed_x, x, mu, log_var):
    # Binary Cross-Entropy (BCE) for the reconstruction loss.
    # We use reduction='sum' to be consistent with the KL divergence term.
    reconstruction_loss = F.binary_cross_entropy(reconstructed_x, x, reduction='sum')

    # KL Divergence between the learned latent distribution and a standard normal distribution.
    # The formula is 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    kl_divergence_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

    return reconstruction_loss + kl_divergence_loss

# --- 3. Data Loading and Preprocessing ---
# We use the MNIST dataset and resize the images to 64x64 pixels.


In [20]:
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor()  # convert to tensor & scale to [0,1]
])
dataset = datasets.ImageFolder(root="G:\Temp", transform=transform)
train_test_split_var = 0.99
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [int(train_test_split_var*len(dataset)), len(dataset) - int(train_test_split_var*len(dataset))])
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=2, pin_memory=True)
n_train = len(train_loader.dataset)

In [21]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

    # Instantiate the VAE model and move it to the device
model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    # Training loop


Using device: cuda


VAE(
  (initial_encoder): Sequential(
    (0): Conv2d(3, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): ReLU()
  )
  (path1): Sequential(
    (0): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): ReLU()
  )
  (path2): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): ReLU()
    (4): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (5): ReLU()
  )
  (final_encoder): Sequential(
    (0): Conv2d(384, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): ReLU()
    (2): Flatten(start_dim=1, end_dim=-1)
  )
  (fc_mu): Linear(in_features=8192, out_features=128, bias=True)
  (fc_log_var): Linear(in_features=8192, out_features=128, bias=True)
  (decoder_fc): Linear(in_features=128, out_featur

In [27]:
# Check for GPU availability
for epoch in range(15):
    model.train()
    running_loss = 0.0
    for i, (images, _) in tqdm(enumerate(train_loader)):
            # Reshape and move images to the correct device
        images = images.to(device)

            # Forward pass
        reconstructed_images, mu, log_var = model(images)
        loss = vae_loss(reconstructed_images, images, mu, log_var)

            # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / len(train_loader.dataset)
    print(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Loss: {avg_loss:.4f}")
    save_reconstructions(model=model, x=images, out_dir="G:\Python\VAE-latent-space-experiment\outputs", step = epoch+5, device=device, variant="")

print("Training finished!")


1567it [01:22, 18.91it/s]


Epoch [1/5], Loss: 6236.5850


1567it [01:23, 18.68it/s]


Epoch [2/5], Loss: 6235.7465


1567it [01:21, 19.23it/s]


Epoch [3/5], Loss: 6234.2730


1567it [01:21, 19.18it/s]


Epoch [4/5], Loss: 6233.7973


1567it [01:28, 17.64it/s]


Epoch [5/5], Loss: 6232.7996


1567it [01:25, 18.32it/s]


Epoch [6/5], Loss: 6232.1588


1567it [01:26, 18.11it/s]


Epoch [7/5], Loss: 6231.6171


1567it [01:25, 18.37it/s]


Epoch [8/5], Loss: 6230.9793


1567it [01:26, 18.09it/s]


Epoch [9/5], Loss: 6230.5429


1567it [01:28, 17.73it/s]


Epoch [10/5], Loss: 6230.0408


1567it [01:29, 17.57it/s]


Epoch [11/5], Loss: 6229.4530


1567it [01:28, 17.65it/s]


Epoch [12/5], Loss: 6229.1816


1567it [01:30, 17.31it/s]


Epoch [13/5], Loss: 6228.7009


1567it [01:28, 17.74it/s]


Epoch [14/5], Loss: 6228.2270


1567it [01:28, 17.65it/s]

Epoch [15/5], Loss: 6228.0413
Training finished!



