In [2]:
import torch
import os
from tensorboardX import SummaryWriter

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt

import torchvision.datasets as datasets
import torchvision.transforms as transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
transforms = transforms.Compose([transforms.ToTensor()])

train_dataset = datasets.MNIST(
    root="data2", train=True, transform=transforms, download=True
)

test_dataset = datasets.MNIST(
    root="data2", train=False, transform=transforms, download=True
)

train_loder = DataLoader(
    train_dataset, batch_size=100, shuffle=True, num_workers=4, pin_memory=False
)

test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False, num_workers=4)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data2\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:01<00:00, 7544831.55it/s] 


Extracting data2\MNIST\raw\train-images-idx3-ubyte.gz to data2\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data2\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<?, ?it/s]


Extracting data2\MNIST\raw\train-labels-idx1-ubyte.gz to data2\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data2\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 5597541.61it/s]


Extracting data2\MNIST\raw\t10k-images-idx3-ubyte.gz to data2\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data2\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<?, ?it/s]

Extracting data2\MNIST\raw\t10k-labels-idx1-ubyte.gz to data2\MNIST\raw





In [4]:
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super().__init__()
        self.input1 = nn.Linear(input_dim, hidden_dim)
        self.input2 = nn.Linear(hidden_dim, hidden_dim)
        self.mean = nn.Linear(hidden_dim, latent_dim)
        self.var = nn.Linear(hidden_dim, latent_dim)

        self.LeakyReLU = nn.LeakyReLU(0.2)
        self.training = True

    def forward(self, x):
        h_ = self.LeakyReLU(self.input1(x))
        h_ = self.LeakyReLU(self.input2(h_))
        mean = self.mean(h_)
        log_var = self.var(h_)
        return mean, log_var


class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super().__init__()
        self.input1 = nn.Linear(latent_dim, hidden_dim)
        self.input2 = nn.Linear(hidden_dim, hidden_dim)
        self.output = nn.Linear(hidden_dim, output_dim)
        self.LeakyReLU = nn.LeakyReLU(0.2)

    def forward(self, x):
        h = self.LeakyReLU(self.input1(x))
        h = self.LeakyReLU(self.input2(h))
        x_hat = torch.sigmoid(self.output(h))
        return x_hat


class Model_network(nn.Module):
    def __init__(self, Encoder, Decoder):
        super().__init__()
        self.Encoder = Encoder
        self.Decoder = Decoder

    def reparameterization(self, mean, var):
        epsilon = torch.randn_like(var).to(device)
        z = mean + var * epsilon
        return z

    def forward(self, x):
        mean, log_var = self.Encoder(x)
        z = self.reparameterization(mean, torch.exp(0.5 * log_var))
        x_hat = self.Decoder(z)
        return x_hat, mean, log_var

In [5]:
x_dim = 784
hidden_dim = 400
latent_dim = 200
epochs = 30
batch_size = 100

encoder = Encoder(x_dim, hidden_dim, latent_dim).to(device)
decoder = Decoder(latent_dim, hidden_dim, x_dim).to(device)
model = Model_network(encoder, decoder).to(device)

In [6]:
def loss_function(x, x_hat, mean, log_var):
    reproduction_loss = nn.functional.binary_cross_entropy(x_hat, x, reduction="sum")
    KLD = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
    return reproduction_loss, KLD


optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [7]:
saved_loc = "scalar/"
writer = SummaryWriter(saved_loc)

model.train()


def train(epoch, model, train_loder, optimizer):
    train_loss = 0
    for batch_idx, (x, _) in enumerate(train_loder):
        x = x.view(batch_size, x_dim)
        x = x.to(device)

        optimizer.zero_grad()
        x_hat, mean, log_var = model(x)
        BCE, KLD = loss_function(x, x_hat, mean, log_var)
        loss = BCE + KLD
        writer.add_scalar(
            "Train/Reconstruction Error",
            BCE.item(),
            batch_idx + epoch * len(train_loder.dataset) / batch_size,
        )
        writer.add_scalar(
            "Train/KL-Divergence",
            KLD.item(),
            batch_idx + epoch * len(train_loder.dataset) / batch_size,
        )
        writer.add_scalar(
            "Train/Total Loss",
            loss.item(),
            batch_idx + epoch * len(train_loder.dataset) / batch_size,
        )
        train_loss += loss.item()
        loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            print(
                f"Train Epoch: {epoch} [{batch_idx * len(x)}/{len(train_loder.dataset)} ({100. * batch_idx / len(train_loder):.0f}%)]\tLoss: {loss.item() / len(x):.6f}"
            )
    print(
        "====> Epoch: {} Average loss: {:.4f}".format(
            epoch, train_loss / len(train_loder.dataset)
        )
    )

In [8]:
def test(epoch, model, test_loader):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for batch_idx, (x, _) in enumerate(test_loader):
            x = x.view(batch_size, x_dim)
            x = x.to(device)
            x_hat, mean, log_var = model(x)
            BCE, KLD = loss_function(x, x_hat, mean, log_var)
            loss = BCE + KLD

            writer.add_scalar(
                "Test/Reconstruction Error",
                BCE.item(),
                batch_idx + epoch * len(test_loader.dataset) / batch_size,
            )
            writer.add_scalar(
                "Test/KL-Divergence",
                KLD.item(),
                batch_idx + epoch * len(test_loader.dataset) / batch_size,
            )
            writer.add_scalar(
                "Test/Total Loss",
                loss.item(),
                batch_idx + epoch * len(test_loader.dataset) / batch_size,
            )
            test_loss += loss.item()

            if batch_idx == 0:
                n = min(x.size(0), 8)
                comparison = torch.cat([x[:n], x_hat.view(batch_size, x_dim)[:n]])
                grid = torchvision.utils.make_grid(comparison.cpu())
                writer.add_image(
                    "Test image - Above: real data, below: reconstructed data",
                    grid,
                    epoch,
                )

In [9]:
from tqdm.auto import tqdm

for epoch in tqdm(range(0, epochs)):
    train(epoch, model, train_loder, optimizer)
    test(epoch, model, test_loader)
    print("\n")
writer.close()

  0%|          | 0/30 [00:00<?, ?it/s]

====> Epoch: 0 Average loss: 173.6330

====> Epoch: 1 Average loss: 128.3812
====> Epoch: 2 Average loss: 116.5594
====> Epoch: 3 Average loss: 112.1319
====> Epoch: 4 Average loss: 109.7575
====> Epoch: 5 Average loss: 108.0817
====> Epoch: 6 Average loss: 106.9279
====> Epoch: 7 Average loss: 105.9735

====> Epoch: 8 Average loss: 105.3685

====> Epoch: 9 Average loss: 104.6922

====> Epoch: 10 Average loss: 104.2365

====> Epoch: 11 Average loss: 103.7784

====> Epoch: 12 Average loss: 103.3555

====> Epoch: 13 Average loss: 103.0026

====> Epoch: 14 Average loss: 102.6596

====> Epoch: 15 Average loss: 102.3887

====> Epoch: 16 Average loss: 102.1218

====> Epoch: 17 Average loss: 101.9588

====> Epoch: 18 Average loss: 101.7313

====> Epoch: 19 Average loss: 101.5428

====> Epoch: 20 Average loss: 101.3697

====> Epoch: 21 Average loss: 101.1782

====> Epoch: 22 Average loss: 101.0778

====> Epoch: 23 Average loss: 100.9537

====> Epoch: 24 Average loss: 100.8101

====> Epoch: 25 

In [10]:
%load_ext tensorboard
%tensorboard --logdir scalar --port=6013

Launching TensorBoard...