# Artifex Framework: Modern VAE Training on CelebA

This notebook demonstrates the **modern Artifex framework** for developing and training generative models. It showcases:

## 🎯 Key Features Demonstrated:

### Core Framework Components:
- **Unified Factory System** - Centralized model creation with `ModelFactory`
- **Model Zoo** - Pre-configured model templates for quick experimentation  
- **Device Management** - Automatic GPU/CPU handling with fallback
- **Modality System** - Image modality adapters and processors
- **Official Trainer** - Artifex's production-ready training system
- **Evaluation Framework** - Comprehensive metrics and benchmarking

### Advanced Capabilities:
- **JIT Compilation** - 2-5x speedup with JAX's JIT
- **Mixed Precision** - Optional FP16/BF16 training
- **Configuration Management** - Type-safe Pydantic configurations
- **Checkpoint System** - Model saving/loading with versioning
- **Logging Integration** - Metrics tracking and visualization

## 📚 Learning Objectives:
1. How to use Artifex's factory system for model creation
2. Proper configuration management with unified configs
3. Integration with the official training system
4. Modality-based data handling
5. Comprehensive evaluation and benchmarking
6. Best practices for production-ready code

## 1. Environment Setup and Imports

First, we'll set up the environment and import the necessary Artifex components.

In [None]:
# Standard library imports
import os
import pickle
import time
import warnings
from pathlib import Path
from typing import Any


# Suppress warnings for cleaner output
warnings.filterwarnings("ignore", category=UserWarning, message=".*CUDA.*")
warnings.filterwarnings("ignore", category=FutureWarning)

# JAX and device configuration
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
from flax import nnx


# Configure JAX for optimal performance
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
os.environ["JAX_PLATFORM_NAME"] = "gpu"  # Prefer GPU if available
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["XLA_FLAGS"] = "--xla_gpu_autotune_level=0"

# Artifex imports - Core components
# Artifex imports - Benchmarks and datasets
from artifex.benchmarks.datasets import CelebADataset
from artifex.benchmarks.metrics import FIDMetric, ISMetric

# Artifex imports - Factory and model creation
from artifex.generative_models.core.configuration.unified import (
    ConfigurationType,
    DataConfiguration,
    EvaluationConfiguration,
    ModelConfiguration,
    OptimizerConfiguration,
    SchedulerConfiguration,
    TrainingConfiguration,
)
from artifex.generative_models.core.device_manager import (
    DeviceConfiguration,
    DeviceManager,
    MemoryStrategy,
)
from artifex.generative_models.factory import ModelFactory

# Artifex imports - Modalities
from artifex.generative_models.modalities.image import (
    ImageModality,
    ImageModalityConfig,
    ImageRepresentation,
)

# Artifex imports - Training
from artifex.generative_models.training.trainer import Trainer


print("✅ Artifex Framework Loaded")
print(f"📦 JAX version: {jax.__version__}")
print(f"📦 Flax NNX version: {nnx.__version__ if hasattr(nnx, '__version__') else 'latest'}")

## 1.2 Device Management with Artifex

Artifex provides automatic device management with fallback capabilities.

In [None]:
# Initialize Artifex's Device Manager
device_config = DeviceConfiguration(
    memory_strategy=MemoryStrategy.BALANCED,  # Use 75% of GPU memory
    enable_x64=False,  # Use float32 for speed
    enable_jit=True,  # Enable JIT compilation
)

device_manager = DeviceManager(config=device_config)
device_info = device_manager.capabilities  # Access capabilities attribute directly

print("🖥️  Device Configuration")
print("=" * 60)
print(f"Device Type: {device_info.device_type.value.upper()}")
print(f"Device Count: {device_info.device_count}")

if device_info.device_type.value == "gpu":
    print("✅ GPU Detected")
    if device_info.total_memory_mb:
        print(f"   Total Memory: {device_info.total_memory_mb / 1024:.1f} GB")
    if device_info.compute_capability:
        print(f"   Compute Capability: {device_info.compute_capability}")
    if device_info.cuda_version:
        print(f"   CUDA Version: {device_info.cuda_version}")
    print(f"   Mixed Precision: {'✓' if device_info.supports_mixed_precision else '✗'}")
else:
    print("⚠️  Running on CPU - Training will be slower")
    print("   For GPU support, ensure CUDA is properly installed")

# Get more detailed device information
detailed_info = device_manager.get_device_info()
print(f"\nJAX Backend: {detailed_info['backend']}")
print(f"Default Device: {detailed_info['default_device']}")

# Configure HuggingFace cache
os.environ["HF_HOME"] = "/media/mahdi/ssd23/Data/huggingface"
os.environ["HUGGINGFACE_HUB_CACHE"] = "/media/mahdi/ssd23/Data/huggingface"
os.environ["HF_DATASETS_CACHE"] = "/media/mahdi/ssd23/Data/huggingface/datasets"

## 2. Model Configuration with Artifex's Unified System

Artifex uses a unified configuration system with type-safe Pydantic models.

