In [9]:
import tqdm

import numpy as np
import torch
from torch import nn
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch.optim import Adam
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt

In [2]:
class Autoencoder(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
    
    def forward(self, x):
        """
        Take a mini-batch as an input, encode it to the latent space and decode back to the original space
        x_out = decoder(encoder(x))
        :param x: torch.tensor, (MB, x_dim)
        :return: torch.tensor, (MB, x_dim)
        """
        return decoder(encoder(x))

In [3]:
def train(epochs, net, criterion, optimizer, train_loader, val_loader,scheduler=None, verbose=True, save_dir=None):
    net.to(device)
    for epoch in range(1, epochs+1):
        net.train()
        loss_avg = 0
        for X, _ in train_loader:
            X = X.to(device)
            optimizer.zero_grad()
            X_new = net(X)
            loss = criterion(X, X_new)
            loss_avg += loss.item()
            loss.backward()
            optimizer.step()
        loss_avg /= len(train_loader)
        
        # define NN evaluation, i.e. turn off dropouts, batchnorms, etc.
        net.eval()
        val_loss = 0
        with torch.no_grad():
            for X, _ in val_loader:
                X = X.to(device)
                X_new = net(X)
                val_loss += criterion(X, X_new)
        val_loss /= len(val_loader)
         
        if scheduler is not None:
            scheduler.step(val_loss)
        freq = max(epochs//20,1)
        if verbose and epoch%freq==0:
            print('Epoch {}/{} || Loss:  Train {:.4f} | Validation {:.4f}'.format(epoch, epochs, loss_avg, val_loss.item()))

In [4]:
class Reshape(nn.Module):
    def __init__(self, *shape):
        super().__init__()
        self.shape = shape
        
    def forward(self, input):
        return input.view(*self.shape)
    
class Debugger(nn.Module):
    def __init__(self, name=""):
        super().__init__()
        self.name = name
        
    def forward(self, input):
        print(self.name, input.shape)
        return input
    
class Interpolator(nn.Module):
    def __init__(self, *shape, mode="bilinear"):
        super().__init__()
        self.shape = shape
        self.mode = mode
        
    def forward(self, input):
        return torch.nn.functional.interpolate(input, self.shape, mode=self.mode,
                                               align_corners=False)

In [5]:
torch.manual_seed(1)

n_hidden = 30
encoder = nn.Sequential(nn.Conv2d(1, 8, 3, stride=1, padding=1),
                        nn.BatchNorm2d(8),
                        nn.SELU(),
                        nn.Conv2d(8, 16, 3, stride=2, padding=1),
                        nn.BatchNorm2d(16),
                        nn.SELU(),
                        nn.Conv2d(16, 32, 3, stride=2, padding=1),
                        nn.BatchNorm2d(32),
                        nn.SELU(),
                        nn.Conv2d(32, 64, 3, stride=2, padding=1),
                        nn.BatchNorm2d(64),
                        nn.SELU(),
                        nn.Conv2d(64, 128, 3, stride=2, padding=1),
                        nn.BatchNorm2d(128),
                        nn.SELU(),
                        nn.Conv2d(128, 256, 3, stride=2, padding=1),
                        nn.BatchNorm2d(256),
                        nn.SELU(),
                        Reshape(-1, 256*1**2),
                        nn.Linear(256*1**2, n_hidden)
                       )  
decoder = nn.Sequential(nn.Linear(n_hidden, 256),
                        Reshape(-1, 256, 1, 1),
                        nn.ConvTranspose2d(256, 128, kernel_size=3,
                                           stride=2, padding=0),
                        nn.BatchNorm2d(128),
                        nn.SELU(),
                        nn.ConvTranspose2d(128, 64, kernel_size=3,
                                           stride=2, padding=0),
                        nn.BatchNorm2d(64),
                        nn.SELU(),
                        nn.ConvTranspose2d(64, 32, kernel_size=3,
                                           stride=2, padding=0),
                        nn.BatchNorm2d(32),
                        nn.SELU(),
                        nn.ConvTranspose2d(32, 16, kernel_size=3,
                                           stride=2, padding=0),
                        nn.BatchNorm2d(16),
                        nn.SELU(),
                        nn.ConvTranspose2d(16, 8, kernel_size=3,
                                           stride=2, padding=0),
                        nn.BatchNorm2d(8),
                        nn.SELU(),
                        nn.Conv2d(8, 1, kernel_size=3,
                                           stride=2, padding=0),
                        nn.Sigmoid(),
                        Interpolator(28, 28)
)

In [6]:
mnist_tr = transforms.Compose([transforms.ToTensor(), 
                               ])

In [7]:
mnist_train = MNIST('data/mnist', download=True, transform=mnist_tr, train=True)
mnist_test = MNIST('data/mnist', download=True, transform=mnist_tr, train=False)

In [12]:
try:
    device = torch.device("cuda")
except:
    device = torch.device("cpu")

net = Autoencoder(encoder, decoder)  
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5)

train_loader = DataLoader(mnist_train, batch_size=1000, shuffle=True,
                                     pin_memory=True) 
val_loader = DataLoader(mnist_test, batch_size=500, shuffle=False,
                                   pin_memory=True) 

In [13]:
epochs = 20
train(epochs, net, criterion, optimizer, train_loader, val_loader, scheduler)

AssertionError: Torch not compiled with CUDA enabled