# Creating a simple sparse auto-encoder

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

In [5]:
# Define Sparse Autoencoder
class SparseAutoencoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, sparsity_lambda=1e-3):
        super(SparseAutoencoder, self).__init__()
        # Encoder layer: Reduces input dimensions
        self.encoder = nn.Linear(input_dim, hidden_dim)
        # Decoder layer: Reconstructs input from encoded representation
        self.decoder = nn.Linear(hidden_dim, input_dim)
        self.sparsity_lambda = sparsity_lambda  # Regularization strength
        
    def forward(self, x):
        # Encode input using ReLU activation
        encoded = torch.relu(self.encoder(x))
        # Decode back to input space
        decoded = self.decoder(encoded)
        return encoded, decoded
    
    def sparsity_loss(self, encoded):
        # Compute average activation of hidden units
        rho_hat = torch.mean(encoded, dim=0)
        rho = 0.05  # Desired average activation
        # KL divergence for sparsity constraint
        return self.sparsity_lambda * torch.sum(rho * torch.log(rho / rho_hat) + (1 - rho) * torch.log((1 - rho) / (1 - rho_hat)))

In [6]:
# Hyperparameters
input_dim = 20  # Input feature dimension
hidden_dim = 5   # Hidden layer size
lr = 0.01        # Learning rate
epochs = 100     # Number of training iterations

In [8]:
# Initialize model, loss function, and optimizer
model = SparseAutoencoder(input_dim, hidden_dim)
criterion = nn.MSELoss()  # Mean Squared Error loss for reconstruction
optimizer = optim.Adam(model.parameters(), lr=lr)

In [9]:
# Dummy Data (Random tensor simulating input data)
X = torch.rand(100, input_dim)

In [10]:
# Training Loop
for epoch in range(epochs):
    optimizer.zero_grad()  # Reset gradients
    encoded, decoded = model(X)  # Forward pass
    loss = criterion(decoded, X) + model.sparsity_loss(encoded)  # Compute loss with sparsity constraint
    loss.backward()  # Backpropagation
    optimizer.step()  # Update weights
    
    if epoch % 10 == 0:
        print(f'Epoch {epoch}, Loss: {loss.item():.4f}')  # Print progress

Epoch 0, Loss: 0.4296
Epoch 10, Loss: inf
Epoch 20, Loss: inf
Epoch 30, Loss: inf
Epoch 40, Loss: inf
Epoch 50, Loss: inf
Epoch 60, Loss: inf
Epoch 70, Loss: inf
Epoch 80, Loss: inf
Epoch 90, Loss: inf
