In [None]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines){
    return false;
}

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import TensorDataset, DataLoader, random_split
import torch.nn as nn
import torch.optim as optim
import time
from aepsf import VAELoss, VAE, make_mock
import pickle
import torchvision

In [None]:
data = torchvision.datasets.MNIST('./local_data/', transform=torchvision.transforms.ToTensor(), download=True)

normed_Z = []
for i , (x, y) in enumerate(data):
    # print(i)
    normed_Z.append(x.detach().numpy()[0].reshape(-1))

normed_Z = np.array(normed_Z)

In [None]:
for i in range(10):
    plt.figure()
    plt.imshow(normed_Z[i].reshape((28,28)))
    plt.colorbar()

In [None]:
train_ratio = 0.8
batch_size = 64
device = 'cpu'
# device = 'mps'


dataset = TensorDataset(torch.tensor(normed_Z, dtype=torch.float32).to(device))

# Split the dataset into training and validation sets
train_size = int(train_ratio * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create data loaders for the training and validation sets
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=val_size, shuffle=False)

In [None]:
# Instantiate VAE model
vae = VAE(N_pixel=28*28, layers_n_hidden_units=[512, 256, 128], latent_dim=2).to(device)

# Define MSE loss function
vae_loss = VAELoss().to(device)

# Define optimizer
optimizer = optim.Adam(vae.parameters(), lr=1e-3)

train_size = len(train_loader.dataset)
print(train_size)

epochs = 30
total_loss = []
val_loss = []

TIMEA = time.time()

print('start')
for epoch in range(epochs):
    running_loss = 0.0
    for x, in train_loader:
        # Zero the gradients
        optimizer.zero_grad()
        # Forward pass
        x_hat, mu, log_var = vae(x)
        loss = vae_loss(x, x_hat, mu, log_var)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        # Add the mini-batch loss to the running loss
        running_loss += loss.item()

    for x_val, in val_loader:
        #x_hat_val, mu_val, log_var_val = vae(x_val)
        x_hat_val, mu_val, log_var_val = vae.forward(x_val, repam=False)
        vloss = vae_loss(x_val, x_hat_val, mu_val, log_var_val)
        val_loss.append(vloss.item())

    # Compute the average loss for the epoch
    epoch_loss = running_loss #/ train_size
    total_loss.append(epoch_loss)

    # Print the average loss for the epoch
    print(f"Epoch {epoch+1} loss: {epoch_loss:.6f} validation loss: {vloss:.6f}")

TIMEB = time.time()

print(TIMEB-TIMEA)

In [None]:
plt.plot(total_loss)
plt.plot(val_loss)
#plt.yscale('log')

In [None]:
for i in range(20):

    x = x_val[i].detach().numpy().reshape(28,28)
    x_hat = x_hat_val[i].detach().numpy().reshape(28,28)
    
    plt.figure(figsize=(12,5))
    plt.subplots_adjust(wspace=0.4, left=0.05, right=0.95)
    plt.subplot(1,3,1)
    plt.imshow(x, cmap=plt.cm.seismic, vmin=-1, vmax=1)
    plt.gca().invert_yaxis()
    plt.colorbar()

    plt.subplot(1,3,2)
    plt.imshow(x_hat, cmap=plt.cm.seismic, vmin=-1, vmax=1)
    plt.colorbar()
    plt.gca().invert_yaxis()
    plt.yticks([],[])

    plt.subplot(1,3,3)
    plt.imshow(((x_hat-x) / x) * 200, cmap=plt.cm.seismic, vmin=-10, vmax=10)
    plt.colorbar()
    plt.gca().invert_yaxis()
    plt.yticks([],[])

In [None]:
A = mu_val.detach().numpy()
plt.scatter(A[:,0], A[:,1])
plt.figure()
plt.scatter(A[:,0], A[:,2])
plt.figure()
plt.scatter(A[:,1], A[:,2])

In [None]:
VAE?