# Testing VAE Model Module

This notebook tests the newly extracted VAE model from `src.models.vae`

In [1]:
# Add project root to Python path
import sys
from pathlib import Path
project_root = Path().absolute().parent
sys.path.insert(0, str(project_root))

# Import the extracted VAE module
from src.models import VariationalAutoencoder, variational_autoencoder_loss_function
import torch

print('✓ Imports successful!')

✓ Imports successful!


In [2]:
# Create a VAE model
vae = VariationalAutoencoder(
    input_dimension=500,      # Number of species
    latent_dimension=16,      # Embedding size
    hidden_dimension=512      # Hidden layer size
)

print(f"Model created successfully!")
print(f"Total parameters: {sum(p.numel() for p in vae.parameters()):,}")

Model created successfully!
Total parameters: 1,588,756


In [3]:
# Test forward pass with binary data (0 or 1, like species presence/absence)
batch_size = 32
input_data = torch.randint(0, 2, (batch_size, 500)).float()  # Binary data: 0 or 1

reconstructed, mu, logvar = vae(input_data)

print(f"Input shape: {input_data.shape}")
print(f"Reconstructed shape: {reconstructed.shape}")
print(f"Latent mean shape: {mu.shape}")
print(f"Latent logvar shape: {logvar.shape}")

Input shape: torch.Size([32, 500])
Reconstructed shape: torch.Size([32, 500])
Latent mean shape: torch.Size([32, 16])
Latent logvar shape: torch.Size([32, 16])


In [4]:
# Test loss computation
total_loss, recon_loss, kl_loss = variational_autoencoder_loss_function(
    reconstructed_input=reconstructed,
    original_input=input_data,
    latent_mean=mu,
    latent_log_variance=logvar
)

print(f"Total loss: {total_loss.item():.2f}")
print(f"Reconstruction loss: {recon_loss.item():.2f}")
print(f"KL divergence loss: {kl_loss.item():.2f}")

Total loss: 11093.29
Reconstruction loss: 11092.56
KL divergence loss: 0.73


In [5]:
# Test getting embeddings
with torch.no_grad():
    embeddings, _ = vae.encode(input_data)
    
print(f"Embeddings shape: {embeddings.shape}")
print(f"Sample embedding: {embeddings[0]}")

Embeddings shape: torch.Size([32, 16])
Sample embedding: tensor([-0.0048, -0.0111, -0.0474,  0.0668,  0.0026, -0.0388, -0.0028,  0.0004,
        -0.0043,  0.0419,  0.0473, -0.0585,  0.0088,  0.0506,  0.0463, -0.0226])