In [None]:
def create_vae_configuration(
    latent_dim: int = 512,
    image_size: int = 64,
    beta: float = 1.0,
    learning_rate: float = 2e-4,
    batch_size: int = 256,
    num_epochs: int = 100,
) -> dict[str, Any]:
    """Create comprehensive configuration for VAE training.

    This demonstrates Artifex's unified configuration system where all
    aspects of the experiment are configured in one place.
    """

    # Model configuration
    model_config = ModelConfiguration(
        name=f"vae_celeba_{image_size}x{image_size}",
        type=ConfigurationType.MODEL,
        model_class="artifex.generative_models.models.vae.base.VAE",
        # Architecture parameters
        input_dim=(image_size, image_size, 3),
        output_dim=latent_dim,
        hidden_dims=[64, 128, 256, 512],
        activation="relu",
        # VAE-specific parameters
        parameters={
            "encoder_type": "cnn",  # CNN encoder for images
            "decoder_type": "cnn",  # CNN decoder for images
            "kl_weight": beta,
            "reconstruction_loss": "mse",
            "beta": beta,
            "decoder_dims": [512, 256, 128, 64],  # Reverse of encoder
        },
        # Metadata for tracking
        metadata={
            "modality": "image",
            "dataset": "celeba",
            "image_resolution": (image_size, image_size),
            "color_channels": 3,
            "architecture": "convolutional",
            "variant": "beta-vae" if beta != 1.0 else "standard-vae",
        },
    )

    # Optimizer configuration (must be created first)
    optimizer_config = OptimizerConfiguration(
        name="adamw_optimizer",
        type=ConfigurationType.OPTIMIZER,
        optimizer_type="adamw",
        learning_rate=learning_rate,
        weight_decay=1e-5,
        beta1=0.9,
        beta2=0.999,
        eps=1e-8,
    )

    # Scheduler configuration (optional)
    scheduler_config = SchedulerConfiguration(
        name="cosine_scheduler",
        type=ConfigurationType.SCHEDULER,
        scheduler_type="cosine",
        warmup_steps=1000,
        min_lr_ratio=0.01,
    )

    # Training configuration (now with optimizer)
    training_config = TrainingConfiguration(
        name="vae_celeba_training",
        type=ConfigurationType.TRAINING,
        num_epochs=num_epochs,
        batch_size=batch_size,
        gradient_clip_norm=1.0,
        optimizer=optimizer_config,  # Pass the optimizer config object
        scheduler=scheduler_config,  # Pass the scheduler config object
        checkpoint_dir=Path("./checkpoints"),
        save_frequency=1000,
    )

    # Data configuration
    data_config = DataConfiguration(
        name="celeba_data",
        type=ConfigurationType.DATA,
        dataset_name="celeba",
        data_dir=Path("/media/mahdi/ssd23/Data/huggingface"),
        split="train",
        num_workers=4,
        pin_memory=True,
    )

    # Evaluation configuration
    eval_config = EvaluationConfiguration(
        name="vae_celeba_evaluation",
        type=ConfigurationType.EVALUATION,
        metrics=["fid", "is", "reconstruction_mse", "kl_divergence"],
        metric_params={
            "fid": {"mock_inception": True},
            "is": {"splits": 10},
        },
        eval_batch_size=32,
        num_eval_samples=1000,
    )

    # Image modality configuration
    modality_config = ImageModalityConfig(
        representation=ImageRepresentation.RGB,
        height=image_size,
        width=image_size,
        channels=3,
        normalize=True,
        augmentation=False,
    )

    return {
        "model": model_config,
        "training": training_config,
        "optimizer": optimizer_config,
        "scheduler": scheduler_config,
        "data": data_config,
        "evaluation": eval_config,
        "modality": modality_config,
    }


# Create all configurations
configs = create_vae_configuration(
    latent_dim=512,
    image_size=64,
    beta=1.0,
    learning_rate=2e-4,
    batch_size=256,
    num_epochs=100,
)

print("📋 Configuration Summary")
print("=" * 60)
print(f"Model: {configs['model'].name}")
print("Architecture: CNN Encoder-Decoder")
print(f"Latent Dimension: {configs['model'].output_dim}")
print(f"Input Shape: {configs['model'].input_dim}")
print(f"Hidden Layers: {configs['model'].hidden_dims}")
print("Training:")
print(f"  Epochs: {configs['training'].num_epochs}")
print(f"  Batch Size: {configs['training'].batch_size}")
print(f"  Learning Rate: {configs['optimizer'].learning_rate}")
print(f"  Optimizer: {configs['optimizer'].optimizer_type}")
print(f"  Scheduler: {configs['scheduler'].scheduler_type}")
print(f"Evaluation Metrics: {', '.join(configs['evaluation'].metrics)}")

## 3. Model Creation with Artifex's Factory System

Artifex provides multiple ways to create models:
1. Using the factory with custom configuration
2. Using pre-configured models from the Model Zoo
3. Direct instantiation with modality adapters

In [None]:
# Initialize RNGs for reproducibility
rngs = nnx.Rngs(42)

print("🔨 Creating VAE Model with Artifex Factory")
print("=" * 60)

# Method 1: Using the factory system with modality (recommended)
factory = ModelFactory()

# Create the model with factory - modality will apply image-specific adapters
model = factory.create(
    config=configs["model"],
    modality="image",  # Applies image modality adapters
    rngs=rngs,
)

print(f"✅ Model created: {type(model).__name__}")
print(f"   Encoder: {type(model.encoder).__name__} (CNN-based)")
print(f"   Decoder: {type(model.decoder).__name__} (CNN-based)")
print(f"   Latent dim: {model.latent_dim}")

# Alternative Method 2: Using the factory's create method directly
# model = factory.create(
#     config=configs["model"],
#     modality="image",
#     rngs=rngs,
# )

# Alternative Method 3: Check if Model Zoo has pre-configured models
# zoo = ModelZoo()
# try:
#     zoo_config = zoo.get_config("vae_celeba_64x64")
#     model = zoo.create_model("vae_celeba_64x64", rngs=rngs)
#     print("   Using pre-configured model from Zoo")
# except KeyError:
#     print("   No pre-configured model found in Zoo, using custom config")

# Create image modality for data processing (separate from model adapters)
image_modality = ImageModality(config=configs["modality"], rngs=rngs)

print("\n📸 Image Modality Configured:")
print(f"   Representation: {configs['modality'].representation.value}")
print(f"   Resolution: {configs['modality'].height}x{configs['modality'].width}")
print(f"   Channels: {configs['modality'].channels}")
print(f"   Adapter available: {hasattr(image_modality, 'get_adapter')}")

## 4. Dataset Setup with Artifex's Data System

Artifex provides integrated dataset handling with automatic preprocessing and batching.

In [None]:
# Dataset configuration
DATASET_SIZE = 10000  # Start with smaller size for quick testing
# DATASET_SIZE = 100000  # Use this for full training

print("📊 Loading CelebA Dataset")
print("=" * 60)

