In [2]:
## https://medium.com/the-innovation/autoencoder-in-pytorch-for-the-fashion-mnist-dataset-66f4fb9465b4

In [11]:
import numpy as np
import matplotlib.pyplot as plt

import time

import torch
from torch import nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import ExponentialLR

In [21]:
train = torchvision.datasets.FashionMNIST(root = "./fashion_mnist_data", train = True, download = True, transform = transforms.ToTensor())
test = torchvision.datasets.FashionMNIST(root = "./fashion_mnist_data", train = False, download = True, transform = transforms.ToTensor())

train_loader = torch.utils.data.DataLoader(train, batch_size=8, shuffle = True)
test_loader = torch.utils.data.DataLoader(test, batch_size=8, shuffle=False)

In [22]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()        
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=(3,3)),
            nn.LeakyReLU(),
            nn.Conv2d(8, 16, kernel_size=(3,3)),
            nn.LeakyReLU(),
            nn.Conv2d(16, 32, kernel_size=(3,3)),
            nn.LeakyReLU(),
            nn.Conv2d(32, 64, kernel_size=(3,3))
        )

    def forward(self, x):
        return self.encoder(x)

class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.decoder = nn.Sequential(
            nn.ConvTranspose1d(64, 32, kernel_size=(3,3)),
            nn.LeakyReLU(),
            nn.ConvTranspose1d(32, 16, kernel_size=(3,3)),
            nn.LeakyReLU(),
            nn.ConvTranspose1d(16, 8, kernel_size=(3,3)),
            nn.LeakyReLU(),
            nn.ConvTranspose1d(8, 1, kernel_size=(3,3)),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.decoder(x)

class Autoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [23]:
model = Autoencoder()

model_name = f'ae_{time.time()}'
print(model_name)

# Tensorboard
tb_writer = SummaryWriter('logs/' + model_name)

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
scheduler = ExponentialLR(optimizer, gamma=0.999)

# Loss function
loss_func = nn.MSELoss()

ae_1605976851.892132


In [24]:
epochs = 10

for epoch in range(epochs):
    print(f'Epoch: {epoch}')
    current_time = time.time()

    # Train
    model.train()
    loss_tr = []
    for step, (images_raw, images_out) in enumerate(train_loader):
        print(step, images_raw.shape, images_out)
        p = model(images_raw)
        batch_loss = loss_func(images_out, p)
        loss_tr.append(batch_loss.detach().item())
        optimizer.zero_grad()
        batch_loss.backward()
        optimizer.step()
    print(f"{epoch}, training_loss {np.mean(loss_tr)}, {time.time() - current_time} secs")
    current_time = time.time()
    
    # Validation
    model.eval()
    loss_ts = []
    for step, (images_raw, images_out) in enumerate(test_loader):
        p = model(images_raw.float())
        batch_loss = loss_func(images_out, p)       
        loss_ts.append(batch_loss.detach().cpu().numpy())
    print(f"{epoch}, validation_loss {np.mean(loss_ts)}, {time.time() - current_time} secs")
    scheduler.step()

    tb_writer.add_scalar("Test Loss", np.mean(loss_ts), epoch)
    tb_writer.add_scalar("Training Loss", np.mean(loss_tr), epoch)
    tb_writer.add_scalar("Learning Rate", scheduler.get_lr()[0], epoch)
    img_grid = torchvision.utils.make_grid(images_out[:4])
    tb_writer.add_image('orig_fashion_mnist_images', img_grid)
    img_grid = torchvision.utils.make_grid(p[:4])
    tb_writer.add_image('recons_fashion_mnist_images', img_grid)
    tb_writer.flush()

Epoch: 0
0 torch.Size([8, 1, 28, 28]) tensor([7, 5, 1, 1, 0, 4, 7, 4])


RuntimeError: The size of tensor a (8) must match the size of tensor b (28) at non-singleton dimension 3