# Defining the VAE Class

In [1]:
import torch
from torch import nn
from torch.nn import functional as F

class VAE(nn.Module):
    def __init__(self, latent_dim=20):
        super(VAE, self).__init__()
        
        # Encoder
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1)  # Output: 32 x 14 x 14
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)  # Output: 64 x 7 x 7
        self.fc1 = nn.Linear(64 * 7 * 7, 400)
        self.fc21 = nn.Linear(400, latent_dim)  # Mean vector
        self.fc22 = nn.Linear(400, latent_dim)  # Log variance vector
        
        # Decoder
        self.fc3 = nn.Linear(latent_dim, 400)
        self.fc4 = nn.Linear(400, 64 * 7 * 7)
        self.conv_transpose1 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)  # Output: 32 x 14 x 14
        self.conv_transpose2 = nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1)  # Output: 1 x 28 x 28

    def encode(self, x):
        h1 = F.relu(self.conv1(x))
        h2 = F.relu(self.conv2(h1))
        h2_flat = h2.view(-1, 64 * 7 * 7)
        h3 = F.relu(self.fc1(h2_flat))
        return self.fc21(h3), self.fc22(h3)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h4 = F.relu(self.fc3(z))
        h5 = F.relu(self.fc4(h4))
        h5 = h5.view(-1, 64, 7, 7)
        h6 = F.relu(self.conv_transpose1(h5))
        return torch.sigmoid(self.conv_transpose2(h6))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

    def loss_function(self, recon_x, x, mu, logvar):
        BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
        # KL divergence
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return BCE + KLD


In [2]:
import torch
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import plotly.graph_objects as go
import os

# VAE class definition
class VAE(nn.Module):
    def __init__(self, latent_dim=20):
        super(VAE, self).__init__()
        # Encoder
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1)  # Output: 32 x 14 x 14
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)  # Output: 64 x 7 x 7
        self.fc1 = nn.Linear(64 * 7 * 7, 400)
        self.fc21 = nn.Linear(400, latent_dim)
        self.fc22 = nn.Linear(400, latent_dim)
        # Decoder
        self.fc3 = nn.Linear(latent_dim, 400)
        self.fc4 = nn.Linear(400, 64 * 7 * 7)
        self.conv_transpose1 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.conv_transpose2 = nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1)

    def encode(self, x):
        h1 = F.relu(self.conv1(x))
        h2 = F.relu(self.conv2(h1))
        h2_flat = h2.view(-1, 64 * 7 * 7)
        h3 = F.relu(self.fc1(h2_flat))
        return self.fc21(h3), self.fc22(h3)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h4 = F.relu(self.fc3(z))
        h5 = F.relu(self.fc4(h4))
        h5 = h5.view(-1, 64, 7, 7)
        h6 = F.relu(self.conv_transpose1(h5))
        return torch.sigmoid(self.conv_transpose2(h6))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

    def loss_function(self, recon_x, x, mu, logvar):
        BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return BCE + KLD

# Tuning the Learning Rate

In [3]:
import torch
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import plotly.graph_objects as go
import os

# VAE class definition
class VAE(nn.Module):
    def __init__(self, latent_dim=20):
        super(VAE, self).__init__()
        # Encoder
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1)  # Output: 32 x 14 x 14
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)  # Output: 64 x 7 x 7
        self.fc1 = nn.Linear(64 * 7 * 7, 400)
        self.fc21 = nn.Linear(400, latent_dim)
        self.fc22 = nn.Linear(400, latent_dim)
        # Decoder
        self.fc3 = nn.Linear(latent_dim, 400)
        self.fc4 = nn.Linear(400, 64 * 7 * 7)
        self.conv_transpose1 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.conv_transpose2 = nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1)

    def encode(self, x):
        h1 = F.relu(self.conv1(x))
        h2 = F.relu(self.conv2(h1))
        h2_flat = h2.view(-1, 64 * 7 * 7)
        h3 = F.relu(self.fc1(h2_flat))
        return self.fc21(h3), self.fc22(h3)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h4 = F.relu(self.fc3(z))
        h5 = F.relu(self.fc4(h4))
        h5 = h5.view(-1, 64, 7, 7)
        h6 = F.relu(self.conv_transpose1(h5))
        return torch.sigmoid(self.conv_transpose2(h6))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

    def loss_function(self, recon_x, x, mu, logvar):
        BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return BCE + KLD

