In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tqdm import tqdm
import os

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Make the `images` directory
def make_dir():
    image_dir = '../outputs/images'
    if not os.path.exists(image_dir):
        os.makedirs(image_dir)
make_dir()

# Sparse Autoencoder Model
class SparseAutoencoder(nn.Module):
    def __init__(self):
        super(SparseAutoencoder, self).__init__()
        self.enc1 = nn.Linear(784, 3136)
        self.enc2 = nn.Linear(3136, 1568)
        self.dec1 = nn.Linear(1568, 3136)
        self.dec2 = nn.Linear(3136, 784)

    def forward(self, x):
        x = F.relu(self.enc1(x))
        x = F.relu(self.enc2(x))
        x = F.relu(self.dec1(x))
        x = F.relu(self.dec2(x))
        return x

# Function to save decoded images
def save_decoded_image(img, name):
    img = img.view(img.size(0), 1, 28, 28)
    save_image(img, name)

# Sparse loss function
def sparse_loss(model, inputs):
    loss = 0
    values = inputs
    for layer in model.children():
        values = layer(values)
        loss += torch.mean(torch.abs(values))
    return loss

# Training function
def train(model, train_loader, criterion, optimizer, k_sparse):
    model.train()
    running_loss = 0.0

    for inputs, _ in tqdm(train_loader, desc=f"Training", leave=False):
        inputs = inputs.view(-1, 28 * 28).to(device)
        optimizer.zero_grad()

        outputs = model(inputs)

        # Reconstruction loss
        recon_loss = criterion(outputs, inputs)

        # Sparse loss
        sparse_loss_value = sparse_loss(model, inputs)

        # Total loss
        loss = recon_loss + 0.1 * sparse_loss_value  # Adjust the weight of the sparse loss

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    epoch_loss = running_loss / len(train_loader)
    return epoch_loss

# Hyperparameters
num_epochs = 30
batch_size = 64
learning_rate = 1e-3
weight_decay = 1e-5
k_sparse = 10  # Number of neurons to keep active

# DataLoader
transform = transforms.ToTensor()
train_dataset = datasets.MNIST(root='./data', download=True, train=True, transform=transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = datasets.MNIST(root='./data', download=True, train=False, transform=transform)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

# Model, optimizer, and criterion
model = SparseAutoencoder().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
criterion = nn.MSELoss()

# Training loop
train_losses = []
for epoch in range(num_epochs):
    train_loss = train(model, train_loader, criterion, optimizer, k_sparse)
    train_losses.append(train_loss)

    if epoch % 10 == 9:
        print(f"Epoch {epoch + 1}, Training Loss: {train_loss}")

# Save and visualize reconstructed images
with torch.no_grad():
    model.eval()
    for batch in test_loader:
        inputs, _ = batch
        inputs = inputs.view(-1, 28 * 28).to(device)
        outputs = model(inputs)
        save_decoded_image(inputs.cpu().data, f"../outputs/images/input_image.png")
        save_decoded_image(outputs.cpu().data, f"../outputs/images/reconstructed_image.png")
        break  # Save only one batch of images for brevity

# Print final results
print("Training complete!")
print(f"Final Training Loss: {train_losses[-1]}")

# Plot the training loss curve
import matplotlib.pyplot as plt

plt.plot(train_losses, label='Training Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

                                                           

Epoch 10, Training Loss: 0.04995959337506849


                                                           

Epoch 20, Training Loss: 0.10429519001068845


                                                           

Epoch 30, Training Loss: 0.10896705435727959
Training complete!
Final Training Loss: 0.10896705435727959