# Create dataset using Artifex's benchmark datasets
train_dataset = CelebADataset(
    data_path=configs["data"].data_dir,
    num_samples=DATASET_SIZE,
    image_size=configs["modality"].height,
    include_attributes=True,
    split="train",
    rngs=rngs,
)

# Create validation dataset
val_dataset = CelebADataset(
    data_path=configs["data"].data_dir,
    num_samples=min(1000, DATASET_SIZE // 10),  # 10% for validation
    image_size=configs["modality"].height,
    include_attributes=True,
    split="valid",
    rngs=rngs,
)

print("✅ Datasets loaded:")
print(f"   Training samples: {len(train_dataset)}")
print(f"   Validation samples: {len(val_dataset)}")
print(f"   Image size: {configs['modality'].height}x{configs['modality'].width}")
print("   Attributes included: ✓")

# Create data loaders (Artifex style)


def create_data_loader(dataset, batch_size: int, shuffle: bool = True):
    """Create a data loader for the dataset."""

    def data_loader():
        num_samples = len(dataset)
        indices = jnp.arange(num_samples)

        if shuffle:
            key = jax.random.PRNGKey(0)
            indices = jax.random.permutation(key, indices)

        for start_idx in range(0, num_samples, batch_size):
            batch = dataset.get_batch(
                batch_size=min(batch_size, num_samples - start_idx), start_idx=start_idx
            )
            yield batch

    return data_loader


train_loader = create_data_loader(
    train_dataset, batch_size=configs["training"].batch_size, shuffle=True
)

val_loader = create_data_loader(
    val_dataset, batch_size=configs["evaluation"].eval_batch_size, shuffle=False
)

# Display sample images
print("\n🖼️ Sample Images from Dataset:")
sample_batch = train_dataset.get_batch(batch_size=8, start_idx=0)
sample_images = sample_batch["images"]

fig, axes = plt.subplots(1, 8, figsize=(16, 2))
for i in range(8):
    axes[i].imshow(np.clip(sample_images[i], 0, 1))
    axes[i].axis("off")
plt.suptitle("CelebA Sample Images", fontsize=12)
plt.tight_layout()
plt.show()

## 5. Training with Artifex's Official Trainer

Artifex provides a production-ready `Trainer` class that handles:
- Automatic JIT compilation
- Gradient accumulation
- Checkpointing
- Metrics logging
- Early stopping
- Learning rate scheduling

In [None]:
# Create optimizer using Artifex's configuration


def create_optimizer(config: OptimizerConfiguration):
    """Create optimizer from configuration."""
    if config.optimizer_type == "adamw":
        return optax.adamw(
            learning_rate=config.learning_rate,
            weight_decay=config.weight_decay,
            b1=config.beta1,
            b2=config.beta2,
            eps=config.eps,
        )
    elif config.optimizer_type == "adam":
        return optax.adam(
            learning_rate=config.learning_rate,
            b1=config.beta1,
            b2=config.beta2,
            eps=config.eps,
        )
    else:
        raise ValueError(f"Unknown optimizer: {config.optimizer_type}")


# Create the optimizer
optimizer = create_optimizer(configs["optimizer"])

# Define custom loss function for VAE


def vae_loss_fn(model, batch, training: bool = True):
    """VAE loss function with reconstruction and KL divergence.

    Note: We don't use @nnx.jit decorator here as the function will be JIT-compiled
    by the trainer or custom training loop.
    """
    images = batch["images"]

    # Forward pass
    mean, log_var = model.encode(images)
    z = model.reparameterize(mean, log_var) if training else mean
    reconstructed = model.decode(z)

    # Reconstruction loss (MSE)
    recon_loss = jnp.mean((images - reconstructed) ** 2)

    # KL divergence loss
    kl_loss = -0.5 * jnp.mean(1 + log_var - mean**2 - jnp.exp(log_var))

    # Total loss with beta weighting
    beta = configs["model"].parameters.get("beta", 1.0)
    total_loss = recon_loss + beta * kl_loss

    # Return loss and metrics
    metrics = {
        "loss": total_loss,
        "reconstruction_loss": recon_loss,
        "kl_loss": kl_loss,
        "beta": beta,
    }

    return total_loss, metrics


# Initialize Artifex Trainer
print("🚀 Initializing Artifex Trainer")
print("=" * 60)

trainer = Trainer(
    model=model,
    training_config=configs["training"],
    optimizer=optimizer,
    train_data_loader=train_loader,
    val_data_loader=val_loader,
    loss_fn=vae_loss_fn,
    rng=jax.random.PRNGKey(42),
    workdir="./vae_celeba_checkpoints",
    checkpoint_dir="./vae_celeba_checkpoints",
    save_interval=configs["training"].save_frequency,
)

print("✅ Trainer initialized with:")
print(f"   Model: {configs['model'].name}")
print(f"   Optimizer: {configs['optimizer'].optimizer_type}")
print(f"   Learning rate: {configs['optimizer'].learning_rate}")
print(f"   Batch size: {configs['training'].batch_size}")
print(f"   Checkpointing: Every {configs['training'].save_frequency} steps")

## 5.1 Custom Training Loop with JIT Compilation

For demonstration, we'll also show a custom optimized training loop that showcases JAX's JIT compilation capabilities.

In [None]:
class OptimizedVAETrainer:
    """Custom optimized VAE trainer showcasing JAX JIT compilation."""

    def __init__(self, model, config: ModelConfiguration, learning_rate: float = 2e-4):
        """Initialize the optimized trainer."""
        self.model = model
        self.config = config

        # Create optimizer
        tx = optax.adam(learning_rate)
        self.optimizer = nnx.Optimizer(model, tx)

        # Create JIT-compiled training step
        self.train_step = self._create_train_step()
        self.eval_step = self._create_eval_step()

    def _create_train_step(self):
        """Create JIT-compiled training step."""

        @nnx.jit
        def train_step(model, optimizer, images):
            """Single training step (JIT-compiled)."""

            def loss_fn(model):
                # Forward pass
                mean, log_var = model.encode(images)
                z = model.reparameterize(mean, log_var)
                reconstructed = model.decode(z)

                # Losses
                recon_loss = jnp.mean((images - reconstructed) ** 2)
                kl_loss = -0.5 * jnp.mean(1 + log_var - mean**2 - jnp.exp(log_var))

                # Total loss
                beta = self.config.parameters.get("kl_weight", 1.0)
                total_loss = recon_loss + beta * kl_loss

                return total_loss, {
                    "total_loss": total_loss,
                    "reconstruction_loss": recon_loss,
                    "kl_loss": kl_loss,
                }

            # Compute gradients and update
            (loss, metrics), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model)
            optimizer.update(grads)

            return metrics

        return train_step

    def _create_eval_step(self):
        """Create JIT-compiled evaluation step."""

        @nnx.jit
        def eval_step(model, images):
            """Single evaluation step (JIT-compiled)."""
            # Forward pass (no reparameterization for eval)
            mean, log_var = model.encode(images)
            reconstructed = model.decode(mean)  # Use mean directly

            # Losses
            recon_loss = jnp.mean((images - reconstructed) ** 2)
            kl_loss = -0.5 * jnp.mean(1 + log_var - mean**2 - jnp.exp(log_var))

            beta = self.config.parameters.get("kl_weight", 1.0)
            total_loss = recon_loss + beta * kl_loss

            return {
                "total_loss": total_loss,
                "reconstruction_loss": recon_loss,
                "kl_loss": kl_loss,
            }

        return eval_step

    def train_epoch(self, dataset, batch_size: int = 32):
        """Train for one epoch."""
        epoch_metrics = {"total_loss": 0, "reconstruction_loss": 0, "kl_loss": 0}
        num_batches = 0

        # Training loop
        for start_idx in range(0, len(dataset), batch_size):
            batch = dataset.get_batch(
                batch_size=min(batch_size, len(dataset) - start_idx), start_idx=start_idx
            )
            images = batch["images"]

            # Training step (JIT-compiled)
            metrics = self.train_step(self.model, self.optimizer, images)

            # Accumulate metrics
            for key in epoch_metrics:
                epoch_metrics[key] += float(metrics[key])
            num_batches += 1

        # Average metrics
        for key in epoch_metrics:
            epoch_metrics[key] /= num_batches

        return epoch_metrics

    def evaluate(self, dataset, batch_size: int = 32):
        """Evaluate on a dataset."""
        eval_metrics = {"total_loss": 0, "reconstruction_loss": 0, "kl_loss": 0}
        num_batches = 0

        for start_idx in range(0, len(dataset), batch_size):
            batch = dataset.get_batch(
                batch_size=min(batch_size, len(dataset) - start_idx), start_idx=start_idx
            )
            images = batch["images"]

            # Evaluation step (JIT-compiled)
            metrics = self.eval_step(self.model, images)

            # Accumulate metrics
            for key in eval_metrics:
                eval_metrics[key] += float(metrics[key])
            num_batches += 1

        # Average metrics
        for key in eval_metrics:
            eval_metrics[key] /= num_batches

        return eval_metrics