# Set up training environment
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = 100
checkpoint_dir = './vae_checkpoints'
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

# Prepare MNIST dataset
transform = transforms.ToTensor()
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)

# Function for training VAE with different learning rates
def train_vae_with_lr_and_save_checkpoints(learning_rates):
    loss_results = {}
    for lr in learning_rates:
        train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
        model = VAE(latent_dim=20).to(device)
        optimizer = optim.Adam(model.parameters(), lr=lr)
        loss_results[lr] = []
        for epoch in range(epochs):
            model.train()
            total_loss = 0
            for batch_idx, (data, _) in enumerate(train_loader):
                data = data.to(device)
                optimizer.zero_grad()
                recon_batch, mu, logvar = model(data)
                loss = model.loss_function(recon_batch, data, mu, logvar)
                loss.backward()
                total_loss += loss.item()
                optimizer.step()
            avg_loss = total_loss / len(train_loader.dataset)
            loss_results[lr].append(avg_loss)
            print(f'LR: {lr}, Epoch: {epoch}, Loss: {avg_loss}')
            # Save checkpoint
            checkpoint_path = os.path.join(checkpoint_dir, f'vae_lr_{lr}_epoch_{epoch}.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_loss,
            }, checkpoint_path)
    return loss_results

# Learning rates to experiment with
learning_rates = [1e-4, 1e-3, 1e-2]

# Train and collect loss results
lr_loss_results = train_vae_with_lr_and_save_checkpoints(learning_rates)

# Plot the results using Plotly
fig = go.Figure(layout=go.Layout(template="simple_white"))
for lr, losses in lr_loss_results.items():
    fig.add_trace(go.Scatter(x=np.arange(1, epochs+1), y=losses, mode='lines+markers', name=f'LR: {lr}'))
fig.update_layout(title='VAE Training Loss by Learning Rate', xaxis_title='Epoch', yaxis_title='Loss')
fig.show()


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 110986356.84it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 52236176.72it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 24692028.15it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 7924512.80it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

LR: 0.0001, Epoch: 0, Loss: 216.19542233072917
LR: 0.0001, Epoch: 1, Loss: 137.79010704752605
LR: 0.0001, Epoch: 2, Loss: 120.91865783284506
LR: 0.0001, Epoch: 3, Loss: 113.65512817382813
LR: 0.0001, Epoch: 4, Loss: 110.04836251220704
LR: 0.0001, Epoch: 5, Loss: 107.84220789388021
LR: 0.0001, Epoch: 6, Loss: 106.2928292561849
LR: 0.0001, Epoch: 7, Loss: 105.2018078125
LR: 0.0001, Epoch: 8, Loss: 104.31523302815755
LR: 0.0001, Epoch: 9, Loss: 103.60809580078126
LR: 0.0001, Epoch: 10, Loss: 102.9880675048828
LR: 0.0001, Epoch: 11, Loss: 102.49077422688802
LR: 0.0001, Epoch: 12, Loss: 102.04498166503906
LR: 0.0001, Epoch: 13, Loss: 101.64449479166667
LR: 0.0001, Epoch: 14, Loss: 101.26167803548176
LR: 0.0001, Epoch: 15, Loss: 100.95743692626954
LR: 0.0001, Epoch: 16, Loss: 100.60335036621093
LR: 0.0001, Epoch: 17, Loss: 100.33059466552734
LR: 0.0001, Epoch: 18, Loss: 100.0812346110026
LR: 0.0001, Epoch: 19, Loss: 9

# Plotting the result 

In [4]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np

import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Prepare MNIST dataset for testing
transform = transforms.ToTensor()
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# Define test_loader
test_loader = DataLoader(test_dataset, batch_size=10, shuffle=False)

