# Supervised Experiment

In [13]:
import torch
from torch import nn

from src.adversarial_autoencoder.supervised import SupervisedAdversarialAutoencoder
from src.adversarial_autoencoder.supervised import load_data

### Paper Config

In [14]:
NUM_EPOCHS = 2000

INPUT_DIM = 784
BATCH_SIZE = 100
AE_HIDDEN = 1000
DC_HIDDEN = 1000
LATENT_DIM = 15
NUM_CLASSES = 10
PRIOR_STD = 1.0
recon_loss = nn.MSELoss()
init_recon_lr = 0.001#0.01
init_gen_lr = init_disc_lr = 0.0005#0.1
use_decoder_sigmoid = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Load Training Data

In [15]:
train_loader, test_loader = load_data(BATCH_SIZE, -1)

### Model setup

In [16]:
aae = SupervisedAdversarialAutoencoder(
    input_dim=INPUT_DIM,
    ae_hidden=AE_HIDDEN,
    dc_hidden=DC_HIDDEN,
    latent_dim=LATENT_DIM,
    recon_loss_fn=recon_loss,
    num_classes=NUM_CLASSES,
    init_recon_lr=init_recon_lr,
    init_gen_lr=init_gen_lr,
    init_disc_lr=init_disc_lr,
    use_decoder_sigmoid=use_decoder_sigmoid,
    device = "cuda" if torch.cuda.is_available() else "cpu"
)

### Train Model

In [None]:
# aae.load_weights("weights/supervised_2_2000")
aae.train_mbgd(
    data_loader=train_loader,
    epochs=NUM_EPOCHS,
    prior_std=PRIOR_STD,
)
aae.save_weights("weights/supervised_" + str(LATENT_DIM) + "_" + str(NUM_EPOCHS))

Training Epochs:   0%|          | 7/2000 [00:30<2:31:05,  4.55s/epoch]

### Visualize

In [None]:
import matplotlib.pyplot as plt

def display_reconstructions(model, test_loader, num_images=5):
    model.eval()
    data_iter = iter(test_loader)
    images, labels = next(data_iter)
    images = images.to(model.device)
    images_flattened = images.view(images.size(0), -1)
    labels_one_hot = torch.nn.functional.one_hot(labels, num_classes=model.num_classes).float().to(model.device)
    
    # Pass through the encoder and decoder
    with torch.no_grad():
        z = model.encoder(images_flattened)
        z_cat = torch.cat([z, labels_one_hot], dim=1)
        recon_images = model.decoder(z_cat)
    
    # Plot the original images and their reconstructions
    fig, axes = plt.subplots(2, num_images, figsize=(12, 6))
    
    for i in range(num_images):
        # Original image
        ax = axes[0, i]
        ax.imshow(images[i].cpu().detach().numpy().reshape(28, 28), cmap='gray')
        ax.axis('off')
        ax.set_title(f"Original {labels[i].item()}")
        
        # Reconstructed image
        ax = axes[1, i]
        ax.imshow(recon_images[i].cpu().detach().numpy().reshape(28, 28), cmap='gray')
        ax.axis('off')
        ax.set_title(f"Reconstruction {labels[i].item()}")
    
    plt.tight_layout()
    plt.show()

# Call the function to display the results
display_reconstructions(aae, test_loader)


In [None]:
def generate_image_grid(aae, latent_dim, prior_std, n_classes=10):
    aae.eval()
    device = next(aae.parameters()).device
    fig, axes = plt.subplots(1, n_classes, figsize=(15, 2))

    z = torch.randn(1, latent_dim).to(device) * prior_std
    for i in range(n_classes):
        # New z for each class (optional: use fixed z to see label impact)
        y = torch.zeros(1, n_classes).to(device)
        y[0, i] = 1  # One-hot
        # print(y)

        with torch.no_grad():
            x_hat = aae.decoder(torch.cat([z, y], dim=1))
            x_hat = x_hat.view(28, 28).cpu()

        axes[i].imshow(x_hat, cmap='gray', vmin=0, vmax=1)  # Force [0,1] range
        axes[i].axis('off')
        axes[i].set_title(f'{i}')

    plt.tight_layout()
    plt.show()

# Usage
generate_image_grid(aae, latent_dim=LATENT_DIM, prior_std=PRIOR_STD)
generate_image_grid(aae, latent_dim=LATENT_DIM, prior_std=PRIOR_STD)
generate_image_grid(aae, latent_dim=LATENT_DIM, prior_std=PRIOR_STD)