In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
from skimage import io
import torch
from torchvision.transforms.functional import to_pil_image
from models_mae_spectral import MaskedAutoencoderViT

# ----- CONFIG -----
# Directory containing .tif images
IMAGE_DIR = "sample_images"
# Path to the pretrained SpectralGPT+ model checkpoint. To be downloaded from git.
CHECKPOINT_PATH = "Checkpoints/SpectralGPT+.pth"

# SELECT CORRECT SENTINEL2 INDEXES
rgb_idx = [3, 2, 1]  # R, G, B

total_loss = 0.0
valid_samples = 0

# ----- MODEL -----
model = MaskedAutoencoderViT(
    img_size=128,
    in_chans=1,
    patch_size=8,
    embed_dim=768,
    decoder_embed_dim=512,
    depth=12,
    decoder_depth=4,
    decoder_num_heads=16,
    num_heads=12,
    mlp_ratio=4,
    num_frames=12,
    pred_t_dim=12,
    t_patch_size=3,
    mask_ratio=0.0001,
)
checkpoint = torch.load(CHECKPOINT_PATH, map_location="cpu", weights_only=False)
model.load_state_dict(checkpoint["model"], strict=False)
model.eval().cuda()

# ----- Load all .tif images -----
all_files = [os.path.join(IMAGE_DIR, f) for f in os.listdir(IMAGE_DIR) if f.endswith(".tif")]

fig, axs = plt.subplots(len(all_files), 2, figsize=(8, 4 * len(all_files)))


for row, file_path in enumerate(all_files):
    img = io.imread(file_path)

    # Convert to [C, H, W] (12,128,128)
    img = img.transpose(2, 0, 1)
    img_tensor = torch.tensor(img).float()

    # Normalize per-image
    img_tensor = (img_tensor - img_tensor.min(dim=1, keepdim=True)[0].min(dim=2, keepdim=True)[0]) / \
                 (img_tensor.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0] + 1e-6)   # min-max normalization
    img_tensor = img_tensor.unsqueeze(0).cuda()  # [1, C, H, W]

    with torch.no_grad():
        loss, pred, mask = model(img_tensor)
        recon_all = model.unpatchify(pred).squeeze(0).cpu()  # [C, H, W]

    total_loss += loss.item()
    valid_samples += 1
    print(f"Image: {os.path.basename(file_path)} | Reconstruction Loss: {loss.item():.6f}")

    # Extract RGB for visualization
    low_rgb = img_tensor[0, rgb_idx].cpu()
    low_rgb = (low_rgb - low_rgb.min()) / (low_rgb.max() - low_rgb.min())

    recon_rgb = recon_all[0][rgb_idx]
    recon_rgb = (recon_rgb - recon_rgb.min()) / (recon_rgb.max() - recon_rgb.min())

    # ----- PLOTTING -----
    axs[row][0].imshow(to_pil_image(low_rgb))
    axs[row][0].set_title("Ground Truth")
    axs[row][0].axis("off")

    axs[row][1].imshow(to_pil_image(recon_rgb))
    axs[row][1].set_title("Reconstructed")
    axs[row][1].axis("off")

if valid_samples > 0:
    avg_loss = total_loss / valid_samples
    print(f"\nüîç Averaged Reconstruction Loss over {valid_samples} images: {avg_loss:.6f}")

plt.tight_layout()
plt.show()


ModuleNotFoundError: No module named 'torch'

In [None]:
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
