In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.ToTensor()
])

# Load Fashion-MNIST (auto-downloads)
train_dataset = datasets.FashionMNIST(
    root="./data",
    train=True,
    download=True,
    transform=transform
)

test_dataset = datasets.FashionMNIST(
    root="./data",
    train=False,
    download=True,
    transform=transform
)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

In [None]:
import torch.nn as nn
import torch.nn.functional as F

latent_dim = 20

class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU()
        )

        self.fc_mu = nn.Linear(256, latent_dim)
        self.fc_logvar = nn.Linear(256, latent_dim)

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 28*28),
            nn.Sigmoid()
        )

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        x = x.view(-1, 28*28)
        h = self.encoder(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        z = self.reparameterize(mu, logvar)
        out = self.decoder(z)
        return out.view(-1, 1, 28, 28), mu, logvar


In [None]:
def vae_loss(recon_x, x, mu, logvar):
    recon_loss = F.binary_cross_entropy(
        recon_x, x, reduction="sum"
    )
    kl_loss = -0.5 * torch.sum(
        1 + logvar - mu.pow(2) - logvar.exp()
    )
    return recon_loss + kl_loss


In [None]:
model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

epochs = 50
losses = []

for epoch in range(epochs):
    model.train()
    total_loss = 0

    for imgs, _ in train_loader:
        imgs = imgs.to(device)

        recon, mu, logvar = model(imgs)
        loss = vae_loss(recon, imgs, mu, logvar)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_dataset)
    losses.append(avg_loss)

    print(f"Epoch [{epoch+1}/{epochs}] Loss: {avg_loss:.2f}")


In [None]:
import matplotlib.pyplot as plt

model.eval()
with torch.no_grad():
    z = torch.randn(16, latent_dim).to(device)
    generated = model.decoder(z).view(-1, 1, 28, 28)

plt.figure(figsize=(8,8))
for i in range(16):
    plt.subplot(4,4,i+1)
    plt.imshow(generated[i][0].cpu(), cmap="gray")
    plt.axis("off")
plt.show()

In [None]:
imgs, _ = next(iter(test_loader))
imgs = imgs.to(device)

with torch.no_grad():
    recon, _, _ = model(imgs)

plt.figure(figsize=(10,4))
for i in range(8):
    plt.subplot(2,8,i+1)
    plt.imshow(imgs[i][0].cpu(), cmap="gray")
    plt.axis("off")

    plt.subplot(2,8,i+9)
    plt.imshow(recon[i][0].cpu(), cmap="gray")
    plt.axis("off")

plt.show()

In [None]:
plt.plot(losses)
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("VAE Training Loss on FMNIST")
plt.show()

In [None]:
latent_dim = 2   # for 2-D visualization

In [None]:

model.eval()

latent_vectors = []
labels = []

with torch.no_grad():
    for imgs, lbls in test_loader:
        imgs = imgs.to(device)
        _, mu, _ = model(imgs)   # use mean (Î¼) for visualization
        latent_vectors.append(mu.cpu())
        labels.append(lbls)

latent_vectors = torch.cat(latent_vectors).numpy()
labels = torch.cat(labels).numpy()

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(8,6))
scatter = plt.scatter(
    latent_vectors[:, 0],
    latent_vectors[:, 1],
    c=labels,
    cmap="tab10",
    s=5
)

plt.colorbar(scatter)
plt.xlabel("Latent Dimension 1")
plt.ylabel("Latent Dimension 2")
plt.title("2-D Latent Space Visualization of FMNIST (VAE)")
plt.show()