<a href="https://colab.research.google.com/github/R0bk/ml_replications/blob/main/04_vae.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose
import matplotlib.pyplot as plt

In [None]:
# Download training data from open datasets.
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

In [None]:
batch_size = 128

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print("Shape of X [N, C, H, W]: ", X.shape)
    print("Shape of y: ", y.shape, y.dtype)
    break

Shape of X [N, C, H, W]:  torch.Size([128, 1, 28, 28])
Shape of y:  torch.Size([128]) torch.int64


In [None]:
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

class NeuralNetwork(nn.Module):
  def __init__(self):
    super(NeuralNetwork, self).__init__()
    self.flatten = nn.Flatten()

    self.encoder = nn.Sequential(
        nn.Linear(28*28, 400),
        nn.ReLU(),
        nn.Linear(400, 64),
        nn.ReLU(),
        nn.Linear(64, 40),
        nn.ReLU()
    )
    self.decoder = nn.Sequential(
        nn.Linear(20, 64),
        nn.ReLU(),
        nn.Linear(64, 400),
        nn.ReLU(),
        nn.Linear(400, 28*28),
        nn.ReLU()
    )

  def sample(self, x):
    mu = x[:, 0, :]
    if self.training:
      stds = x[:, 1, :]

      eps = torch.normal(0., 1., stds.shape).to(device)
      x = stds.exp().mul(eps).add(mu)
      
      return x, mu, stds
    return mu, None, None
    

  def forward(self, x, batch_size):
    x = self.flatten(x)
    x = self.encoder(x)
    x, mu, stds = self.sample(torch.reshape(x, (batch_size, 2, 20)))
    x = self.decoder(x)
    return x, mu, stds

model = NeuralNetwork().to(device)
print(model)

Using cuda device
NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (encoder): Sequential(
    (0): Linear(in_features=784, out_features=400, bias=True)
    (1): ReLU()
    (2): Linear(in_features=400, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=40, bias=True)
    (5): ReLU()
  )
  (decoder): Sequential(
    (0): Linear(in_features=20, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=400, bias=True)
    (3): ReLU()
    (4): Linear(in_features=400, out_features=784, bias=True)
    (5): ReLU()
  )
)


In [None]:
mse = nn.MSELoss(size_average=False)
def loss_fn(y, pred, means, stds):

  loss = mse(y, pred)
  kld = means.pow(2).add(1).div(stds.exp().pow(2)).sub(1).add(stds.mul(2))
  kld = kld.mul(0.5).sum()
  return loss.add(kld), kld



In [None]:
optimiser = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
def train(dataloader, model, loss_fn, optimiser):
  size = len(train_dataloader.dataset)

  model.train()
  for batch, (X, y) in enumerate(dataloader):
    X, y = X.to(device), y.to(device)
    # print('hi')
    pred, means, stds = model(X, len(X))
    # print(pred.shape)
    # print(X.squeeze().flatten(1).shape)
    loss, kl_loss = loss_fn(X.squeeze().flatten(1), pred, means, stds)
    

    optimiser.zero_grad()
    loss.backward()
    optimiser.step()


    if batch % 100 == 0:
      loss, current = loss.item(), batch * len(X)
      print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]", kl_loss)

In [None]:
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [None]:
epochs = 40
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimiser)
    # test(test_dataloader, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
loss: 18773.445312  [    0/60000] tensor(11.3471, device='cuda:0', grad_fn=<SumBackward0>)
loss: 7575.714844  [12800/60000] tensor(686.3396, device='cuda:0', grad_fn=<SumBackward0>)
loss: 6618.272949  [25600/60000] tensor(812.8885, device='cuda:0', grad_fn=<SumBackward0>)
loss: 5534.042480  [38400/60000] tensor(806.9263, device='cuda:0', grad_fn=<SumBackward0>)
loss: 5589.187988  [51200/60000] tensor(881.5921, device='cuda:0', grad_fn=<SumBackward0>)
Epoch 2
-------------------------------
loss: 5684.200195  [    0/60000] tensor(781.9841, device='cuda:0', grad_fn=<SumBackward0>)
loss: 5463.739258  [12800/60000] tensor(892.3912, device='cuda:0', grad_fn=<SumBackward0>)
loss: 5461.671875  [25600/60000] tensor(851.7177, device='cuda:0', grad_fn=<SumBackward0>)
loss: 5205.003906  [38400/60000] tensor(817.7628, device='cuda:0', grad_fn=<SumBackward0>)
loss: 4961.550293  [51200/60000] tensor(905.8105, device='cuda:0', grad_fn=<SumBackward0>)
Epoch 3
--

In [None]:
def plot(dataloader, model, loss_fn):
    i = 0
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            i+=1
            print(X.shape)
            X, y = X.to(device), y.to(device)
            pred, means, stds = model(X, len(X))
            print(pred.shape)
            plt.figure()
            plt.imshow(  X.cpu()[0].squeeze()  )
            plt.figure()
            plt.imshow(  pred.cpu().unflatten(1, torch.Size([28, 28]))[0]  )
            if i == 10:
              break
plot(train_dataloader, model, loss_fn)

torch.Size([128, 1, 28, 28])


ValueError: ignored