In [None]:
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch

# resize from 1x28x28 to 1x728
transform = transforms.Compose([
	transforms.ToTensor(),
	transforms.Normalize((0.5,), (0.5,)),
	transforms.Lambda(lambda x: torch.flatten(x)),
])

# download the Fashion-MNIST dataset
train_set = datasets.FashionMNIST(root="./data", train=True, download=True, transform=transform)
test_set = datasets.FashionMNIST(root="./data", train=False, download=True, transform=transform)

# prepare the data
train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False)

In [None]:
from torch import nn
import torch

# use same architecture for both AutoEncoder and Variational AutoEncoder
class AutoEncoder(nn.Module):
	def __init__(self, latent_dim=10, hidden_dim=128, resample=False):
		super().__init__()

		self.resample = resample

		# two hidden dimensions into the latent dim
		self.encoder = nn.Sequential(
			nn.Linear(28 * 28, hidden_dim),
			nn.ReLU(),
			nn.Linear(hidden_dim, hidden_dim),
			nn.ReLU(),
			nn.Linear(hidden_dim, latent_dim),
		)

		self.decoder = nn.Sequential(
			nn.Linear(latent_dim, hidden_dim),
			nn.ReLU(),
			nn.Linear(hidden_dim, hidden_dim),
			nn.ReLU(),
			nn.Linear(hidden_dim, 28 * 28),
		)

	def encode(self, x):
		latent = self.encoder(x)

		if self.resample:
			# use first half as mean, second half as std dev
			mu, std_dev = torch.chunk(latent, 2, dim=1)
			epsilon = torch.randn_like(std_dev)
			return mu + epsilon * std_dev
		
		return latent
	
	def decode(self, z):
		return self.decoder(z)

	def forward(self, x):
		encoded = self.encode(x)
		decoded = self.decode(encoded)
		return decoded

In [None]:
import matplotlib.pyplot as plt
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# normal AE
vae_model = AutoEncoder(latent_dim=10, hidden_dim=64, resample=True).to(device)
optimizer = torch.optim.Adam(vae_model.parameters(), lr=1e-3)

def loss_function(reconstructed, original, mean, sigma):
  mse_loss = torch.nn.functional.mse_loss(reconstructed, original) 
  kl_loss = -0.5 * torch.sum(mean.pow(2) + sigma.pow(2) - torch.log(sigma.pow(2)) - 1)
  return mse_loss + kl_loss

for epoch in range(10):
	vae_model.train()

	for images, labels in train_loader:
		images = images.to(device)

	optimizer.zero_grad()
	encoded, mean, sigma = vae_model.encode(images)
	reconstructed = vae_model.decode(encoded)
	loss = loss_function(reconstructed, images, mean, sigma)
	loss.backward()
	optimizer.step()

	print(f"Epoch {epoch+1}, Loss: {loss.item()}")

vae_model.eval()

total_loss = 0

with torch.no_grad():
	for images, labels in test_loader:
		images = images.to(device)

		reconstructed = vae_model(images)
		loss = torch.nn.functional.mse_loss(reconstructed, images)
		total_loss += loss.item()

average_loss = total_loss / len(test_loader)

print(f"Test Loss: {average_loss:.4f}")

for images, labels in test_loader:
	images = images.to(device)

	for i in range(10):
		image = images[i].view(28, 28).cpu().numpy()
		reconstructed = vae_model(images[i]).view(28, 28).detach().cpu().numpy()
		
		fig, axes = plt.subplots(1, 2)
		axes[0].imshow(image, cmap='gray')
		axes[0].set_title('Original Image')
		axes[1].imshow(reconstructed, cmap='gray')
		axes[1].set_title('Reconstructed Image')
		plt.show()
  
	break
	
		

Epoch 1, Loss: 0.10277780145406723
Epoch 2, Loss: 0.08608654886484146


KeyboardInterrupt: 