# Function to visualize reconstructions
def visualize_reconstructions_for_lr(learning_rates, checkpoint_dir, test_loader, device):
    num_images = 10  # Number of images to display
    rows = len(learning_rates)
    
    # Create a subplot
    fig = make_subplots(rows=rows, cols=num_images, subplot_titles=['Original'] + ['Reconstructed'] * (num_images - 1))

    for i, lr in enumerate(learning_rates):
        # Load the trained model from checkpoint
        model = VAE(latent_dim=20).to(device)
        checkpoint_path = os.path.join(checkpoint_dir, f'vae_lr_{lr}_epoch_9.pth')  # Adjust based on your saved checkpoints
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()

        # Get a batch of data
        for batch, _ in test_loader:
            batch = batch.to(device)
            recon, _, _ = model(batch[:num_images])
            break  # Only need one batch

        # Prepare original and reconstructed images for plotting
        original_images = batch[:num_images].cpu().numpy().squeeze()
        reconstructed_images = recon.detach().cpu().numpy().squeeze()


        # Plot original and reconstructed images
        for j in range(num_images):
            # Original images (first row of each section)
            fig.add_trace(
                go.Heatmap(
                    z=original_images[j],
                    colorscale='Greys',
                    showscale=False,
                    zmin=0, zmax=1
                ),
                row=i + 1, col=j + 1
            )
            # Reconstructed images
            if j > 0:  # Skip the first slot on reconstructed rows, filled by original
                fig.add_trace(
                    go.Heatmap(
                        z=reconstructed_images[j - 1],  # j-1 because first image is original
                        colorscale='Greys',
                        showscale=False,
                        zmin=0, zmax=1
                    ),
                    row=i + 1, col=j + 1
                )

    # Update layout
    fig.update_layout(height=150 * rows, width=800, title_text="Original and Reconstructed Images for Different Learning Rates", template="simple_white" )
    fig.show()

# Visualize reconstructions
visualize_reconstructions_for_lr([1e-4, 1e-3, 1e-2], './vae_checkpoints', test_loader, device)


# Tunining the Batch Size

In [5]:
import torch
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import plotly.graph_objects as go
import os

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# MNIST dataset
transform = transforms.ToTensor()
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# Directory to save checkpoints
checkpoint_dir = './vae_batchsize_checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)

# Training function
def train_vae_with_batch_size(batch_sizes, epochs=100):
    loss_results = {}
    for batch_size in batch_sizes:
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        model = VAE(latent_dim=20).to(device)
        optimizer = optim.Adam(model.parameters(), lr=1e-3)
        loss_results[batch_size] = []
        
        for epoch in range(epochs):
            model.train()
            total_loss = 0
            for batch_idx, (data, _) in enumerate(train_loader):
                data = data.to(device)
                optimizer.zero_grad()
                recon_batch, mu, logvar = model(data)
                loss = model.loss_function(recon_batch, data, mu, logvar)
                loss.backward()
                total_loss += loss.item()
                optimizer.step()
            
            avg_loss = total_loss / len(train_loader.dataset)
            loss_results[batch_size].append(avg_loss)
            print(f'Batch size: {batch_size}, Epoch: {epoch}, Loss: {avg_loss}')
            
            # Save checkpoint
            checkpoint_path = os.path.join(checkpoint_dir, f'vae_batchsize_{batch_size}_epoch_{epoch}.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_loss,
            }, checkpoint_path)
            
    return loss_results

# Batch sizes to experiment with
batch_sizes = [32, 64, 128]

# Train and get results
batch_loss_results = train_vae_with_batch_size(batch_sizes)

# Plot training results
fig = go.Figure(layout=go.Layout(template="simple_white"))
for batch_size, losses in batch_loss_results.items():
    fig.add_trace(go.Scatter(x=np.arange(1, epochs+1), y=losses, mode='lines+markers', name=f'Batch Size: {batch_size}'))
fig.update_layout(title='VAE Training Loss by Batch Size', xaxis_title='Epoch', yaxis_title='Loss')
fig.show()


