<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/VAEs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn

# Define the VAE class
class VAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)  # First fully connected layer
        self.fc21 = nn.Linear(hidden_dim, latent_dim)  # Fully connected layer for mu
        self.fc22 = nn.Linear(hidden_dim, latent_dim)  # Fully connected layer for logvar
        self.fc3 = nn.Linear(latent_dim, hidden_dim)  # Fully connected layer for decoder input
        self.fc4 = nn.Linear(hidden_dim, input_dim)  # Fully connected layer for output

    def encode(self, x):
        h = torch.relu(self.fc1(x))  # Apply ReLU activation
        return self.fc21(h), self.fc22(h)  # Return mu and logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)  # Compute standard deviation
        eps = torch.randn_like(std)  # Sample from standard normal distribution
        return mu + eps * std  # Return reparameterized variable

    def decode(self, z):
        h = torch.relu(self.fc3(z))  # Apply ReLU activation
        return torch.sigmoid(self.fc4(h))  # Apply sigmoid activation

    def forward(self, x):
        mu, logvar = self.encode(x)  # Encode input
        z = self.reparameterize(mu, logvar)  # Sample from latent space
        return self.decode(z), mu, logvar  # Decode to reconstruct input

# Example usage with dummy data
vae = VAE(input_dim=784, hidden_dim=400, latent_dim=20)
input_data = torch.randn(64, 784)  # Example input (batch_size=64, input_dim=784)
recon_data, mu, logvar = vae(input_data)

# Print the shape of the reconstructed data
print(recon_data.shape)  # Expected shape: [64, 784]