# Create custom trainer for demonstration
custom_trainer = OptimizedVAETrainer(
    model, configs["model"], learning_rate=configs["optimizer"].learning_rate
)

print("⚡ Custom Optimized Trainer Created")
print("   Features:")
print("   - JIT-compiled training steps")
print("   - Separate train/eval paths")
print("   - Efficient batch processing")
print("   - Automatic gradient computation")

## 5.2 Training Execution

Now let's train the model. You can choose between:
1. Artifex's official Trainer (recommended for production)
2. Custom optimized trainer (for learning and experimentation)

In [None]:
# Quick training with custom trainer for demonstration
NUM_EPOCHS = 10  # Reduce for quick demo, use 100+ for full training
BATCH_SIZE = configs["training"].batch_size

print("🚀 Starting Training")
print("=" * 60)
print(f"Epochs: {NUM_EPOCHS}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Learning rate: {configs['optimizer'].learning_rate}")
print(f"Beta (KL weight): {configs['model'].parameters.get('beta', 1.0)}")

# Training history
history = {
    "train_loss": [],
    "val_loss": [],
    "reconstruction_loss": [],
    "kl_loss": [],
    "epoch_times": [],
}

# Training loop
for epoch in range(NUM_EPOCHS):
    epoch_start = time.time()

    # Train for one epoch
    train_metrics = custom_trainer.train_epoch(train_dataset, BATCH_SIZE)

    # Evaluate on validation set
    val_metrics = custom_trainer.evaluate(val_dataset, configs["evaluation"].eval_batch_size)

    # Record metrics
    history["train_loss"].append(train_metrics["total_loss"])
    history["val_loss"].append(val_metrics["total_loss"])
    history["reconstruction_loss"].append(train_metrics["reconstruction_loss"])
    history["kl_loss"].append(train_metrics["kl_loss"])

    epoch_time = time.time() - epoch_start
    history["epoch_times"].append(epoch_time)

    # Print progress
    print(f"\n📈 Epoch {epoch + 1}/{NUM_EPOCHS} (Time: {epoch_time:.1f}s)")
    print(f"   Train Loss: {train_metrics['total_loss']:.4f}")
    print(f"   Val Loss: {val_metrics['total_loss']:.4f}")
    print(f"   Recon Loss: {train_metrics['reconstruction_loss']:.4f}")
    print(f"   KL Loss: {train_metrics['kl_loss']:.4f}")

total_time = sum(history["epoch_times"])
print(f"\n✅ Training completed in {total_time:.1f}s ({total_time / 60:.1f} minutes)")
print(f"   Average epoch time: {np.mean(history['epoch_times']):.1f}s")
print(f"   Final train loss: {history['train_loss'][-1]:.4f}")
print(f"   Final val loss: {history['val_loss'][-1]:.4f}")

## 6. Training Visualization

Let's visualize the training progress and metrics.

