In [None]:
import requests

# Correct URL for the pre-trained VAE weights
url = "https://github.com/antonio-f/VAE-MNIST-Pytorch/raw/main/models/vae_mnist.pth"
response = requests.get(url)

# Save the file
with open("vae_mnist.pth", "wb") as f:
    f.write(response.content)

print("Pre-trained VAE weights downloaded!")

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define the VAE model
class VAE(nn.Module):
    def __init__(self, latent_dim=20):
        super(VAE, self).__init__()
        self.latent_dim = latent_dim

        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(784, 400),
            nn.ReLU(),
            nn.Linear(400, latent_dim * 2)  # Mean and log variance
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 400),
            nn.ReLU(),
            nn.Linear(400, 784),
            nn.Sigmoid()  # Output is between 0 and 1
        )

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        x = x.view(-1, 784)  # Flatten the input
        h = self.encoder(x)
        mu, logvar = h.chunk(2, dim=1)  # Split into mean and log variance
        z = self.reparameterize(mu, logvar)
        recon_x = self.decoder(z)
        return recon_x, mu, logvar

# Define the loss function
def loss_function(recon_x, x, mu, logvar):
    BCE = nn.functional.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=128, shuffle=False)

# Initialize the VAE model
vae = VAE(latent_dim=20).to(device)
optimizer = optim.Adam(vae.parameters(), lr=1e-3)

# Training loop
def train_vae(model, train_loader, epochs=10):
    model.train()
    for epoch in range(epochs):
        train_loss = 0
        for batch_idx, (data, _) in enumerate(train_loader):
            data = data.to(device)
            optimizer.zero_grad()
            recon_data, mu, logvar = model(data)
            loss = loss_function(recon_data, data, mu, logvar)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        print(f"Epoch {epoch + 1}, Loss: {train_loss / len(train_loader.dataset):.4f}")

# Train the VAE
print("Training the VAE...")
train_vae(vae, train_loader, epochs=10)
print("Training complete!")

# Save the trained model
torch.save(vae.state_dict(), "vae_mnist_trained.pth")
print("Trained VAE model saved!")

# Load the trained model (optional)
vae.load_state_dict(torch.load("vae_mnist_trained.pth", map_location=device))
vae.eval()

# Test the VAE on the test dataset
def reconstruction_error(model, data_loader):
    model.eval()
    total_error = 0
    with torch.no_grad():
        for batch_idx, (data, _) in enumerate(data_loader):
            data = data.to(device)
            recon_data, _, _ = model(data)
            error = nn.functional.mse_loss(recon_data, data.view(-1, 784), reduction='sum').item()
            total_error += error
    return total_error / len(data_loader.dataset)

test_error = reconstruction_error(vae, test_loader)
print(f"Reconstruction Error (MSE) on Test Dataset: {test_error:.4f}")

# Visualize reconstructions
def visualize_reconstructions(model, data_loader, num_images=5):
    model.eval()
    with torch.no_grad():
        for batch_idx, (data, _) in enumerate(data_loader):
            data = data.to(device)
            recon_data, _, _ = model(data)
            data = data.cpu()
            recon_data = recon_data.cpu()
            break  # Only visualize the first batch

    plt.figure(figsize=(10, 4))
    for i in range(num_images):
        # Original Image
        plt.subplot(2, num_images, i + 1)
        plt.imshow(data[i].squeeze(), cmap='gray')
        plt.title("Original")
        plt.axis('off')

        # Reconstructed Image
        plt.subplot(2, num_images, i + 1 + num_images)
        plt.imshow(recon_data[i].view(28, 28).squeeze(), cmap='gray')
        plt.title("Reconstructed")
        plt.axis('off')
    plt.show()

# Visualize some reconstructions
visualize_reconstructions(vae, test_loader, num_images=5)

In [None]:
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define the Generator
class Generator(nn.Module):
    def __init__(self, latent_dim=100):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 784),
            nn.Tanh()  # Output is between -1 and 1
        )

    def forward(self, z):
        return self.model(z)

# Download pre-trained GAN weights
url = "https://github.com/znahman/gan-mnist-pytorch/raw/main/generator_mnist.pth"
response = requests.get(url)
with open("generator_mnist.pth", "wb") as f:
    f.write(response.content)

print("Pre-trained GAN weights downloaded!")

# Load the pre-trained Generator
latent_dim = 100
generator = Generator(latent_dim).to(device)
generator.load_state_dict(torch.load("generator_mnist.pth", map_location=device))
generator.eval()  # Set the model to evaluation mode
print("Pre-trained GAN model loaded!")

# Generate and visualize samples
def generate_samples(generator, num_samples=10):
    z = torch.randn(num_samples, latent_dim).to(device)
    with torch.no_grad():
        samples = generator(z).cpu()
    
    # Rescale images from [-1, 1] to [0, 1] for visualization
    samples = (samples + 1) / 2

    plt.figure(figsize=(10, 2))
    for i in range(num_samples):
        plt.subplot(1, num_samples, i + 1)
        plt.imshow(samples[i].view(28, 28).squeeze(), cmap='gray')
        plt.axis('off')
    plt.show()

# Generate new samples
generate_samples(generator, num_samples=10)