<a href="https://colab.research.google.com/github/MehrdadDastouri/mnist_autoencoder/blob/main/mnist_autoencoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Hyperparameters
batch_size = 64
learning_rate = 0.001
epochs = 5

# Data preparation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load MNIST dataset
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

# Define the Autoencoder model
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 32)
        )
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 28 * 28),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

# Initialize the model
model = Autoencoder().to(device)

# Define the loss function and optimizer
criterion = nn.MSELoss()  # Mean Squared Error for reconstruction
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training the model
losses = []
for epoch in range(epochs):
    model.train()
    epoch_loss = 0
    for images, _ in train_loader:
        images = images.view(-1, 28 * 28).to(device)  # Flatten the images
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, images)  # Compare reconstructed and original images
        epoch_loss += loss.item()

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    losses.append(epoch_loss / len(train_loader))
    print(f"Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss / len(train_loader):.4f}")

# Plot training loss
plt.figure(figsize=(8, 6))
plt.plot(range(1, epochs + 1), losses, label="Training Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Training Loss Over Epochs")
plt.legend()
plt.show()

# Visualize reconstructed images
model.eval()
with torch.no_grad():
    examples = iter(train_loader)
    example_data, _ = next(examples)
    example_data = example_data.view(-1, 28 * 28).to(device)
    reconstructed = model(example_data).view(-1, 1, 28, 28).cpu()

plt.figure(figsize=(12, 6))
for i in range(10):
    # Original images
    plt.subplot(2, 10, i + 1)
    plt.imshow(example_data[i].view(28, 28).cpu(), cmap="gray")
    plt.axis("off")
    # Reconstructed images
    plt.subplot(2, 10, i + 11)
    plt.imshow(reconstructed[i].view(28, 28), cmap="gray")
    plt.axis("off")
plt.suptitle("Original Images (Top) and Reconstructed Images (Bottom)")
plt.show()