Batch size: 32, Epoch: 0, Loss: 134.71615651041665
Batch size: 32, Epoch: 1, Loss: 107.75895514729818
Batch size: 32, Epoch: 2, Loss: 104.29156040852864
Batch size: 32, Epoch: 3, Loss: 102.72250392659505
Batch size: 32, Epoch: 4, Loss: 101.65598631184896
Batch size: 32, Epoch: 5, Loss: 100.83129790039062
Batch size: 32, Epoch: 6, Loss: 100.2652395711263
Batch size: 32, Epoch: 7, Loss: 99.80134552815755
Batch size: 32, Epoch: 8, Loss: 99.3098961710612
Batch size: 32, Epoch: 9, Loss: 99.02629256184896
Batch size: 32, Epoch: 10, Loss: 98.65876465657553
Batch size: 32, Epoch: 11, Loss: 98.3900671101888
Batch size: 32, Epoch: 12, Loss: 98.1846992553711
Batch size: 32, Epoch: 13, Loss: 97.86719619140625
Batch size: 32, Epoch: 14, Loss: 97.69376309407552
Batch size: 32, Epoch: 15, Loss: 97.5147716796875
Batch size: 32, Epoch: 16, Loss: 97.33063371988932
Batch size: 32, Epoch: 17, Loss: 97.11468103434245
Batch size: 32, Epoch: 18, Loss: 96.97634073079428
Batch size: 32, Epoch: 19, Loss: 96.836

# PLotting the Result

In [6]:
# Function for visualization of reconstructions, similar to the one used for learning rate visualization
import plotly.graph_objects as go
def visualize_reconstructions_for_batch_size(batch_sizes, checkpoint_dir, test_loader, device, epochs):
    num_images = 10  # Number of images to display
    rows = len(batch_sizes)
    
    # Create a subplot
    fig = make_subplots(rows=rows, cols=num_images, subplot_titles=['Original'] + ['Reconstructed'] * (num_images - 1))

    for i, bs in enumerate(batch_sizes):
        # Load the trained model from checkpoint
        model = VAE(latent_dim=20).to(device)
        checkpoint_path = os.path.join(checkpoint_dir, f'vae_batchsize_{bs}_epoch_9.pth')  # Adjust based on your saved checkpoints
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()

        # Get a batch of data
        for batch, _ in test_loader:
            batch = batch.to(device)
            recon, _, _ = model(batch[:num_images])
            break  # Only need one batch

        # Prepare original and reconstructed images for plotting
        original_images = batch[:num_images].cpu().numpy().squeeze()
        reconstructed_images = recon.detach().cpu().numpy().squeeze()


        # Plot original and reconstructed images
        for j in range(num_images):
            # Original images (first row of each section)
            fig.add_trace(
                go.Heatmap(
                    z=original_images[j],
                    colorscale='Greys',
                    showscale=False,
                    zmin=0, zmax=1
                ),
                row=i + 1, col=j + 1
            )
            # Reconstructed images
            if j > 0:  # Skip the first slot on reconstructed rows, filled by original
                fig.add_trace(
                    go.Heatmap(
                        z=reconstructed_images[j - 1],  # j-1 because first image is original
                        colorscale='Greys',
                        showscale=False,
                        zmin=0, zmax=1
                    ),
                    row=i + 1, col=j + 1
                )

    # Update layout
    fig.update_layout(height=150 * rows, width=800, title_text="Original and Reconstructed Images for Different Learning Rates", template="simple_white" )
    fig.show()
    pass  # Replace with actual implementation

# Assuming test_loader is defined (add this if not)
test_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)

# Visualize reconstructions for different batch sizes
visualize_reconstructions_for_batch_size(batch_sizes, checkpoint_dir, test_loader, device, epochs=10)

# Tuning the Latent Dimensions

In [7]:
# import torch
# import torch.optim as optim
# from torchvision import datasets, transforms
# from torch.utils.data import DataLoader
# import numpy as np
# import plotly.graph_objects as go
# import os

# # Set device
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # MNIST dataset
# transform = transforms.ToTensor()
# train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# # Directory to save checkpoints
# checkpoint_dir = './vae_latent_checkpoints'
# os.makedirs(checkpoint_dir, exist_ok=True)

# # Training function
# def train_vae_with_latent_dims(latent_dims, epochs=100):
#     loss_results = {}
#     for latent_dim in latent_dims:
#         train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
#         model = VAE(latent_dim=latent_dim).to(device)
#         optimizer = optim.Adam(model.parameters(), lr=1e-3)
#         loss_results[latent_dim] = []
        
#         for epoch in range(epochs):
#             model.train()
#             total_loss = 0
#             for batch_idx, (data, _) in enumerate(train_loader):
#                 data = data.to(device)
#                 optimizer.zero_grad()
#                 recon_batch, mu, logvar = model(data)
#                 loss = model.loss_function(recon_batch, data, mu, logvar)
#                 loss.backward()
#                 total_loss += loss.item()
#                 optimizer.step()
            
