# MAE Checkpoint Visualization

This notebook loads a trained MAE checkpoint and visualizes reconstructions on sample galaxy images.


In [1]:
import torch
import matplotlib.pyplot as plt
import numpy as np
from mae_model import create_mae_model
from data import get_dataloaders
from visualization import denormalize

%matplotlib inline


  from .autonotebook import tqdm as notebook_tqdm


## Configuration


In [2]:
# Path to your checkpoint (change this!)
CHECKPOINT_PATH = "models/Mae_Galaxy_Vit_Base_Epoch_400.pth"  # Example

In [3]:


# Model configuration (must match your training config)
model_config = {
    "image_size": 256,
    "patch_size": 16,
    "embed_dim": 768,
    "encoder_depth": 12,
    "encoder_heads": 12,
    "decoder_embed_dim": 512,
    "decoder_depth": 8,
    "decoder_heads": 16
}

DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f"Using device: {DEVICE}")


Using device: mps


## Load Data


In [4]:
# Load dataloaders
print("Loading data...")
mae_loader, probe_train_loader, probe_test_loader, finetune_train_loader, finetune_test_loader, train_mean, train_std = get_dataloaders(
    batch_size=32,
    image_size=model_config["image_size"],
    num_workers=0  # Set to 0 for notebooks
)

print(f"Mean: {train_mean}")
print(f"Std: {train_std}")


Loading data...
Loading Galaxy10 dataset from Hugging Face...
Calculating dataset statistics (mean and std)...


Calculating Stats: 100%|██████████| 499/499 [00:54<00:00,  9.11it/s]


Calculated Mean: [0.16750076413154602, 0.16260723769664764, 0.15888828039169312]
Calculated Std: [0.12320292741060257, 0.11179731786251068, 0.1046297699213028]

Dataloaders created successfully.
  - MAE loader: No augmentation (for reconstruction)
  - Probe loaders: No augmentation (for evaluation)
  - Fine-tune loaders: WITH augmentation (for training)
Mean: [0.16750076413154602, 0.16260723769664764, 0.15888828039169312]
Std: [0.12320292741060257, 0.11179731786251068, 0.1046297699213028]





## Load Model & Checkpoint


In [5]:
# Create model
print("Creating model...")
mae_model = create_mae_model(**model_config)

# Load checkpoint
print(f"Loading checkpoint from {CHECKPOINT_PATH}...")
checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
mae_model.load_state_dict(checkpoint)
mae_model.to(DEVICE)
mae_model.eval()

print("✓ Model loaded successfully!")


Creating model...
Creating a randomly initialized MAE model...
Loading checkpoint from models/Mae_Galaxy_Vit_Base_Epoch_400.pth...
✓ Model loaded successfully!


## Visualize the Reconstruction


In [None]:
def visualize_reconstruction(model, image, mean, std, device, title="MAE Reconstruction"):
    """Visualize MAE reconstruction for a single image."""
    model.eval()
    with torch.no_grad():
        image_batch = image.unsqueeze(0).to(device)
        
        # Forward pass
        outputs = model(pixel_values=image_batch)
        mask = outputs.mask.detach()
        
        # Get original patches
        original_patches = model.patchify(image_batch)
        
        mask_expanded = mask.unsqueeze(-1)  # (batch, num_patches, 1)
        masked_patches = original_patches * (1 - mask_expanded)  # Keep visible, zero masked
        masked_image_tensor = model.unpatchify(masked_patches).squeeze(0)
        
        # --- Create Full Reconstruction (model's prediction for ALL patches) ---
            # outputs.logits contains predictions for ALL patches
        pred_patches = outputs.logits.detach()  # (B, N, P)

        if getattr(model.config, "norm_pix_loss", True):
            # compute per-patch stats from targets (i.e., original_patches)
            patch_mean = original_patches.mean(dim=-1, keepdim=True)            # (B, N, 1)
            patch_var  = original_patches.var(dim=-1, keepdim=True, unbiased=False)
            patch_std  = (patch_var + 1e-6).sqrt()

            # undo the per-patch normalization
            pred_patches = pred_patches * patch_std + patch_mean  # back to dataset-normalized pixel space
            
        full_reconstruction_tensor = model.unpatchify(pred_patches).squeeze(0)

        # --- Hybrid: visible from original + masked from reconstruction ---
        hybrid_patches = original_patches * (1 - mask_expanded) + pred_patches * mask_expanded
        hybrid_reconstruction_tensor = model.unpatchify(hybrid_patches).squeeze(0)

        # --- Plotting ---
        fig, axs = plt.subplots(1, 4, figsize=(20, 5))
        fig.suptitle(f'MAE Reconstruction at Epoch 400', fontsize=16)

        original_vis = denormalize(image.cpu(), train_mean, train_std).permute(1, 2, 0).numpy()
        masked_vis = denormalize(masked_image_tensor.cpu(), train_mean, train_std).permute(1, 2, 0).numpy()
        full_recon_vis = denormalize(full_reconstruction_tensor.cpu(), train_mean, train_std).permute(1, 2, 0).numpy()
        hybrid_vis = denormalize(hybrid_reconstruction_tensor.cpu(), train_mean, train_std).permute(1, 2, 0).numpy()

        axs[0].imshow(original_vis); axs[0].set_title('Original'); axs[0].axis('off')
        axs[1].imshow(masked_vis); axs[1].set_title('Masked (75%)'); axs[1].axis('off')
        axs[2].imshow(full_recon_vis); axs[2].set_title('Full Reconstruction'); axs[2].axis('off')
        axs[3].imshow(hybrid_vis); axs[3].set_title('Hybrid (Visible + Recon)'); axs[3].axis('off')
        
        plt.tight_layout()
        plt.show()
        
        return 


## Get a Sample Image


In [9]:
# Get one image from the test set
test_batch = next(iter(probe_test_loader))
test_image = test_batch['pixel_values'][0]  # Get first image
test_label = test_batch['label'][0].item()

print(f"Image shape: {test_image.shape}")
print(f"Galaxy class: {test_label}")


Image shape: torch.Size([3, 256, 256])
Galaxy class: 7




In [None]:
# Visualize the reconstruction
visualize_reconstruction(
    mae_model,
    test_image,
    train_mean,
    train_std,
    DEVICE,
    title=f"Galaxy Class {test_label}"
)
