In [6]:
import torch
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from ae import UNetAutoencoder
from torch.utils.data import DataLoader

In [None]:
# Initialize the UNetAutoencoder
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNetAutoencoder(base_ch=256, skip_connections=False).to(device)

# Load the checkpoint
checkpoint_path = "path"
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()


# Extract config from checkpoint
config = checkpoint['config']


# Load a batch of images (replace with your dataset loader)
# Example: images = next(iter(dataloader))
dataloader = DataLoader(config.dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.data_workers)
im, idx = next(iter(dataloader))
images = im.to(device)


In [None]:

# Extract latent representations
with torch.no_grad():
    latents, _ = model.encode(images)

# Flatten the latent space for PCA
latents_flat = latents.view(latents.size(0), -1).cpu().numpy()

# Perform PCA
pca = PCA(n_components=2)
latents_pca = pca.fit_transform(latents_flat)

# Visualize the PCA results
plt.scatter(latents_pca[:, 0], latents_pca[:, 1])
plt.title("PCA of Latent Space")
plt.xlabel("Principal Component 1")
plt.ylabel("Principal Component 2")
plt.show()

# Reconstruct images from the latent space
with torch.no_grad():
    reconstructed_images = model.decode(latents)

# Visualize original and reconstructed images
fig, axes = plt.subplots(2, len(images), figsize=(15, 5))
for i in range(len(images)):
    axes[0, i].imshow(images[i].squeeze().cpu().numpy(), cmap="gray")
    axes[0, i].set_title("Original")
    axes[0, i].axis("off")
    
    axes[1, i].imshow(reconstructed_images[i].squeeze().cpu().numpy(), cmap="gray")
    axes[1, i].set_title("Reconstructed")
    axes[1, i].axis("off")
plt.tight_layout()
plt.show()