In [None]:
def plot_training_history(history: dict[str, list[float]]):
    """Plot comprehensive training history."""

    fig, axes = plt.subplots(2, 2, figsize=(12, 8))
    epochs = range(1, len(history["train_loss"]) + 1)

    # Total loss (train vs val)
    axes[0, 0].plot(epochs, history["train_loss"], "b-", linewidth=2, label="Train Loss")
    axes[0, 0].plot(epochs, history["val_loss"], "r--", linewidth=2, label="Val Loss")
    axes[0, 0].set_xlabel("Epoch")
    axes[0, 0].set_ylabel("Loss")
    axes[0, 0].set_title("Training vs Validation Loss")
    axes[0, 0].grid(True, alpha=0.3)
    axes[0, 0].legend()

    # Reconstruction loss
    axes[0, 1].plot(epochs, history["reconstruction_loss"], "g-", linewidth=2)
    axes[0, 1].set_xlabel("Epoch")
    axes[0, 1].set_ylabel("MSE")
    axes[0, 1].set_title("Reconstruction Loss")
    axes[0, 1].grid(True, alpha=0.3)

    # KL divergence
    axes[1, 0].plot(epochs, history["kl_loss"], "m-", linewidth=2)
    axes[1, 0].set_xlabel("Epoch")
    axes[1, 0].set_ylabel("KL Divergence")
    axes[1, 0].set_title("KL Divergence Loss")
    axes[1, 0].grid(True, alpha=0.3)

    # Epoch times
    axes[1, 1].bar(epochs, history["epoch_times"], color="orange", alpha=0.7)
    axes[1, 1].set_xlabel("Epoch")
    axes[1, 1].set_ylabel("Time (seconds)")
    axes[1, 1].set_title("Training Time per Epoch")
    axes[1, 1].grid(True, alpha=0.3)

    plt.suptitle("VAE Training Progress", fontsize=14, fontweight="bold")
    plt.tight_layout()
    plt.show()

    # Print summary statistics
    print("\n📊 Training Summary:")
    print("=" * 60)
    print(f"Final train loss: {history['train_loss'][-1]:.4f}")
    print(f"Final val loss: {history['val_loss'][-1]:.4f}")
    print(f"Final reconstruction loss: {history['reconstruction_loss'][-1]:.4f}")
    print(f"Final KL loss: {history['kl_loss'][-1]:.4f}")
    print(f"Average epoch time: {np.mean(history['epoch_times']):.2f}s")
    print(f"Total training time: {sum(history['epoch_times']):.2f}s")

    # Check for overfitting
    if history["val_loss"][-1] > history["train_loss"][-1] * 1.2:
        print("\n⚠️  Warning: Possible overfitting detected (val loss > 1.2x train loss)")


# Plot the training history
plot_training_history(history)

## 7. Model Evaluation with Artifex's Benchmark Suite

Artifex provides comprehensive evaluation metrics through its benchmark framework.

In [None]:
# Helper functions for generation and reconstruction


@nnx.jit
def generate_samples(model, num_samples: int, latent_dim: int, key):
    """Generate new samples from the latent space."""
    z = jax.random.normal(key, (num_samples, latent_dim))
    return model.decode(z)


@nnx.jit
def reconstruct_images(model, images):
    """Reconstruct images through the VAE."""
    mean, log_var = model.encode(images)
    z = model.reparameterize(mean, log_var)
    reconstructed = model.decode(z)
    return reconstructed, mean, log_var


def evaluate_model_comprehensive(model, dataset, eval_config, rngs):
    """Comprehensive model evaluation using Artifex's metrics."""

    print("\n🔍 Comprehensive Model Evaluation")
    print("=" * 60)

    # Generate samples for evaluation
    num_eval_samples = min(64, len(dataset))
    key = rngs.sample()
    generated = generate_samples(model, num_eval_samples, model.latent_dim, key)

    # Get real images for comparison
    real_batch = dataset.get_batch(batch_size=num_eval_samples, start_idx=0)
    real_images = real_batch["images"]

    # Reconstruct images
    reconstructed, mean, log_var = reconstruct_images(model, real_images)

    # Calculate basic metrics
    recon_mse = float(jnp.mean((real_images - reconstructed) ** 2))
    kl_div = float(-0.5 * jnp.mean(1 + log_var - mean**2 - jnp.exp(log_var)))

    # Calculate FID score (using mock for demo)
    fid_metric = FIDMetric(rngs=rngs, config=eval_config)
    fid_result = fid_metric.compute(real_images, generated)
    fid_score = (
        fid_result.get("fid", fid_result) if isinstance(fid_result, dict) else float(fid_result)
    )

    # Calculate Inception Score
    is_metric = ISMetric(rngs=rngs, config=eval_config)
    is_result = is_metric.compute(generated)
    is_score = is_result.get("is", is_result) if isinstance(is_result, dict) else float(is_result)

    # Latent space statistics
    latent_mean = float(jnp.mean(mean))
    latent_std = float(jnp.mean(jnp.exp(0.5 * log_var)))

    # Compile results
    results = {
        "reconstruction_mse": recon_mse,
        "kl_divergence": kl_div,
        "fid_score": fid_score,
        "inception_score": is_score,
        "latent_mean": latent_mean,
        "latent_std": latent_std,
    }

    # Print results
    print("\n📈 Evaluation Results:")
    print(f"   Reconstruction MSE: {results['reconstruction_mse']:.6f}")
    print(f"   KL Divergence: {results['kl_divergence']:.4f}")
    print(f"   FID Score: {results['fid_score']:.2f} (lower is better)")
    print(f"   Inception Score: {results['inception_score']:.2f} (higher is better)")
    print("   Latent Space:")
    print(f"      Mean: {results['latent_mean']:.4f}")
    print(f"      Std: {results['latent_std']:.4f}")

    # Quality assessment
    print("\n🎯 Quality Assessment:")
    if results["reconstruction_mse"] < 0.01:
        print("   ✅ Excellent reconstruction quality")
    elif results["reconstruction_mse"] < 0.05:
        print("   ✓ Good reconstruction quality")
    else:
        print("   ⚠️ Reconstruction could be improved")

    if results["fid_score"] < 50:
        print("   ✅ Excellent generation quality (FID < 50)")
    elif results["fid_score"] < 100:
        print("   ✓ Good generation quality (FID < 100)")
    else:
        print("   ⚠️ Generation quality needs improvement")

    return results, generated, reconstructed


