In [2]:
# %%
import sys
sys.path.append(r"C:\Users\Hagai.LAPTOP-QAG9263N\Desktop\Thesis\repositories\ImagiNav")
import torch
from torchvision.utils import save_image
from modules.autoencoder import AutoEncoder
from modules.datasets import LayoutDataset, collate_skip_none
import torchvision.transforms as T
import os


In [5]:

# %%
# --- CONFIG ---
ae_ckpt_path= r"C:\Users\Hagai.LAPTOP-QAG9263N\Desktop\Thesis\repositories\ImagiNav\indexes\autoencoder_512_64x64x4\checkpoints\ae_epoch_50.pt"
ae_cfg_path= r"C:\Users\Hagai.LAPTOP-QAG9263N\Desktop\Thesis\repositories\ImagiNav\indexes\autoencoder_512_64x64x4\checkpoints\model_config.yaml"
manifest_path = r"C:\Users\Hagai.LAPTOP-QAG9263N\Desktop\Thesis\repositories\ImagiNav\indexes\layouts.csv"
output_dir = "test/output"
os.makedirs(output_dir, exist_ok=True)

device = "cuda" if torch.cuda.is_available() else "cpu"

# %%
# --- LOAD AUTOENCODER ---
autoencoder = AutoEncoder.from_config(ae_cfg_path)
state = torch.load(ae_ckpt_path, map_location="cpu")
autoencoder.load_state_dict(state, strict=False)
autoencoder.eval().to(device)

print("AutoEncoder loaded.")
print("Latent channels:", autoencoder.encoder.latent_channels)
print("Latent base:", autoencoder.encoder.latent_base)

# %%
# --- LOAD SAMPLE DATA ---
transform = T.ToTensor()
dataset = LayoutDataset(manifest_path, transform=transform, mode="all")
sample = dataset[0]["layout"].unsqueeze(0).to(device)

print("Sample shape:", sample.shape)

# %%
# --- ENCODE + DECODE TEST ---
with torch.no_grad():
    z = autoencoder.encoder(sample)
    x_rec = autoencoder.decoder(z)

print("Latent shape:", z.shape)
save_image(x_rec, os.path.join(output_dir, "reconstruction.png"))
print(f"Saved reconstruction → {output_dir}/reconstruction.png")

# %%
# --- LATENT STATISTICS ---
z_flat = z.flatten(1)
mean = z_flat.mean().item()
std = z_flat.std().item()
norms = z_flat.norm(dim=1).mean().item()
print(f"Latent mean={mean:.4f}, std={std:.4f}, avg_norm={norms:.4f}")

# %%
# --- MULTIPLE SAMPLES GRID ---
from torchvision.utils import make_grid
samples = torch.stack([dataset[i]["layout"] for i in range(min(8, len(dataset)))])
samples = samples.to(device)

with torch.no_grad():
    z = autoencoder.encoder(samples)
    x_rec = autoencoder.decoder(z)

grid_in = make_grid(samples, nrow=4)
grid_rec = make_grid(x_rec, nrow=4)

save_image(grid_in, os.path.join(output_dir, "originals.png"))
save_image(grid_rec, os.path.join(output_dir, "recon_grid.png"))

print("Saved originals.png and recon_grid.png")


AutoEncoder loaded.
Latent channels: 4
Latent base: 64
Sample shape: torch.Size([1, 3, 512, 512])
Latent shape: torch.Size([1, 4, 64, 64])
Saved reconstruction → test/output/reconstruction.png
Latent mean=0.0020, std=0.0075, avg_norm=1.0000
Saved originals.png and recon_grid.png
