# Module 06: Unsupervised Learning - Autoencoders

**Learning Objectives:**
- Understand representation learning without labels
- Implement basic and variational autoencoders
- Visualize learned latent spaces
- Connect autoencoders to sparse representations

**Prerequisites:** Modules 01-05 (Neural networks, supervised learning, graph basics, topology, sparse networks)

---

## 1. Introduction to Unsupervised Learning

So far, we've worked with **supervised learning** where we have input-output pairs (X, y).

In **unsupervised learning**, we only have inputs X and want to discover:
- Hidden structure in the data
- Compressed representations
- Generative models

**Why does this matter for our architecture?**
- Multi-modal learning often requires learning shared representations
- Sparse autoencoders connect to our topology work
- Representation learning is key to cross-modal binding

In [None]:
# Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import sys
sys.path.insert(0, '../..')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"[OK] Using device: {device}")

## 2. The Autoencoder Concept

An autoencoder learns to:
1. **Encode**: Compress input X into a smaller latent representation Z
2. **Decode**: Reconstruct input X' from Z

```
Input X  -->  [Encoder]  -->  Latent Z  -->  [Decoder]  -->  Output X'
(784)           |             (32)              |            (784)
                |                               |
           Compression                    Reconstruction
```

The key insight: if X' is close to X, then Z must capture the essential information!

In [None]:
# Load MNIST for experiments
transform = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = datasets.MNIST(root='../../data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='../../data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

print(f"[OK] Loaded {len(train_dataset)} training images")
print(f"[OK] Loaded {len(test_dataset)} test images")

## 3. Implementing a Basic Autoencoder

In [None]:
class BasicAutoencoder(nn.Module):
    """Simple fully-connected autoencoder."""
    
    def __init__(self, input_dim=784, hidden_dim=256, latent_dim=32):
        super().__init__()
        
        # Encoder: compress input to latent space
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, latent_dim),
        )
        
        # Decoder: reconstruct from latent space
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid(),  # Output in [0, 1] range
        )
        
        self.latent_dim = latent_dim
    
    def encode(self, x):
        """Encode input to latent representation."""
        return self.encoder(x)
    
    def decode(self, z):
        """Decode latent representation to reconstruction."""
        return self.decoder(z)
    
    def forward(self, x):
        """Full forward pass: encode then decode."""
        z = self.encode(x)
        x_recon = self.decode(z)
        return x_recon, z

# Create model
autoencoder = BasicAutoencoder(latent_dim=32).to(device)
print(f"[OK] Created autoencoder with {sum(p.numel() for p in autoencoder.parameters()):,} parameters")

## 4. Training the Autoencoder

The loss function is **reconstruction error**: how different is X' from X?

Common choices:
- **MSE Loss**: Mean squared error (good for continuous data)
- **BCE Loss**: Binary cross-entropy (good for binary/image data)

In [None]:
def train_autoencoder(model, train_loader, epochs=10, lr=0.001):
    """Train autoencoder with reconstruction loss."""
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    
    history = {'loss': []}
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        
        for batch_idx, (data, _) in enumerate(train_loader):
            # Flatten images
            data = data.view(data.size(0), -1).to(device)
            
            # Forward pass
            recon, z = model(data)
            loss = criterion(recon, data)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        avg_loss = total_loss / len(train_loader)
        history['loss'].append(avg_loss)
        print(f"Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f}")
    
    return history

# Train the model
print("Training autoencoder...")
history = train_autoencoder(autoencoder, train_loader, epochs=10)
print("[OK] Training complete!")