# Run comprehensive evaluation
eval_results, generated_images, reconstructed_images = evaluate_model_comprehensive(
    model, val_dataset, configs["evaluation"], rngs
)

## 8. Comprehensive Visualization

Visualize reconstruction quality, generation quality, and latent space properties.

In [None]:
def visualize_vae_results(dataset, model, rngs, num_samples: int = 8):
    """Comprehensive visualization of VAE results."""

    print("\n🎨 Generating Comprehensive Visualizations")

    # Get real images
    batch = dataset.get_batch(batch_size=num_samples, start_idx=0)
    real_images = batch["images"]

    # Reconstruct
    reconstructed, _, _ = reconstruct_images(model, real_images)

    # Generate new samples
    key = rngs.sample()
    generated = generate_samples(model, num_samples, model.latent_dim, key)

    # Interpolation in latent space
    key1, key2 = jax.random.split(rngs.sample())
    z1 = jax.random.normal(key1, (1, model.latent_dim))
    z2 = jax.random.normal(key2, (1, model.latent_dim))
    alphas = jnp.linspace(0, 1, num_samples)
    interpolated = jnp.array([model.decode((1 - a) * z1 + a * z2)[0] for a in alphas])

    # Create comprehensive visualization
    fig, axes = plt.subplots(4, num_samples, figsize=(num_samples * 2, 8))

    titles = ["Original", "Reconstructed", "Generated", "Interpolated"]
    images_list = [real_images, reconstructed, generated, interpolated]

    for row, (title, images) in enumerate(zip(titles, images_list)):
        for col in range(num_samples):
            img = np.clip(images[col], 0, 1)
            axes[row, col].imshow(img)
            axes[row, col].axis("off")

            if col == 0:
                axes[row, col].set_ylabel(title, fontsize=10, fontweight="bold")

            # Add reconstruction error for reconstructed images
            if row == 1:
                mse = float(jnp.mean((real_images[col] - reconstructed[col]) ** 2))
                axes[row, col].set_title(f"MSE: {mse:.3f}", fontsize=8)

    plt.suptitle(
        "VAE Results: Reconstruction, Generation, and Interpolation", fontsize=14, fontweight="bold"
    )
    plt.tight_layout()
    plt.show()


# Visualize results
visualize_vae_results(val_dataset, model, rngs, num_samples=8)

## 9. Latent Space Analysis

Analyze the structure and properties of the learned latent space.

In [None]:
def analyze_latent_space(model, dataset, num_samples: int = 1000):
    """Analyze the learned latent space structure."""

    print("\n🔬 Analyzing Latent Space Structure")
    print("=" * 60)

    # Encode a batch of images
    num_samples = min(num_samples, len(dataset))
    batch = dataset.get_batch(batch_size=num_samples, start_idx=0)
    images = batch["images"]

    mean, log_var = model.encode(images)
    z = model.reparameterize(mean, log_var)

    # Statistics
    z_mean = jnp.mean(z, axis=0)
    z_std = jnp.std(z, axis=0)

    # Create visualizations
    fig, axes = plt.subplots(2, 3, figsize=(15, 8))

    # 1. Distribution of latent means
    axes[0, 0].hist(z_mean, bins=30, alpha=0.7, color="blue", edgecolor="black")
    axes[0, 0].set_xlabel("Mean value")
    axes[0, 0].set_ylabel("Frequency")
    axes[0, 0].set_title("Distribution of Latent Dimension Means")
    axes[0, 0].grid(True, alpha=0.3)
    axes[0, 0].axvline(0, color="red", linestyle="--", alpha=0.5)

    # 2. Distribution of latent stds
    axes[0, 1].hist(z_std, bins=30, alpha=0.7, color="green", edgecolor="black")
    axes[0, 1].set_xlabel("Std deviation")
    axes[0, 1].set_ylabel("Frequency")
    axes[0, 1].set_title("Distribution of Latent Dimension Stds")
    axes[0, 1].grid(True, alpha=0.3)
    axes[0, 1].axvline(1, color="red", linestyle="--", alpha=0.5)

    # 3. Active dimensions (std > threshold)
    threshold = 0.1
    active_dims = jnp.sum(z_std > threshold)
    axes[0, 2].bar(
        ["Active", "Inactive"],
        [int(active_dims), len(z_std) - int(active_dims)],
        color=["green", "gray"],
    )
    axes[0, 2].set_title(f"Active Dimensions (std > {threshold})")
    axes[0, 2].set_ylabel("Count")

    # 4. Correlation matrix (subset)
    subset_size = min(50, z.shape[1])
    z_subset = z[:, :subset_size]
    corr = jnp.corrcoef(z_subset.T)

    im = axes[1, 0].imshow(corr, cmap="coolwarm", vmin=-1, vmax=1)
    axes[1, 0].set_title(f"Latent Correlations (first {subset_size} dims)")
    axes[1, 0].set_xlabel("Dimension")
    axes[1, 0].set_ylabel("Dimension")
    plt.colorbar(im, ax=axes[1, 0])

    # 5. 2D PCA projection
    if z.shape[0] > 2:
        # Simple 2D projection using first two principal components
        z_centered = z - jnp.mean(z, axis=0)
        cov = jnp.cov(z_centered.T)
        eigenvalues, eigenvectors = jnp.linalg.eigh(cov)

        # Sort by eigenvalues
        idx = jnp.argsort(eigenvalues)[::-1]
        eigenvectors = eigenvectors[:, idx]

        # Project to 2D
        z_2d = z_centered @ eigenvectors[:, :2]

        axes[1, 1].scatter(z_2d[:, 0], z_2d[:, 1], alpha=0.5, s=10)
        axes[1, 1].set_title("2D PCA Projection of Latent Space")
        axes[1, 1].set_xlabel("PC1")
        axes[1, 1].set_ylabel("PC2")
        axes[1, 1].grid(True, alpha=0.3)

    # 6. KL divergence per dimension
    kl_per_dim = 0.5 * (mean**2 + jnp.exp(log_var) - log_var - 1)
    mean_kl_per_dim = jnp.mean(kl_per_dim, axis=0)

    axes[1, 2].plot(mean_kl_per_dim[:100], alpha=0.7)  # Plot first 100 dims
    axes[1, 2].set_title("KL Divergence per Dimension")
    axes[1, 2].set_xlabel("Dimension")
    axes[1, 2].set_ylabel("Mean KL")
    axes[1, 2].grid(True, alpha=0.3)

    plt.suptitle("Latent Space Analysis", fontsize=14, fontweight="bold")
    plt.tight_layout()
    plt.show()

    # Print statistics
    print("\n📊 Latent Space Statistics:")
    print(f"   Latent dimension: {z.shape[1]}")
    print(f"   Active dimensions: {int(active_dims)} / {z.shape[1]}")
    print(f"   Mean of means: {jnp.mean(z_mean):.4f} (target: ~0)")
    print(f"   Mean of stds: {jnp.mean(z_std):.4f} (target: ~1)")
    print(f"   Max correlation: {jnp.max(jnp.abs(corr - jnp.eye(subset_size))):.4f}")
    print(f"   Total KL divergence: {jnp.sum(mean_kl_per_dim):.4f}")

    # Quality assessment
    print("\n🎯 Latent Space Quality:")
    if abs(jnp.mean(z_mean)) < 0.1:
        print("   ✅ Well-centered latent space")
    else:
        print("   ⚠️ Latent space not well-centered")

    if 0.8 < jnp.mean(z_std) < 1.2:
        print("   ✅ Good variance in latent space")
    else:
        print("   ⚠️ Variance issues in latent space")

    if jnp.max(jnp.abs(corr - jnp.eye(subset_size))) < 0.3:
        print("   ✅ Good disentanglement (low correlation)")
    else:
        print("   ⚠️ High correlation between latent dimensions")


