# Autoencoders


In [53]:
import numpy as np
import torch 
import torch.nn as nn 
import torch.optim as optim
from torch.utils.data import DataLoader 
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

In [22]:
class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 128), 
            nn.ReLU(inplace=True), 
            nn.Linear(128, 64),
            nn.ReLU(True),
            nn.Linear(64, 12), 
            nn.ReLU(True),
            nn.Linear(12, 3)
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(3, 12),
            nn.ReLU(True),
            nn.Linear(12, 64),
            nn.ReLU(True),
            nn.Linear(64, 128), 
            nn.ReLU(True), 
            nn.Linear(128, 28*28), 
            nn.Sigmoid()
        )
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x
        

In [8]:
transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Normalize((0.5,), (0.5,))
])

In [30]:

train_dataset = datasets.MNIST(root='./data', train=True, download=False, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

In [25]:
model = AutoEncoder()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [50]:
num_epochs = 2

for epoch in range(num_epochs):
    for data in train_loader:
        img, _ = data
        #print('imgsize', img.size(0), img[0].shape, img.shape)
        img = img.reshape(img.size(0), -1)
        #print(img.shape)
        output = model(img)
        loss = criterion(output, img)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

Epoch [1/2], Loss: 0.9241
Epoch [2/2], Loss: 0.9237


In [54]:
def imshow(img):
    img = img / 2 + 0.5 
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)), cmap='gray') 