In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
batch_size = 64
mnist_data = datasets.MNIST(root='.datasets', train=True, download= True, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(dataset=mnist_data, batch_size=batch_size, shuffle=True, num_workers=4)

In [None]:
data, targets = next(iter(train_loader))
print(torch.min(data), torch.max(data))

In [None]:
device = 'cuda'

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self,input_size, hidden_size):
        super(AutoEncoder, self).__init__()
        # Batch_size, 784 (28x28)
        self.encoder = nn.Sequential(
            nn.Linear(input_size, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, hidden_size), # Batch, output_size
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(hidden_size, 32),
            nn.ReLU(),
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, input_size), # Batch, input_size
            nn.Sigmoid()
        )
        
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

In [None]:
input_size = 784
hidden_size = 16
lr = 1e-3
epochs = 40
model = AutoEncoder(input_size, hidden_size).to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)

In [None]:
outputs = []
for epoch in range(epochs):
    loop = tqdm(enumerate(train_loader), total= len(train_loader), leave=False)
    for batch_idx, (data, _) in loop:
        
        data = data.to(device)
        data = data.reshape(-1, input_size)
        reconstruction = model(data)
        
        loss = criterion(reconstruction, data)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        loop.set_description(f'Epoch [{epoch}]')
        loop.set_postfix(loss= loss.item())
    outputs.append((epoch, data, reconstruction))

In [None]:
for k in range(0, epochs, 1):
    plt.figure(figsize=(9, 2))
    plt.gray()
    imgs = outputs[k][1].detach().cpu().numpy()
    recon = outputs[k][2].detach().cpu().numpy()
    for i, item in enumerate(imgs):
        if i >= 9: break
        plt.subplot(2, 9, i+1)
        item = item.reshape(-1, 28,28) # -> use for Autoencoder_Linear
        # item: 1, 28, 28
        plt.imshow(item[0])
            
    for i, item in enumerate(recon):
        if i >= 9: break
        plt.subplot(2, 9, 9+i+1) # row_length + i + 1
        item = item.reshape(-1, 28,28) # -> use for Autoencoder_Linear
        # item: 1, 28, 28
        plt.imshow(item[0])

In [None]:
class AutoEncoderCNN(nn.Module):
    def __init__(self, hidden_size):
        super(AutoEncoderCNN, self).__init__()
        # Batch_size, 28, 28
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1), # 14X14X16
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 7X7X32
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 4X4X64
            nn.ReLU(),
            nn.Conv2d(64, hidden_size, kernel_size=4) # 1X1Xhidden_size
        )
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(hidden_size, 64, kernel_size=4, stride=1, padding=0), #4x4x64
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1), # 7x7x32
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1), # 14x14x16
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, kernel_size=4, stride=2, padding=1), # 14x14x16
            nn.Sigmoid()
        )
        
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

In [None]:
x = torch.randn((16, 1, 28, 28))
model = AutoEncoderCNN(128)
print(model(x).shape)

In [None]:
hidden_size = 128
lr = 1e-3
epochs = 10
model = AutoEncoderCNN(hidden_size).to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)

In [None]:
outputs = []
for epoch in range(epochs):
    loop = tqdm(enumerate(train_loader), total= len(train_loader), leave=False)
    for batch_idx, (data, _) in loop:
        
        data = data.to(device)
        reconstruction = model(data)
        
        loss = criterion(reconstruction, data)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        loop.set_description(f'Epoch [{epoch}]')
        loop.set_postfix(loss= loss.item())
    outputs.append((epoch, data, reconstruction))

In [None]:
for k in range(0, epochs, 4):
    plt.figure(figsize=(9, 2))
    plt.gray()
    imgs = outputs[k][1].detach().cpu().numpy()
    recon = outputs[k][2].detach().cpu().numpy()
    for i, item in enumerate(imgs):
        if i >= 9: break
        plt.subplot(2, 9, i+1)
        # item = item.reshape(-1, 28,28) # -> use for Autoencoder_Linear
        # item: 1, 28, 28
        plt.imshow(item[0])
            
    for i, item in enumerate(recon):
        if i >= 9: break
        plt.subplot(2, 9, 9+i+1) # row_length + i + 1
        # item = item.reshape(-1, 28,28) # -> use for Autoencoder_Linear
        # item: 1, 28, 28
        plt.imshow(item[0])