In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

# Data loading 
transform = transforms.ToTensor()
train_dataset = datasets.MNIST(root='./data', download=True, train=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', download=True, train=False, transform=transform)
val_dataset = datasets.MNIST(root='./data', download=True, train=False, transform=transform)

train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)
val_loader = DataLoader(dataset=val_dataset, batch_size=64, shuffle=False)

# k-Sparse Autoencoder 
class KSparseAutoencoder(nn.Module):
    def __init__(self, k):
        super(KSparseAutoencoder, self).__init__()
        self.k = k
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 4 * self.k),  # N, 784 to N, 40
            nn.ReLU(),
            nn.Linear(4 * self.k, self.k)  # N, 40 to N, 10
        )

        self.decoder = nn.Sequential(
            nn.Linear(self.k, 4 * self.k),  # N, 10 to N, 40
            nn.ReLU(),
            nn.Linear(4 * self.k, 28 * 28),  # N, 40 to N, 784
            nn.Sigmoid()
        )

    def forward(self, x):
        encoded = self.encoder(x)
        
       
        values, indices = torch.topk(encoded, self.k, dim=1)
        
        
        mask = torch.zeros_like(encoded)
        mask.scatter_(1, indices, 1)
        
        
        sparse_encoded = encoded * mask
        
        decoded = self.decoder(sparse_encoded)
        return decoded, mask

# Model training
k_sparse_model = KSparseAutoencoder(k=10)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(k_sparse_model.parameters(), lr=1e-3, weight_decay=1e-5)

num_epochs = 50

train_losses = []
val_losses = []

for epoch in range(num_epochs):
    running_loss = 0.0
    for inputs, _ in train_loader:
        inputs = inputs.view(-1, 28 * 28)
        optimizer.zero_grad()
        outputs, _ = k_sparse_model(inputs)
        loss = criterion(outputs, inputs)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    epoch_loss = running_loss / len(train_loader)
    train_losses.append(epoch_loss)

    if epoch % 10 == 9:
        print(f"Epoch {epoch + 1}, Training Loss: {epoch_loss}")

    # Validation loss
    val_running_loss = 0.0
    for val_inputs, _ in val_loader:
        val_inputs = val_inputs.view(-1, 28 * 28)
        val_outputs, _ = k_sparse_model(val_inputs)
        val_loss = criterion(val_outputs, val_inputs)
        val_running_loss += val_loss.item()

    val_epoch_loss = val_running_loss / len(val_loader)
    val_losses.append(val_epoch_loss)

    if epoch % 10 == 9:
        print(f"Epoch {epoch + 1}, Validation Loss: {val_epoch_loss}")

Epoch 10, Training Loss: 0.023122952289895207
Epoch 10, Validation Loss: 0.02257089330488519
Epoch 20, Training Loss: 0.021570703253022898
Epoch 20, Validation Loss: 0.02105235578907523
Epoch 30, Training Loss: 0.02089336153858506
Epoch 30, Validation Loss: 0.02050130059171444
Epoch 40, Training Loss: 0.020500831979154144
Epoch 40, Validation Loss: 0.020077689325401358
Epoch 50, Training Loss: 0.020321921794923512
Epoch 50, Validation Loss: 0.019969520046357894
