# PyTorch Autoencoder Tutorial
## Complete Implementation from Scratch

This notebook covers:
- Basic Fully Connected Autoencoder
- Convolutional Autoencoder
- Variational Autoencoder (VAE)
- Training and Visualization

In [None]:
# Cell 1: Imports and Setup
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np

# Set random seed for reproducibility
torch.manual_seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

In [None]:
# Cell 2: Configuration
learning_rate = 1e-3
batch_size = 128
epochs = 10
latent_dim = 32
input_dim = 784  # 28x28 MNIST images

In [None]:
# Cell 3: Basic Fully Connected Autoencoder
class Autoencoder(nn.Module):
    """Simple fully connected autoencoder for MNIST"""
    def __init__(self, input_dim=784, latent_dim=32):
        super(Autoencoder, self).__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, latent_dim),
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, input_dim),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded
    
    def encode(self, x):
        return self.encoder(x)
    
    def decode(self, z):
        return self.decoder(z)

print('Basic Autoencoder defined')

In [None]:
# Cell 4: Convolutional Autoencoder
class ConvAutoencoder(nn.Module):
    """Convolutional autoencoder for image data"""
    def __init__(self, channels=1):
        super(ConvAutoencoder, self).__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(channels, 32, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.ConvTranspose2d(32, channels, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

print('Convolutional Autoencoder defined')

In [None]:
# Cell 5: Variational Autoencoder (VAE)
class VAE(nn.Module):
    """Variational Autoencoder with reparameterization trick"""
    def __init__(self, input_dim=784, latent_dim=20):
        super(VAE, self).__init__()
        
        # Encoder
        self.fc1 = nn.Linear(input_dim, 400)
        self.fc_mu = nn.Linear(400, latent_dim)
        self.fc_logvar = nn.Linear(400, latent_dim)
        
        # Decoder
        self.fc3 = nn.Linear(latent_dim, 400)
        self.fc4 = nn.Linear(400, input_dim)
    
    def encode(self, x):
        h = torch.relu(self.fc1(x))
        return self.fc_mu(h), self.fc_logvar(h)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        h = torch.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h))
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z)
        return recon, mu, logvar

def vae_loss(recon_x, x, mu, logvar):
    """VAE loss = Reconstruction loss + KL divergence"""
    BCE = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

print('VAE defined')

In [None]:
# Cell 6: Load MNIST Dataset
transform = transforms.Compose([transforms.ToTensor()])

train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

print(f'Training samples: {len(train_dataset)}')
print(f'Test samples: {len(test_dataset)}')

In [None]:
# Cell 7: Training Functions
def train_autoencoder(model, train_loader, criterion, optimizer, epochs, device):
    model.to(device)
    losses = []
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        
        for batch_idx, (images, _) in enumerate(train_loader):
            images = images.to(device)
            images_flat = images.view(images.size(0), -1)
            
            outputs = model(images_flat)
            loss = criterion(outputs, images_flat)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            
            if batch_idx % 100 == 0:
                print(f'Epoch [{epoch+1}/{epochs}], Step [{batch_idx}/{len(train_loader)}], Loss: {loss.item():.4f}')
        
        avg_loss = total_loss / len(train_loader)
        losses.append(avg_loss)
        print(f'Epoch [{epoch+1}/{epochs}] Average Loss: {avg_loss:.4f}')
    
    return losses

def train_vae(model, train_loader, optimizer, epochs, device):
    model.to(device)
    losses = []
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        
        for batch_idx, (images, _) in enumerate(train_loader):
            images = images.to(device)
            images_flat = images.view(images.size(0), -1)
            
            recon_batch, mu, logvar = model(images_flat)
            loss = vae_loss(recon_batch, images_flat, mu, logvar)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            
            if batch_idx % 100 == 0:
                print(f'Epoch [{epoch+1}/{epochs}], Step [{batch_idx}/{len(train_loader)}], Loss: {loss.item()/len(images):.4f}')
        
        avg_loss = total_loss / len(train_loader.dataset)
        losses.append(avg_loss)
        print(f'Epoch [{epoch+1}/{epochs}] Average Loss: {avg_loss:.4f}')
    
    return losses

print('Training functions defined')

In [None]:
# Cell 8: Initialize and Train Model
# Choose model type: 'basic', 'conv', or 'vae'
model_type = 'basic'  # Change this to experiment

if model_type == 'basic':
    model = Autoencoder(input_dim=784, latent_dim=latent_dim)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    losses = train_autoencoder(model, train_loader, criterion, optimizer, epochs, device)
    
elif model_type == 'vae':
    model = VAE(input_dim=784, latent_dim=20)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    losses = train_vae(model, train_loader, optimizer, epochs, device)
    
else:
    print('For ConvAutoencoder, use 2D images instead of flattened')

print('Training complete!')

In [None]:
# Cell 9: Visualization Functions
def visualize_reconstructions(model, dataloader, device, num_images=10, model_type='basic'):
    model.eval()
    images, _ = next(iter(dataloader))
    images = images[:num_images].to(device)
    
    with torch.no_grad():
        images_flat = images.view(images.size(0), -1)
        if model_type == 'vae':
            reconstructed, _, _ = model(images_flat)
        else:
            reconstructed = model(images_flat)
        reconstructed = reconstructed.view(-1, 1, 28, 28)
    
    fig, axes = plt.subplots(2, num_images, figsize=(15, 3))
    for i in range(num_images):
        axes[0, i].imshow(images[i].cpu().squeeze(), cmap='gray')
        axes[0, i].axis('off')
        axes[1, i].imshow(reconstructed[i].cpu().squeeze(), cmap='gray')
        axes[1, i].axis('off')
    
    axes[0, 0].set_ylabel('Original', size=14)
    axes[1, 0].set_ylabel('Reconstructed', size=14)
    plt.suptitle('Autoencoder Reconstructions', fontsize=16)
    plt.tight_layout()
    plt.show()

def plot_training_loss(losses):
    plt.figure(figsize=(10, 5))
    plt.plot(losses, linewidth=2)
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('Loss', fontsize=12)
    plt.title('Training Loss over Time', fontsize=14)
    plt.grid(True, alpha=0.3)
    plt.show()

print('Visualization functions defined')

In [None]:
# Cell 10: Visualize Results
plot_training_loss(losses)
visualize_reconstructions(model, test_loader, device, num_images=10, model_type=model_type)

In [None]:
# Cell 11: Save Model
model_path = f'{model_type}_autoencoder.pth'
torch.save(model.state_dict(), model_path)
print(f'Model saved to {model_path}')

In [None]:
# Cell 12: Generate Samples (VAE only)
if model_type == 'vae':
    def generate_samples(model, num_samples=10, device='cpu'):
        model.eval()
        with torch.no_grad():
            z = torch.randn(num_samples, model.fc3.in_features).to(device)
            generated = model.decode(z)
            generated = generated.view(-1, 1, 28, 28)
        
        fig, axes = plt.subplots(1, num_samples, figsize=(15, 2))
        for i in range(num_samples):
            axes[i].imshow(generated[i].cpu().squeeze(), cmap='gray')
            axes[i].axis('off')
        plt.suptitle('Generated Samples from VAE', fontsize=14)
        plt.tight_layout()
        plt.show()
    
    generate_samples(model, num_samples=10, device=device)
else:
    print('Sample generation only available for VAE')