# Analyze the latent space
analyze_latent_space(model, val_dataset, num_samples=500)

## 10. Model Persistence with Artifex

Artifex provides utilities for saving and loading trained models.

In [None]:
def save_model_artifex(model, config: ModelConfiguration, path: str = "vae_celeba_artifex.pkl"):
    """Save model using Artifex's recommended approach."""

    print("\n💾 Saving Model with Artifex")
    print("=" * 60)

    # Create checkpoint directory if it doesn't exist
    checkpoint_dir = Path("./checkpoints")
    checkpoint_dir.mkdir(exist_ok=True)

    full_path = checkpoint_dir / path

    # Extract model state
    _, state = nnx.split(model)

    # Save both state and configuration
    checkpoint = {
        "model_state": state,
        "config": config,
        "framework": "artifex",
        "version": "1.0",
    }

    with open(full_path, "wb") as f:
        pickle.dump(checkpoint, f)

    print(f"✅ Model saved to: {full_path}")
    print(f"   Model type: {config.name}")
    print(f"   Architecture: {config.model_class}")
    print(f"   File size: {full_path.stat().st_size / (1024 * 1024):.1f} MB")

    return full_path


def load_model_artifex(path: str, rngs: nnx.Rngs):
    """Load model using Artifex's factory system."""

    print("\n📂 Loading Model with Artifex")
    print("=" * 60)

    full_path = Path("./checkpoints") / path

    if not full_path.exists():
        raise FileNotFoundError(f"Checkpoint not found: {full_path}")

    with open(full_path, "rb") as f:
        checkpoint = pickle.load(f)

    # Extract configuration and state
    config = checkpoint["config"]
    state = checkpoint["model_state"]

    print(f"✅ Checkpoint loaded from: {full_path}")
    print(f"   Model: {config.name}")
    print(f"   Framework: {checkpoint.get('framework', 'unknown')}")

    # Recreate model using factory
    factory = ModelFactory()
    model = factory.create(
        config=config,
        modality="image",
        rngs=rngs,
    )

    # Merge saved state with new model
    graphdef, _ = nnx.split(model)
    model = nnx.merge(graphdef, state)

    print("✅ Model restored successfully")

    return model, config


# Save the trained model
saved_path = save_model_artifex(model, configs["model"], "vae_celeba_trained.pkl")

# Example: Load the model
# loaded_model, loaded_config = load_model_artifex("vae_celeba_trained.pkl", rngs)
# print("Model loaded successfully!")

## 11. Performance Benchmarking

Compare JIT-compiled vs non-JIT performance to demonstrate optimization benefits.