In [None]:
# Visualize training progress
plt.figure(figsize=(10, 4))
plt.plot(history['loss'], 'b-', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Reconstruction Loss')
plt.title('Autoencoder Training')
plt.grid(True, alpha=0.3)
plt.show()

## 5. Visualizing Reconstructions

In [None]:
def visualize_reconstructions(model, test_loader, n_samples=10):
    """Show original images and their reconstructions."""
    model.eval()
    
    # Get a batch
    data, labels = next(iter(test_loader))
    data = data[:n_samples]
    data_flat = data.view(data.size(0), -1).to(device)
    
    with torch.no_grad():
        recon, z = model(data_flat)
    
    # Plot
    fig, axes = plt.subplots(2, n_samples, figsize=(15, 3))
    
    for i in range(n_samples):
        # Original
        axes[0, i].imshow(data[i].squeeze().numpy(), cmap='gray')
        axes[0, i].axis('off')
        if i == 0:
            axes[0, i].set_title('Original', fontsize=10)
        
        # Reconstruction
        axes[1, i].imshow(recon[i].view(28, 28).cpu().numpy(), cmap='gray')
        axes[1, i].axis('off')
        if i == 0:
            axes[1, i].set_title('Reconstructed', fontsize=10)
    
    plt.suptitle('Autoencoder Reconstructions', fontsize=12)
    plt.tight_layout()
    plt.show()

visualize_reconstructions(autoencoder, test_loader)

## 6. Exploring the Latent Space

The latent space Z is where the magic happens. Let's visualize it using t-SNE.

In [None]:
from sklearn.manifold import TSNE

def visualize_latent_space(model, test_loader, n_samples=2000):
    """Visualize latent space using t-SNE."""
    model.eval()
    
    latents = []
    labels_list = []
    
    with torch.no_grad():
        for data, labels in test_loader:
            data = data.view(data.size(0), -1).to(device)
            z = model.encode(data)
            latents.append(z.cpu().numpy())
            labels_list.append(labels.numpy())
            
            if sum(len(l) for l in latents) >= n_samples:
                break
    
    latents = np.concatenate(latents)[:n_samples]
    labels_arr = np.concatenate(labels_list)[:n_samples]
    
    # t-SNE projection
    print("Computing t-SNE projection...")
    tsne = TSNE(n_components=2, random_state=42, perplexity=30)
    latents_2d = tsne.fit_transform(latents)
    
    # Plot
    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(latents_2d[:, 0], latents_2d[:, 1], 
                          c=labels_arr, cmap='tab10', alpha=0.6, s=10)
    plt.colorbar(scatter, label='Digit')
    plt.xlabel('t-SNE 1')
    plt.ylabel('t-SNE 2')
    plt.title('Latent Space Visualization (t-SNE)')
    plt.show()

visualize_latent_space(autoencoder, test_loader)

## 7. Sparse Autoencoder

Now let's connect this to our sparse network work! A **sparse autoencoder** encourages the latent representation to be sparse (mostly zeros).

This is done by adding a **sparsity penalty** to the loss:

```
Loss = Reconstruction_Loss + lambda * Sparsity_Penalty
```

In [None]:
class SparseAutoencoder(nn.Module):
    """Autoencoder with sparse latent representation."""
    
    def __init__(self, input_dim=784, hidden_dim=256, latent_dim=64):
        super().__init__()
        
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim),
            # No activation - we'll apply sparsity constraint
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid(),
        )
        
        self.latent_dim = latent_dim
    
    def encode(self, x):
        return self.encoder(x)
    
    def decode(self, z):
        return self.decoder(z)
    
    def forward(self, x):
        z = self.encode(x)
        x_recon = self.decode(z)
        return x_recon, z


def sparsity_penalty(z, target_sparsity=0.05):
    """KL divergence sparsity penalty.
    
    Encourages average activation to match target_sparsity.
    """
    # Average activation per neuron across batch
    rho_hat = torch.sigmoid(z).mean(dim=0)
    rho = torch.tensor(target_sparsity).to(z.device)
    
    # KL divergence
    kl = rho * torch.log(rho / (rho_hat + 1e-8)) + \
         (1 - rho) * torch.log((1 - rho) / (1 - rho_hat + 1e-8))
    
    return kl.sum()


def train_sparse_autoencoder(model, train_loader, epochs=10, lr=0.001, 
                              sparsity_weight=0.1, target_sparsity=0.05):
    """Train with sparsity constraint."""
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    
    history = {'loss': [], 'recon_loss': [], 'sparsity_loss': []}
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        total_recon = 0
        total_sparse = 0
        
        for data, _ in train_loader:
            data = data.view(data.size(0), -1).to(device)
            
            recon, z = model(data)
            
            # Reconstruction loss
            recon_loss = criterion(recon, data)
            
            # Sparsity penalty
            sparse_loss = sparsity_penalty(z, target_sparsity)
            
            # Combined loss
            loss = recon_loss + sparsity_weight * sparse_loss
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            total_recon += recon_loss.item()
            total_sparse += sparse_loss.item()
        
        n_batches = len(train_loader)
        history['loss'].append(total_loss / n_batches)
        history['recon_loss'].append(total_recon / n_batches)
        history['sparsity_loss'].append(total_sparse / n_batches)
        
        print(f"Epoch {epoch+1}/{epochs} - Loss: {total_loss/n_batches:.4f} "
              f"(Recon: {total_recon/n_batches:.4f}, Sparse: {total_sparse/n_batches:.4f})")
    
    return history

# Train sparse autoencoder
sparse_ae = SparseAutoencoder(latent_dim=64).to(device)
print("\nTraining sparse autoencoder...")
sparse_history = train_sparse_autoencoder(sparse_ae, train_loader, epochs=10)
print("[OK] Training complete!")

