In [1]:
#Imports

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

In [2]:
transform = transforms.ToTensor()

train_dataset = datasets.MNIST(root ='./data', download = True, train = True, transform = transform)
test_dataset = datasets.MNIST(root='./data', download=True, train=False, transform=transform)
val_dataset = datasets.MNIST(root='./data', download=True, train=False, transform=transform)

In [3]:
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)
val_loader = DataLoader(dataset=val_dataset, batch_size=64, shuffle=False)

In [4]:
dataiter = iter(train_loader)

images, labels = next(dataiter)

print(torch.min(images), torch.max(images))

tensor(0.) tensor(1.)


In [5]:
#Create Network

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 392),   # --> N,784 to N, 392
            nn.ReLU(),
            nn.Linear(392, 196)    # N,196
        ) 
       
        self.decoder = nn.Sequential(
            nn.Linear(196, 392),  # --> N,196 to N, 392
            nn.ReLU(),
            nn.Linear(392, 28*28),    # N,784
            nn.Sigmoid()
            
        )
        self.dropout = nn.Dropout(p=0.5)    
            
    def forward(self, x):
        encoded = self.encoder(x)
        encoded = self.dropout(encoded)
        decoded = self.decoder(encoded)
        return decoded
    
    def evaluate(self, x):
    encoded = self.encoder(x)
    decoded = self.decoder(encoded)
    return decoded

In [6]:
model = Autoencoder()

criterion = nn.MSELoss()

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

num_epochs = 50


train_losses = []
val_losses = []

for epoch in range(num_epochs):
    running_loss = 0.0
    for inputs, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs = inputs.reshape(-1, 28*28)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, inputs)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    epoch_loss = running_loss / len(train_loader)
    train_losses.append(epoch_loss)

    if epoch % 10 == 9:
        print(f"Epoch {epoch + 1}, Training Loss: {epoch_loss}")

    # Validation loss
    val_running_loss = 0.0
    for val_data in val_loader:
        val_inputs, val_labels = val_data
        val_inputs = val_inputs.reshape(-1, 28*28)
        val_outputs = model(val_inputs)
        val_loss = criterion(val_outputs, val_inputs)
        val_running_loss += val_loss.item()

    val_epoch_loss = val_running_loss / len(val_loader)
    val_losses.append(val_epoch_loss)

    if epoch % 10 == 9:
        print(f"Epoch {epoch + 1}, Validation Loss: {val_epoch_loss}")
    

Epoch 10, Training Loss: 0.012863974796850353
Epoch 10, Validation Loss: 0.012573721658462171
Epoch 20, Training Loss: 0.012100392907881724
Epoch 20, Validation Loss: 0.011878984387086075
Epoch 30, Training Loss: 0.011833279907925804
Epoch 30, Validation Loss: 0.011800331994891167
Epoch 40, Training Loss: 0.011663642814204192
Epoch 40, Validation Loss: 0.01152657856868141
Epoch 50, Training Loss: 0.011577261872947026
Epoch 50, Validation Loss: 0.011307002709597159


In [1]:
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

NameError: name 'plt' is not defined

In [None]:
with torch.no_grad():
    model.eval()
    for inputs, _ in test_loader:
        inputs = inputs.reshape(-1, 28*28)
        outputs = model(inputs)
        break  # Break after processing one batch

# Plot some original and reconstructed images
plt.figure(figsize=(10, 2))
for i in range(5):
    plt.subplot(2, 5, i + 1)
    plt.imshow(inputs[i].reshape(28, 28).cpu().numpy(), cmap='gray')
    plt.title('Original')
    plt.axis('off')

    plt.subplot(2, 5, i + 6)
    plt.imshow(outputs[i].reshape(28, 28).cpu().numpy(), cmap='gray')
    plt.title('Reconstructed')
    plt.axis('off')

plt.show()