In [1]:
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import DataLoader
import torchvision
import torchvision.utils as vutils

In [2]:
class VAE(nn.Module):
    def __init__(self, latent_dim = 128, img_channels = 3):
        super(VAE, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=img_channels, out_channels=32, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),

            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),

            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),

            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),

            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),

            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2)
        )

        self.fc_mu = nn.Linear(512*4*4, latent_dim)
        self.fc_logvar = nn.Linear(512*4*4, latent_dim)

        self.fc_decode = nn.Linear(latent_dim, 512*4*4)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(in_channels=512, out_channels=512, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),

            nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),

            nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),

            nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),

            nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),

            nn.ConvTranspose2d(in_channels=32, out_channels=3, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )

    def encode(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), 512*4*4)
        return self.fc_mu(x), self.fc_logvar(x)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + std*eps
    
    def decode(self, x):
        x = self.fc_decode(x)
        x = x.view(x.size(0), 512, 4, 4)
        x = self.decoder(x)
        return x
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z)
        return recon, mu, logvar

In [3]:
def loss_function(x, recon, mu, logvar):
    beta = 0.1
    recon_loss = F.mse_loss(recon, x, reduction='mean')
    kl_div = -0.5*torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl_div*beta

In [4]:
def train_model(model, optimizer, dataloader, device, num_epoch):
    model.train()
    losses = []
    for epoch in range(num_epoch):
        total_loss = 0
        for images,_ in dataloader:
            images = images.to(device)

            recon, mu, logvar = model(images)
            optimizer.zero_grad()
            loss = loss_function(images, recon, mu, logvar)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
        avg_loss = total_loss/len(dataloader)
        losses.append(avg_loss)
        print(f"Epoch {epoch+1}/{num_epoch}  loss: {avg_loss:.4f}")
    return model, losses          

In [5]:
def evaluate_model(model, dataloader, device, num_samples=10):
    model.eval()
    real_images, _ = next(iter(dataloader))
    samples = real_images[:num_samples].to(device)

    recon, _, _ = model(samples)
    generated_images = recon

    generated_images = generated_images * 0.5 + 0.5
    samples = samples * 0.5 + 0.5

    real_grid = vutils.make_grid(samples.detach().cpu(), nrow=5, padding=2)
    generated_grid = vutils.make_grid(generated_images.detach().cpu(), nrow=5, padding=2)

    fig, axes = plt.subplots(nrows=2, figsize=(12, 6))

    axes[0].imshow(real_grid.permute(1, 2, 0))
    axes[0].set_title("Real Images")
    axes[0].axis("off")

    axes[1].imshow(generated_grid.permute(1, 2, 0))
    axes[1].set_title("Generated Images")
    axes[1].axis("off")

    plt.tight_layout()
    plt.show()


In [None]:
if __name__ == "__main__":
    batch_size = 64
    lr = 0.0001 
    num_epoch = 300
    device = "cuda" if torch.cuda.is_available() else "cpu"

    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])

    train_data = torchvision.datasets.Flowers102(root="data/", split="train", download=True, transform=transform)
    train_dataloader = DataLoader(train_data, batch_size, shuffle=True)

    model = VAE(latent_dim=256, img_channels=3).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    trained_model, losses = train_model(model, optimizer, train_dataloader, device, num_epoch)

    plt.figure(figsize=(10,6))
    plt.plot(range(1, 301), losses, linestyle='-', color='red')
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("VAE Training Loss vs Epochs")
    plt.show()

    evaluate_model(trained_model, train_dataloader, device, num_samples=10)

