In [None]:
import lucid
import lucid.nn as nn
import lucid.nn.functional as F
import lucid.optim as optim
import lucid.data as data
import lucid.datasets as datasets
import lucid.models as models

from tqdm import tqdm
import matplotlib.pyplot as plt

In [None]:
batch_size = 128
learning_rate = 1e-3
num_epochs = 10
latent_dim = 2

In [None]:
train_set = datasets.FashionMNIST(root="../../data/fashion_mnist", train=True)
test_set = datasets.FashionMNIST(root="../../data/fashion_mnist",train=False)

In [None]:
train_loader = data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = data.DataLoader(test_set, batch_size=batch_size, shuffle=False)

In [None]:
class VAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 500)
        self.fc2_mean = nn.Linear(500, latent_dim)
        self.fc2_logvar = nn.Linear(500, latent_dim)
        self.fc3 = nn.Linear(latent_dim, 500)
        self.fc4 = nn.Linear(500, 784)
    
    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        mean = self.fc2_mean(h1)
        logvar = self.fc2_logvar(h1).clip(-10.0, 10.0)
        return mean, logvar
    
    def reparameterize(self, mean, logvar):
        std = lucid.exp(0.5 * logvar)
        eps = lucid.random.randn(std.shape)
        return mean + eps * std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return F.sigmoid(self.fc4(h3))
    
    def forward(self, x):
        mean, logvar = self.encode(x)
        z = self.reparameterize(mean, logvar)
        return self.decode(z), mean, logvar

In [None]:
def loss_function(recon_x, x, mean, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * lucid.sum(1 + logvar - mean ** 2 - lucid.exp(logvar))
    return BCE + KLD

In [None]:
def normalize(x):
    norm = x.astype(lucid.Float) / 255.0
    return norm.reshape(-1, 784)

In [None]:
def train(model, train_loader, optimizer, num_epochs):
    losses = []
    model.train()
    for epoch in range(num_epochs):
        epoch_loss = 0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
        for data, _ in progress_bar:
            optimizer.zero_grad()
            data = normalize(data)
            recon_batch, mean, logvar = model(data)

            loss = loss_function(recon_batch, data, mean, logvar)
            loss.backward()
            optimizer.step()

            batch_loss = loss.item()
            losses.append(batch_loss)
            epoch_loss += batch_loss
            progress_bar.set_postfix(loss=batch_loss)
    
    return losses

In [None]:
def test(model, test_loader):
    model.eval()
    test_loss = 0
    losses = []
    with lucid.no_grad():
        progress_bar = tqdm(test_loader, desc="Testing")
        for data, _ in progress_bar:
            recon_batch, mean, logvar = model(data)
            loss = loss_function(recon_batch, data, mean, logvar).eval()
            batch_loss = loss.item()

            test_loss += batch_loss
            losses.append(batch_loss)
            progress_bar.set_postfix(loss=batch_loss)
    
    avg_loss = test_loss / len(test_loader)
    print(f"\nAverage Test Loss: {avg_loss:.4f}")
    return losses

In [None]:
model = VAE()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

models.summarize(model, input_shape=(1, 784))

In [None]:
train_losses = train(model, train_loader, optimizer, num_epochs)