In [4]:
import torch, torchvision
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor

from torch import nn, optim
from torch.nn import functional as F
from torch.autograd import Variable
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import DataLoader, random_split

import numpy as np
import matplotlib.pyplot as plt

import random
import pandas as pd
import seaborn as sns

In [2]:
train_dataset = datasets.MNIST("dataset", train = True, download = True)
test_dataset = datasets.MNIST("dataset", train = False, download = True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to dataset\MNIST\raw\train-images-idx3-ubyte.gz


100.0%


Extracting dataset\MNIST\raw\train-images-idx3-ubyte.gz to dataset\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to dataset\MNIST\raw\train-labels-idx1-ubyte.gz


100.0%


Extracting dataset\MNIST\raw\train-labels-idx1-ubyte.gz to dataset\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to dataset\MNIST\raw\t10k-images-idx3-ubyte.gz


100.0%


Extracting dataset\MNIST\raw\t10k-images-idx3-ubyte.gz to dataset\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to dataset\MNIST\raw\t10k-labels-idx1-ubyte.gz


100.0%


Extracting dataset\MNIST\raw\t10k-labels-idx1-ubyte.gz to dataset\MNIST\raw



In [6]:
train_transform = transforms.Compose([ToTensor(), ])
test_transform = transforms.Compose([ToTensor(), ])

train_dataset.transform = train_transform
test_dataset.transform = test_transform

len_data = len(train_dataset)

train_data, val_data = random_split(train_dataset, [int(len_data * 0.8), int(len_data * 0.2)])
batch_size = 256

In [7]:
train_loader = DataLoader(train_data, batch_size = batch_size)
valid_loader = DataLoader(val_data, batch_size = batch_size)
test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle = True)

In [10]:
class Encoder(nn.Module):
    def __init__(self, encoder_dim):
        super().__init__()
        self.encoder_cnn = nn.Sequential(
            nn.Conv2d(1, 4, kernel_size = 5),
            nn.ReLU(True),
            nn.Conv2d(4, 8, kernel_size = 5),
            nn.ReLU(True)
        )

        #8 * 20 * 20
        self.flatten = nn.Flatten(start_dim = 1)
        self.encoder_fc = nn.Sequential(
            nn.Linear(3200, 128),
            nn.ReLU(True),
            nn.Linear(128, encoder_dim)
        )

    def forward(self, x):
        x = self.encoder_cnn(x)
        x = self.flatten(x)
        x = self.encoder_fc(x)
        return x

In [12]:
class Decoder(nn.Module):
    def __init__(self, encoder_dim):
        super().__init__()
        
        self.decoder_fc = nn.Sequential(
            nn.Linear(encoder_dim, 128),
            nn.ReLU(True),
            nn.Linear(128, 3200),
            nn.ReLU(True)
        )

        self.unflatten = nn.Unflatten(dim = 1, unflattened_size = (8, 20, 20))

        self.decoder_cnn = nn.Sequential(
            nn.ConvTranspose2d(8, 10, kernel_size = 5), 
            nn.ConvTranspose2d(10, 1, kernel_size = 5)
        )

    def forward(self, x):
        x = self.decoder_fc(x)
        x = self.unflatten(x)
        x = self.decoder_cnn(x)
        x = torch.sigmoid(x)
        return x

In [16]:
loss_fn = nn.MSELoss()
lr = 0.01

torch.manual_seed(0)

encoder_dim = 4

encoder = Encoder(encoder_dim = encoder_dim)
decoder = Decoder(encoder_dim = encoder_dim)

params_optimize = [
    {'params' : encoder.parameters()},
    {'params' : decoder.parameters()}
]

optim = optim.Adam(params_optimize, lr = lr, weight_decay = 1e-05)

In [18]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("Selected device {0}".format(device))

Selected device cpu


In [20]:
encoder.to(device)

Encoder(
  (encoder_cnn): Sequential(
    (0): Conv2d(1, 4, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(4, 8, kernel_size=(5, 5), stride=(1, 1))
    (3): ReLU(inplace=True)
  )
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (encoder_fc): Sequential(
    (0): Linear(in_features=3200, out_features=128, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=128, out_features=4, bias=True)
  )
)

In [21]:
decoder.to(device)

Decoder(
  (decoder_fc): Sequential(
    (0): Linear(in_features=4, out_features=128, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=128, out_features=3200, bias=True)
    (3): ReLU(inplace=True)
  )
  (unflatten): Unflatten(dim=1, unflattened_size=(8, 20, 20))
  (decoder_cnn): Sequential(
    (0): ConvTranspose2d(8, 10, kernel_size=(5, 5), stride=(1, 1))
    (1): ConvTranspose2d(10, 1, kernel_size=(5, 5), stride=(1, 1))
  )
)

In [22]:
def train_epoch(encoder, decoder, device, dataloader, loss_fn, optimizer):
    encoder.train()
    decoder.train()
    train_loss = []
    for image_batch, _ in dataloader:
        image_batch = image_batch.to(device)
        encoder_data = encoder(image_batch)
        decoder_data = decoder(encoder_data)
        loss = loss_fn(decoder_data, image_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss.append(loss.detach().cpu().numpy())


In [24]:
def test_epoch(encoder, decoder, device, dataloader, loss_fn):
    encoder.eval()
    decoder.eval()
    with torch.no_grad():
        conc_out = []
        conc_label = []
        for image_batch, _ in dataloader:
            image_batch = image_batch.to(device)
            encoder_data = encoder(image_batch)
            decoder_data = decoder(encoder_data)
            conc_out.append(decoder_data.cpu())
            conc_label.append(image_batch.cpu())
        
        conc_out = torch.cat(conc_out)
        conc_label = torch.cat(conc_label)
        val_loss = loss_fn(conc_out, conc_label)
    
    return val_loss.data