In [1]:
# Import necessary libraries
import torch                              # PyTorch core library
import torch.nn as nn                     # For building neural network modules
import torch.optim as optim               # For optimization algorithms
from torch.utils.data import DataLoader   # For creating data loaders
from torchvision import datasets, transforms  # For MNIST dataset and image transformations
from torchvision.utils import save_image # For saving generated images

# Check if CUDA (GPU) is available and set the device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "mps")

In [2]:
# Set hyperparameters
batch_size = 128      # Number of images per batch
epochs = 10           # Number of training epochs
learning_rate = 1e-3  # Learning rate for the optimizer
log_interval = 100    # How often to log training progress

# Define a transform to convert images to tensor (values in [0,1])
transform = transforms.ToTensor()

# Download and load the MNIST training dataset
train_dataset = datasets.MNIST(
    root='./data',         # Directory to store data
    train=True,            # Use the training split
    download=True,         # Download if not already available
    transform=transform    # Apply the transformation
)

# Download and load the MNIST test dataset (not used in training here, but available)
test_dataset = datasets.MNIST(
    root='./data',
    train=False,
    download=True,
    transform=transform
)

# Create data loaders for training and testing datasets
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 73550441.99it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 2408982.68it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz





Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 19582218.85it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz





Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 6731635.61it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






In [None]:
# Define the Variational Autoencoder (VAE) model as a subclass of nn.Module
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        # Encoder layers:
        self.fc1 = nn.Linear(784, 400)   # Fully connected layer from input (28x28=784) to hidden dimension 400
        self.fc21 = nn.Linear(400, 20)   # Layer to output the mean (mu) of the latent space
        self.fc22 = nn.Linear(400, 20)   # Layer to output the log variance (logvar) of the latent space

        # Decoder layers:
        self.fc3 = nn.Linear(20, 400)    # Fully connected layer from latent dimension 20 to hidden dimension 400
        self.fc4 = nn.Linear(400, 784)   # Layer to output the reconstructed image (flattened)

    def encode(self, x):
        """
        Encodes the input by passing through a fully connected layer and then splitting
        into mean and log variance vectors.
        """
        h1 = torch.relu(self.fc1(x))   # Apply ReLU activation after the first layer
        return self.fc21(h1), self.fc22(h1)  # Return the mean and log variance

    def reparameterize(self, mu, logvar):
        """
        Reparameterization trick to sample from N(mu, var) from N(0,1).
        """
        std = torch.exp(0.5 * logvar)      # Compute the standard deviation
        eps = torch.randn_like(std)        # Sample random noise from a standard normal distribution
        return mu + eps * std              # Return the sampled latent vector

    def decode(self, z):
        """
        Decodes the latent vector z to reconstruct the image.
        """
        h3 = torch.relu(self.fc3(z))       # Apply ReLU activation after the first decoder layer
        return torch.sigmoid(self.fc4(h3)) # Apply sigmoid to ensure output is between 0 and 1

    def forward(self, x):
        """
        Defines the forward pass through the network.
        """
        # Flatten the input image and encode it to obtain mean and log variance
        mu, logvar = self.encode(x.view(-1, 784))
        # Sample a latent vector using the reparameterization trick
        z = self.reparameterize(mu, logvar)
        print(z.min(), z.max())
        # Decode the latent vector to get the reconstructed image
        recon_x = self.decode(z)
        return recon_x, mu, logvar

In [11]:
# Initialize the VAE model and move it to the appropriate device (CPU or GPU)
model = VAE().to(device)

# Set up the optimizer (Adam optimizer)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

def loss_function(recon_x, x, mu, logvar):
    """
    Computes the VAE loss function as the sum of the reconstruction loss (BCE) and
    the Kullback-Leibler divergence loss.
    """
    # Flatten the input image for computing reconstruction loss
    BCE = nn.functional.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    # Compute the KL divergence between the learned latent distribution and a standard normal distribution
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD  # Total loss

def train(epoch):
    """
    Trains the VAE for one epoch.
    """
    model.train()  # Set model to training mode
    train_loss = 0  # Initialize training loss for the epoch

    # Loop over all batches in the training dataset
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)       # Move data to the device (GPU/CPU)
        optimizer.zero_grad()        # Reset the gradients from the previous iteration

        # Forward pass: compute reconstructed image, mean, and log variance
        recon_batch, mu, logvar = model(data)

        # Compute the loss using the loss_function defined above
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()              # Backpropagate to compute gradients
        train_loss += loss.item()    # Accumulate the batch loss
        optimizer.step()             # Update the model parameters

        # Log progress every 'log_interval' batches
        if batch_idx % log_interval == 0:
            print(f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} "
                  f"({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item() / len(data):.6f}")

    # Print average loss for the epoch
    print(f"====> Epoch: {epoch} Average loss: {train_loss / len(train_loader.dataset):.4f}")

def generate_images(num_images=64, filename='sample.png'):
    """
    Generates new images by sampling from the latent space and decoding them.
    The generated images are saved as a grid in a single image file.
    """
    model.eval()  # Set model to evaluation mode
    with torch.no_grad():  # Disable gradient computation for evaluation
        # Sample latent vectors from a standard normal distribution
        z = torch.randn(num_images, 20).to(device)
        # Decode the latent vectors to generate images
        sample = model.decode(z).cpu()
        # Reshape and save the generated images in a grid
        save_image(sample.view(num_images, 1, 28, 28), filename)
        print(f"Generated images saved to {filename}")

In [12]:
for epoch in range(1, epochs + 1):
    train(epoch)  # Train for one epoch
    # Generate and save a batch of images from the VAE's latent space
    # generate_images(filename=f'sample_epoch_{epoch}.png')

tensor(-3.6934, device='cuda:0', grad_fn=<MinBackward1>) tensor(3.3829, device='cuda:0', grad_fn=<MaxBackward1>)
tensor(-3.2855, device='cuda:0', grad_fn=<MinBackward1>) tensor(3.6798, device='cuda:0', grad_fn=<MaxBackward1>)
tensor(-3.7604, device='cuda:0', grad_fn=<MinBackward1>) tensor(4.2675, device='cuda:0', grad_fn=<MaxBackward1>)
tensor(-3.5199, device='cuda:0', grad_fn=<MinBackward1>) tensor(3.8213, device='cuda:0', grad_fn=<MaxBackward1>)
tensor(-3.6660, device='cuda:0', grad_fn=<MinBackward1>) tensor(4.0352, device='cuda:0', grad_fn=<MaxBackward1>)
tensor(-3.4893, device='cuda:0', grad_fn=<MinBackward1>) tensor(4.6281, device='cuda:0', grad_fn=<MaxBackward1>)
tensor(-4.2769, device='cuda:0', grad_fn=<MinBackward1>) tensor(4.6026, device='cuda:0', grad_fn=<MaxBackward1>)
tensor(-4.9335, device='cuda:0', grad_fn=<MinBackward1>) tensor(5.6856, device='cuda:0', grad_fn=<MaxBackward1>)
tensor(-5.6726, device='cuda:0', grad_fn=<MinBackward1>) tensor(6.0590, device='cuda:0', grad_fn

KeyboardInterrupt: 