# Structured Autoencoder: 2D Content + 6D Transform

**Minimal Training Notebook - Updated for Modular Structure**

- **2D Content Latent**: Digit identity/shape clustering
- **6D Transform Latent**: Spatial transformations  
- **Cloud Ready**: CUDA/CPU optimization
- **Simplified Loss**: Uses new affine+KL loss function
- **Explicit Imports**: Clear module references (structured_2d6d_autoencoder, affine_autoencoder_shared)

In [1]:
import torch
# Explicit imports from modular structure
import structured_2d6d_autoencoder as s2d6d
import affine_autoencoder_shared as shared

# 🚀 CONFIG - Updated for simplified loss
CONFIG = {
    'content_latent_dim': 2, 'transform_latent_dim': 6, 'total_latent_dim': 8,
    'epochs': 50, 'learning_rate': 1e-3, 'batch_size_train': 256, 'batch_size_test': 128,
    'alpha': 1.0,  # Affine loss weight
    'beta': 0.001,  # KL divergence weight (reduced for 2D content latent)
    'force_cuda': True, 'mixed_precision': True, 'gradient_clip': 1.0,
    'pin_memory': True, 'num_workers': 4, 'weight_decay': 1e-5,
    'lr_scheduler': True, 'early_stopping': True, 'patience': 10,
    'data_dir': '../data', 'save_dir': './', 'checkpoint_freq': 10
}

In [2]:
# 🌩️ SETUP - Using shared utilities with explicit module reference
device = shared.get_cloud_device(CONFIG)
scaler = torch.cuda.amp.GradScaler() if CONFIG['mixed_precision'] and device.type == 'cuda' else None
train_loader, test_loader = shared.get_cloud_mnist_loaders(**{k: v for k, v in CONFIG.items() if k in ['batch_size_train', 'batch_size_test', 'data_dir', 'pin_memory', 'num_workers']})

🍎 Apple MPS device
📊 Train batches: 235, Test batches: 79


In [3]:
# 🏗️ MODEL - Using structured_2d6d_autoencoder module explicitly
model = s2d6d.StructuredAffineInvariantAutoEncoder(
    content_dim=CONFIG['content_latent_dim'],
    transform_dim=CONFIG['transform_latent_dim']
).to(device)

In [5]:
# 📂 OPTIONAL: LOAD EXISTING MODEL (comment out to train new)
# shared.list_saved_models()  # List available models (from affine_autoencoder_shared)
# model, CONFIG, device = shared.load_model_cloud("structured_model_YYYYMMDD_HHMMSS.pth")
# Note: Use shared.* functions from affine_autoencoder_shared module

In [4]:
# 🚀 TRAIN - Using simplified affine+KL loss from structured_2d6d_autoencoder
losses_dict = s2d6d.train_structured_autoencoder_simplified(model, train_loader, test_loader, device, CONFIG)

Epoch 1:   0%|          | 0/235 [00:05<?, ?it/s]



NotImplementedError: The operator 'aten::grid_sampler_2d_backward' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

In [None]:
# 📈 VISUALIZE - Using explicit module references
s2d6d.plot_simplified_training_progress_structured(losses_dict)
content_data, transform_data, label_data = s2d6d.visualize_structured_latent_space(model, test_loader, device)

In [None]:
# 🎨 COMPREHENSIVE VISUALIZATIONS - Using structured_2d6d_autoencoder module
s2d6d.comprehensive_visualization_structured(model, test_loader, device, CONFIG)

In [None]:
# 💾 SAVE - Using shared save functionality from affine_autoencoder_shared
# Basic save (model, config, losses only - no extra visualization data)
model_file, metadata_file = shared.save_model_cloud(model, CONFIG, losses_dict, device)

In [None]:
# 💾 ALTERNATIVE: Save for Visualization (includes all data needed for viz)
# This saves everything needed to recreate visualizations later
viz_filename = shared.save_model_for_viz(
    model, 
    model_type="structured",
    config=CONFIG,
    losses=losses_dict,
    extra_data={
        'content_data': content_data,
        'transform_data': transform_data,
        'label_data': label_data
    },
    name="2d6d_simplified"  # Custom name for easy identification
)

In [None]:
# 📁 LOAD MODEL FOR VISUALIZATION
# Run this cell to load a previously saved model instead of training from scratch

# Option 1: Load specific model by filename
# loaded_model, viz_data = shared.load_model_for_viz("structured_2d6d_simplified_20250720_123456.pth", 
#                                                    s2d6d.StructuredAffineInvariantAutoEncoder, device)

# Option 2: Quick load most recent model
# loaded_model, viz_data = shared.quick_load_viz(s2d6d.StructuredAffineInvariantAutoEncoder, 
#                                                model_type="structured", name="2d6d_simplified", device=device)

# After loading, you can access:
# - loaded_model: The trained model ready for inference
# - viz_data['config']: Original training configuration  
# - viz_data['losses']: Training loss history
# - viz_data['extra_data']: Latent embeddings and other visualization data

print("💡 Uncomment the lines above to load a saved model for visualization")