In [None]:
def compare_sparsity(model1, model2, test_loader, name1='Basic', name2='Sparse'):
    """Compare activation sparsity of two models."""
    model1.eval()
    model2.eval()
    
    data, _ = next(iter(test_loader))
    data = data.view(data.size(0), -1).to(device)
    
    with torch.no_grad():
        z1 = model1.encode(data)
        z2 = model2.encode(data)
    
    # Compute sparsity (fraction of near-zero activations)
    threshold = 0.1
    sparsity1 = (torch.abs(z1) < threshold).float().mean().item()
    sparsity2 = (torch.abs(torch.sigmoid(z2) - 0.5) > 0.4).float().mean().item()
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    
    # Histogram of activations
    axes[0].hist(z1.cpu().numpy().flatten(), bins=50, alpha=0.7, label=name1)
    axes[0].hist(z2.cpu().numpy().flatten(), bins=50, alpha=0.7, label=name2)
    axes[0].set_xlabel('Activation Value')
    axes[0].set_ylabel('Count')
    axes[0].set_title('Latent Activation Distribution')
    axes[0].legend()
    
    # Sample latent vectors as heatmap
    combined = torch.cat([torch.sigmoid(z1[:8]), torch.sigmoid(z2[:8])], dim=0)
    im = axes[1].imshow(combined.cpu().numpy(), aspect='auto', cmap='viridis')
    axes[1].set_xlabel('Latent Dimension')
    axes[1].set_ylabel('Sample')
    axes[1].set_title('Latent Activations (top: Basic, bottom: Sparse)')
    axes[1].axhline(y=7.5, color='red', linestyle='--', linewidth=2)
    plt.colorbar(im, ax=axes[1])
    
    plt.tight_layout()
    plt.show()
    
    print(f"{name1} model - Near-zero activations: {sparsity1*100:.1f}%")
    print(f"{name2} model - Sparse activations: {sparsity2*100:.1f}%")

compare_sparsity(autoencoder, sparse_ae, test_loader)

## 8. Exercise: Implement a Convolutional Autoencoder

**TODO:** Complete the convolutional autoencoder below. This architecture preserves spatial structure better than fully-connected layers.

In [None]:
class ConvAutoencoder(nn.Module):
    """Convolutional autoencoder for image data.
    
    TODO: Complete the encoder and decoder architectures.
    
    Encoder should:
    - Use Conv2d layers to downsample
    - End with a flatten and linear to latent_dim
    
    Decoder should:
    - Start with linear from latent_dim
    - Use ConvTranspose2d to upsample
    - Output same size as input (1, 28, 28)
    """
    
    def __init__(self, latent_dim=32):
        super().__init__()
        
        # TODO: Implement encoder
        # Hint: Conv2d(1, 32, 3, stride=2, padding=1) -> (32, 14, 14)
        #       Conv2d(32, 64, 3, stride=2, padding=1) -> (64, 7, 7)
        #       Flatten -> Linear(64*7*7, latent_dim)
        self.encoder = nn.Sequential(
            # Your code here
            nn.Identity()  # Placeholder
        )
        
        # TODO: Implement decoder
        # Hint: Linear(latent_dim, 64*7*7) -> Reshape(64, 7, 7)
        #       ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1)
        #       ConvTranspose2d(32, 1, 3, stride=2, padding=1, output_padding=1)
        self.decoder = nn.Sequential(
            # Your code here
            nn.Identity()  # Placeholder
        )
        
        self.latent_dim = latent_dim
    
    def forward(self, x):
        z = self.encoder(x)
        x_recon = self.decoder(z)
        return x_recon, z

# Test your implementation
# conv_ae = ConvAutoencoder(latent_dim=32).to(device)
# test_input = torch.randn(4, 1, 28, 28).to(device)
# output, latent = conv_ae(test_input)
# print(f"Input shape: {test_input.shape}")
# print(f"Output shape: {output.shape}")
# print(f"Latent shape: {latent.shape}")

## 9. Connection to Multi-Modal Learning

Autoencoders are foundational for multi-modal learning because:

1. **Shared Latent Space**: Different modalities can be encoded to the same latent space
2. **Cross-Modal Translation**: Decode from one modality's latent to another's output
3. **Representation Alignment**: Learn representations that align across modalities

In our capstone architecture, each modality encoder can be thought of as the encoder half of an autoencoder!

In [None]:
# Preview: Multi-modal autoencoder concept
print("""
Multi-Modal Autoencoder Architecture:

    Image  -->  [Visual Encoder]  --+
                                    |                 +--> [Visual Decoder] --> Image'
    Audio  -->  [Audio Encoder]   --+--> [Shared Z] --+
                                    |                 +--> [Audio Decoder] --> Audio'
    Text   -->  [Text Encoder]    --+                 +--> [Text Decoder]  --> Text'

The shared latent space Z enables:
- Cross-modal retrieval (find audio matching an image)
- Modal translation (generate image from text)
- Joint representation learning
""")

## 10. Summary

**Key Takeaways:**

1. **Autoencoders** learn compressed representations without labels
2. **Latent space** captures essential data structure
3. **Sparse autoencoders** encourage interpretable, sparse representations
4. **Connection to multi-modal**: Each modality encoder learns a representation

**Next Module:** We'll dive into dynamic sparse training with SET and DEEP R algorithms.

---

**[->] Continue to Module 07: Dynamic Sparse Training**