In [3]:
import torch
from utils.plotting import plot_latent_tsne_grid
from models.models import SplitEncoder, SplitDecoder
from data.loader import get_dataloader
from utils.seed import set_seed

In [11]:
set_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [12]:
input_dim = 784
output_dim = 784
latent_dim = 64
signal_dim = 32
num_classes = 10

encoder = SplitEncoder(input_dim=input_dim, latent_dim=latent_dim, signal_dim=signal_dim).to(device)
decoder = SplitDecoder(latent_dim=latent_dim, output_dim=output_dim).to(device)

In [13]:
ckpt = torch.load("artifacts/mnist/mnist_pretrained.pt", map_location=device)
encoder.load_state_dict(ckpt["encoder"])
decoder.load_state_dict(ckpt["decoder"])

<All keys matched successfully>

In [14]:
loader = get_dataloader("mnist", batch_size=256, train=False)

In [15]:
plot_latent_tsne_grid(
    encoder=encoder,
    decoder=decoder,
    dataloader=loader,
    device=device,
    save_path="artifacts/plots/mnist_latent_grid.png",
    n_samples=1000,
    title_prefix="MNIST Pretrained Model"
)

[t-SNE] Running joint projection...
✅ t-SNE plot saved to artifacts/plots/mnist_latent_grid.png
