# Import pacakage

In [3]:
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision

# Setting 

In [4]:
# Settings
epochs = 10
batch_size = 128
lr = 0.008


# DataLoader
train_set = torchvision.datasets.MNIST(
    root='mnist',
    train=True,
    download=True,
    transform=torchvision.transforms.ToTensor(),
)

train_loader = data.DataLoader(train_set, batch_size=batch_size, shuffle=True)

# Model structure

In [5]:
# AutoEncoder (Encoder + Decoder)
class AutoEncoder(nn.Module):
    '''
    MNISR image shape = (1,28,28)
    784 = 28*28 
    '''
    def __init__(self):
        super(AutoEncoder, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(784, 128),
            nn.Tanh(),
            nn.Linear(128, 64),
            nn.Tanh(),
            nn.Linear(64, 16),
            nn.Tanh(),
            nn.Linear(16, 2),
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(2, 16),
            nn.Tanh(),
            nn.Linear(16, 64),
            nn.Tanh(),
            nn.Linear(64, 128),
            nn.Tanh(),
            nn.Linear(128, 784),
            nn.Sigmoid()
        )

    def forward(self, inputs):
        codes = self.encoder(inputs)
        decoded = self.decoder(codes)

        return codes, decoded

# Optimizer and loss function

In [6]:
# Optimizer and loss function
model = AutoEncoder()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_function = nn.MSELoss()

# Training

In [12]:
for data, labels in train_loader:
    print(data.shape,labels.shape)
    break

torch.Size([128, 1, 28, 28]) torch.Size([128])


In [13]:
# Train
for epoch in range(epochs):
    for data, labels in train_loader:
        inputs = data.view(-1, 784)

        # Forward
        codes, decoded = model(inputs)

        # Backward
        optimizer.zero_grad()
        loss = loss_function(decoded, inputs)
        loss.backward()
        optimizer.step()

    # Show progress
    print('[{}/{}] Loss:'.format(epoch+1, epochs), loss.item())

[1/10] Loss: 0.053307946771383286
[2/10] Loss: 0.044026024639606476
[3/10] Loss: 0.042683281004428864
[4/10] Loss: 0.0400223545730114
[5/10] Loss: 0.04179670289158821
[6/10] Loss: 0.040596917271614075
[7/10] Loss: 0.040502242743968964
[8/10] Loss: 0.039834607392549515
[9/10] Loss: 0.040662188082933426
[10/10] Loss: 0.040009863674640656


# Save Model

In [14]:
# Save
torch.save(model, 'autoencoder.pth')

# Import Model