In [None]:
def benchmark_performance(model, rngs):
    """Compare JIT vs non-JIT performance."""

    print("\n⚡ Performance Benchmark: JIT Compilation Benefits")
    print("=" * 60)

    # Create test data
    test_batch_size = 32
    key = rngs.sample()
    test_images = jax.random.normal(key, (test_batch_size, 64, 64, 3))

    # Non-JIT function
    def forward_pass_no_jit(model, images):
        mean, log_var = model.encode(images)
        z = model.reparameterize(mean, log_var)
        reconstructed = model.decode(z)
        loss = jnp.mean((images - reconstructed) ** 2)
        return loss

    # JIT-compiled function
    @nnx.jit
    def forward_pass_jit(model, images):
        mean, log_var = model.encode(images)
        z = model.reparameterize(mean, log_var)
        reconstructed = model.decode(z)
        loss = jnp.mean((images - reconstructed) ** 2)
        return loss

    # Warmup JIT
    print("🔥 Warming up JIT compilation...")
    _ = forward_pass_jit(model, test_images)

    # Benchmark settings
    num_runs = 100
    print(f"📊 Running {num_runs} forward passes...")

    # Non-JIT timing
    start = time.time()
    for _ in range(num_runs):
        _ = forward_pass_no_jit(model, test_images)
    no_jit_time = time.time() - start

    # JIT timing
    start = time.time()
    for _ in range(num_runs):
        _ = forward_pass_jit(model, test_images)
    jit_time = time.time() - start

    # Calculate statistics
    speedup = no_jit_time / jit_time
    time_saved = no_jit_time - jit_time
    percentage_saved = (1 - jit_time / no_jit_time) * 100

    # Create visualization
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))

    # Bar chart comparison
    methods = ["No JIT", "With JIT"]
    times = [no_jit_time, jit_time]
    colors = ["red", "green"]

    axes[0].bar(methods, times, color=colors, alpha=0.7)
    axes[0].set_ylabel("Time (seconds)")
    axes[0].set_title(f"Total Time for {num_runs} Forward Passes")
    axes[0].grid(True, alpha=0.3)

    # Add values on bars
    for i, (method, time_val) in enumerate(zip(methods, times)):
        axes[0].text(i, time_val + 0.05, f"{time_val:.2f}s", ha="center")

    # Speedup visualization
    axes[1].bar(["Speedup"], [speedup], color="blue", alpha=0.7)
    axes[1].set_ylabel("Speedup Factor")
    axes[1].set_title("JIT Compilation Speedup")
    axes[1].axhline(y=1, color="red", linestyle="--", alpha=0.5)
    axes[1].set_ylim(0, max(speedup * 1.2, 2))
    axes[1].text(0, speedup + 0.1, f"{speedup:.2f}x", ha="center")

    plt.suptitle("JAX JIT Compilation Performance Impact", fontsize=14, fontweight="bold")
    plt.tight_layout()
    plt.show()

    # Print results
    print("\n📊 Benchmark Results:")
    print(f"   Without JIT: {no_jit_time:.3f}s ({no_jit_time / num_runs * 1000:.2f}ms per pass)")
    print(f"   With JIT:    {jit_time:.3f}s ({jit_time / num_runs * 1000:.2f}ms per pass)")
    print("   ")
    print(f"   🚀 Speedup:     {speedup:.2f}x faster")
    print(f"   ⏱️  Time saved:  {time_saved:.3f}s ({percentage_saved:.1f}%)")
    print("   ")

    # Provide context
    print("💡 Performance Insights:")
    if speedup > 3:
        print("   ✅ Excellent JIT optimization (>3x speedup)")
    elif speedup > 2:
        print("   ✅ Good JIT optimization (2-3x speedup)")
    elif speedup > 1.5:
        print("   ✓ Moderate JIT optimization (1.5-2x speedup)")
    else:
        print("   ⚠️ Limited JIT benefit (consider larger batch sizes)")

    return {
        "no_jit_time": no_jit_time,
        "jit_time": jit_time,
        "speedup": speedup,
        "time_saved": time_saved,
    }


# Run performance benchmark
perf_results = benchmark_performance(model, rngs)

## 12. Summary and Conclusions

This notebook demonstrated the modern Artifex framework for developing and training generative models.

In [None]:
# Summary

print("\n" + "=" * 70)
print(" " * 20 + "🎉 ARTIFEX VAE TRAINING COMPLETE 🎉")
print("=" * 70)

print("\n## 📚 What We Demonstrated:\n")

print("### 1. **Artifex's Factory System**")
print("   - Centralized model creation with ModelFactory")
print("   - Unified configuration management")
print("   - Modality-based adapters")

print("\n### 2. **Device Management**")
print("   - Automatic GPU/CPU detection and fallback")
print("   - Memory strategy configuration")
print("   - Hardware-aware optimization")

print("\n### 3. **Training Infrastructure**")
print("   - Artifex's official Trainer class")
print("   - JIT-compiled training loops")
if "perf_results" in locals():
    print(f"   - Achieved {perf_results['speedup']:.2f}x speedup with JIT")

print("\n### 4. **Comprehensive Evaluation**")
print("   - FID and Inception Score metrics")
print("   - Latent space analysis")
print("   - Reconstruction quality assessment")

print("\n### 5. **Production Features**")
print("   - Model checkpointing and persistence")
print("   - Type-safe configuration with Pydantic")
print("   - Modular, extensible architecture")

print("\n## 📊 Final Results:\n")
if "history" in locals():
    print(f"   Training Loss: {history['train_loss'][-1]:.4f}")
    print(f"   Validation Loss: {history['val_loss'][-1]:.4f}")
if "eval_results" in locals():
    print(f"   Reconstruction MSE: {eval_results['reconstruction_mse']:.6f}")
    print(f"   FID Score: {eval_results['fid_score']:.2f}")
if "perf_results" in locals():
    print(f"   JIT Speedup: {perf_results['speedup']:.2f}x")

print("\n## 🚀 Next Steps:\n")
print("1. **Scale Up Training**")
print("   - Increase DATASET_SIZE to 100,000+")
print("   - Train for 100+ epochs")
print("   - Use larger batch sizes on GPU")

print("\n2. **Experiment with Architectures**")
print("   - Try β-VAE for better disentanglement")
print("   - Implement VQ-VAE for discrete representations")
print("   - Test different encoder/decoder architectures")

print("\n3. **Advanced Features**")
print("   - Add conditional generation")
print("   - Implement attribute manipulation")
print("   - Try other datasets (FFHQ, etc.)")

print("\n4. **Production Deployment**")
print("   - Export to ONNX/TensorFlow")
print("   - Create REST API with FastAPI")
print("   - Deploy with Artifex's CLI tools")

print("\n## 🔗 Resources:\n")
print("- Artifex Documentation: https://github.com/avitai/artifex")
print("- JAX Documentation: https://jax.readthedocs.io")
print("- Flax NNX Guide: https://flax.readthedocs.io/en/latest/nnx/")

print("\n" + "=" * 70)
print("Thank you for using Artifex! Happy modeling! 🚀")
print("=" * 70)