#             avg_loss = total_loss / len(train_loader.dataset)
#             loss_results[latent_dim].append(avg_loss)
#             print(f'Latent dim: {latent_dim}, Epoch: {epoch}, Loss: {avg_loss}')
            
#             # Save checkpoint
#             checkpoint_path = os.path.join(checkpoint_dir, f'vae_latent_{latent_dim}_epoch_{epoch}.pth')
#             torch.save({
#                 'epoch': epoch,
#                 'model_state_dict': model.state_dict(),
#                 'optimizer_state_dict': optimizer.state_dict(),
#                 'loss': avg_loss,
#             }, checkpoint_path)
            
#     return loss_results

# # Latent dimensions to experiment with
# latent_dims = [2, 10, 20]

# # Train and get results
# latent_loss_results = train_vae_with_latent_dims(latent_dims)

# # Plot training results
# fig = go.Figure(layout=go.Layout(template="simple_white"))
# for latent_dim, losses in latent_loss_results.items():
#     fig.add_trace(go.Scatter(x=np.arange(1, epochs+1), y=losses, mode='lines+markers', name=f'Latent Dim: {latent_dim}'))
# fig.update_layout(title='VAE Training Loss by Latent Dimension', xaxis_title='Epoch', yaxis_title='Loss')
# fig.show()


# Plotting the results

In [8]:
# import torch
# from torchvision import datasets, transforms
# from torch.utils.data import DataLoader
# import plotly.graph_objects as go
# from plotly.subplots import make_subplots
# import numpy as np

# # Set device
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # Define your VAE class here (I'm assuming it's already defined and ready to use)
# # from your_vae_module import VAE  # Uncomment and modify this if your VAE class is in a separate file

# # Prepare MNIST dataset for testing
# transform = transforms.ToTensor()
# test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# test_loader = DataLoader(test_dataset, batch_size=10, shuffle=False)

# # Function to visualize reconstructions for different latent dimensions
# def visualize_reconstructions_for_latent_dims(latent_dims, checkpoint_dir, test_loader, device):
#     num_images = 10  # Number of images to display
#     rows = len(latent_dims)
    
#     # Create a subplot
#     fig = make_subplots(rows=rows, cols=num_images, subplot_titles=['Original'] + ['Reconstructed'] * (num_images - 1))

#     for i, latent_dim in enumerate(latent_dims):
#         # Load the trained model from checkpoint
#         model = VAE(latent_dim=latent_dim).to(device)  # Make sure your VAE class accepts latent_dim as a parameter
#         checkpoint_path = os.path.join(checkpoint_dir, f'vae_latent_{latent_dim}_epoch_9.pth')
#         checkpoint = torch.load(checkpoint_path, map_location=device)
#         model.load_state_dict(checkpoint['model_state_dict'])
#         model.eval()

#         # Get a batch of data
#         for batch, _ in test_loader:
#             batch = batch.to(device)
#             recon, _, _ = model(batch[:num_images])
#             break  # Only need one batch

#         # Prepare original and reconstructed images for plotting
#         original_images = batch[:num_images].cpu().numpy().squeeze()
#         reconstructed_images = recon.detach().cpu().numpy().squeeze()

#         # Plot original and reconstructed images
#         for j in range(num_images):
#             # Original images (first row of each section)
#             fig.add_trace(
#                 go.Heatmap(
#                     z=original_images[j],
#                     colorscale='Greys',
#                     showscale=False,
#                     zmin=0, zmax=1
#                 ),
#                 row=i + 1, col=j + 1
#             )
#             # Reconstructed images
#             fig.add_trace(
#                 go.Heatmap(
#                     z=reconstructed_images[j],
#                     colorscale='Greys',
#                     showscale=False,
#                     zmin=0, zmax=1
#                 ),
#                 row=i + 1, col=j + 1
#             )

#     # Update layout
#     fig.update_layout(height=150 * rows, width=800, title_text="Original and Reconstructed Images for Different Latent Dimensions", template="simple_white")
#     fig.show()

# # Visualize reconstructions for different latent dimensions
# visualize_reconstructions_for_latent_dims([2, 10, 20], './vae_latent_checkpoints', test_loader, device)
