In [1]:
import os
# ðŸ”¹ MUST be before importing torch
os.environ["CUDA_VISIBLE_DEVICES"] = "1"


In [None]:
#!/usr/bin/env python
import math
from pathlib import Path

import numpy as np
import polars as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

# =========================
# CONFIG
# =========================

# One of your parquet files
PARQUET_FILE = "/home/astrodados4/downloads/hypercube/datacube_HYDRA-0011.parquet"

# Trained model weights from your training script
MODEL_PATH = "conv_autoencoder_splus_qinco_spatial.pt"

# How many objects to visualize
N_SAMPLES = 4

# Which band to show
BANDS = [
    "F378", "F395", "F410", "F430",
    "F515", "F660", "F861",
    "U", "G", "R", "I", "Z",
]
PLOT_BAND = "R"  # must be in BANDS
IMG_SIZE = 96

# QINCo & latent config (must match training)
LATENT_DIM = 8
QINCO_USE = True
QINCO_STEPS = 4
QINCO_K = 1024

# Require F378 not null, as in training
REQUIRE_F378_NOT_NULL = True

# =========================
# UTILS
# =========================

_GRID_CACHE = {}

def get_grid(H: int, W: int, device: torch.device):
    key = (H, W, device.type)
    if key not in _GRID_CACHE:
        yy, xx = torch.meshgrid(
            torch.arange(H, device=device),
            torch.arange(W, device=device),
            indexing="ij"
        )
        _GRID_CACHE[key] = (yy.float(), xx.float())
    return _GRID_CACHE[key]


def _to_image_torch(flat) -> torch.Tensor:
    """Convert flattened array-like into a square image (H, W)."""
    arr = torch.tensor(flat, dtype=torch.float32)

    if arr.ndim == 2:
        return arr

    if arr.ndim == 1:
        n = arr.numel()
        side = int(math.isqrt(n))
        if side * side != n:
            raise ValueError(f"Cannot reshape length {n} into a square image")
        return arr.view(side, side)

    raise ValueError(f"Unexpected ndim={arr.ndim} for image data")


def elliptical_mask(H, W, x0, y0, a, b, theta, device="cpu", expand_factor=4.0):
    """Binary mask: 1 inside expanded ellipse, 0 outside."""
    device = torch.device(device)
    yy, xx = get_grid(H, W, device=device)

    a_scaled = a * expand_factor
    b_scaled = b * expand_factor

    X = xx - x0
    Y = yy - y0

    ct = torch.cos(theta)
    st = torch.sin(theta)

    Xp =  X * ct + Y * st
    Yp = -X * st + Y * ct

    mask = (Xp / a_scaled) ** 2 + (Yp / b_scaled) ** 2 <= 1.0
    return mask.float()


def percentile_range(values: np.ndarray, p_lo=1.0, p_hi=99.0):
    """Safe percentile-based range for plotting."""
    flat = values.reshape(-1)
    flat = flat[np.isfinite(flat)]
    if flat.size == 0:
        return float(0.0), float(1.0)
    v_lo = float(np.percentile(flat, p_lo))
    v_hi = float(np.percentile(flat, p_hi))
    if v_lo == v_hi:
        v_lo = float(flat.min())
        v_hi = float(flat.max())
        if v_lo == v_hi:
            v_hi = v_lo + 1.0
    return v_lo, v_hi

# =========================
# QINCo MODULES (SAME AS TRAINING)
# =========================

class QINCoStep(nn.Module):
    def __init__(self, D: int, K: int, hidden_dim: int = 256, num_res_blocks: int = 2):
        super().__init__()
        self.D = D
        self.K = K

        self.base_codebook = nn.Parameter(torch.randn(K, D) * 0.1)
        self.concat_proj = nn.Linear(2 * D, D)

        blocks = []
        for _ in range(num_res_blocks):
            blocks.append(nn.Sequential(
                nn.Linear(D, hidden_dim),
                nn.ReLU(inplace=True),
                nn.Linear(hidden_dim, D),
            ))
        self.blocks = nn.ModuleList(blocks)

    def forward_codebook(self, x_hat: torch.Tensor) -> torch.Tensor:
        N, D = x_hat.shape
        x_exp = x_hat.unsqueeze(1).expand(-1, self.K, -1)
        cbar = self.base_codebook.unsqueeze(0).expand(N, -1, -1)
        concat = torch.cat([x_exp, cbar], dim=-1)   # (N,K,2D)
        C = self.concat_proj(concat)
        for block in self.blocks:
            C = C + block(C)
        return C

    def encode_step(self, x: torch.Tensor, x_hat: torch.Tensor):
        N, D = x.shape
        C = self.forward_codebook(x_hat)  # (N,K,D)

        r = x - x_hat  # (N,D)
        r_exp = r.unsqueeze(1).expand(-1, self.K, -1)
        dists = torch.sum((r_exp - C) ** 2, dim=-1)  # (N,K)

        codes = torch.argmin(dists, dim=-1)          # (N,)

        C_flat = C.reshape(N * self.K, D)
        idx = codes + torch.arange(N, device=x.device) * self.K
        c_sel = C_flat[idx]                          # (N,D)

        x_hat_new = x_hat + c_sel
        return codes, x_hat_new, r, c_sel


class QINCoQuantizer(nn.Module):
    def __init__(self, D: int, K: int = 256, M: int = 4):
        super().__init__()
        self.M = M
        self.steps = nn.ModuleList(
            [QINCoStep(D, K) for _ in range(M)]
        )

    def forward(self, z: torch.Tensor):
        N, D = z.shape
        x_hat = torch.zeros_like(z)

        codes_all = []
        residuals = []
        selected_centroids = []

        for step in self.steps:
            codes, x_hat, r, c_sel = step.encode_step(z, x_hat)
            codes_all.append(codes)
            residuals.append(r)
            selected_centroids.append(c_sel)

        codes_all = torch.stack(codes_all, dim=-1)  # (N,M)
        z_q = x_hat
        z_q_st = z + (z_q - z).detach()

        aux = {
            "residuals": residuals,
            "centroids": selected_centroids
        }
        return z_q_st, codes_all, aux


class QINCoQuantizerSpatial(nn.Module):
    def __init__(self, D: int, H: int, W: int, K: int = 256, M: int = 4):
        super().__init__()
        self.D = D
        self.H = H
        self.W = W
        self.inner = QINCoQuantizer(D=D, K=K, M=M)

    def forward(self, z_map: torch.Tensor):
        B, D, H, W = z_map.shape
        assert D == self.D and H == self.H and W == self.W

        z_flat = z_map.permute(0, 2, 3, 1).reshape(-1, D)   # (N,D)
        z_q_flat, codes_flat, aux = self.inner(z_flat)

        z_q_map = z_q_flat.view(B, H, W, D).permute(0, 3, 1, 2)  # (B,D,H,W)
        codes = codes_flat.view(B, H, W, -1)                     # (B,H,W,M)

        return z_q_map, codes, aux

# =========================
# MODEL (SAME ARCH AS TRAINING)
# =========================

class ConvAutoEncoder(nn.Module):
    def __init__(self, in_channels: int = len(BANDS), latent_dim: int = LATENT_DIM):
        super().__init__()

        self.encoder_conv = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # 96 -> 48

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # 48 -> 24

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # 24 -> 12
        )

        self.latent_H = IMG_SIZE // 8
        self.latent_W = IMG_SIZE // 8

        self.to_latent = nn.Conv2d(128, latent_dim, kernel_size=1)
        self.from_latent = nn.Conv2d(latent_dim, 128, kernel_size=1)

        if QINCO_USE:
            self.qinco = QINCoQuantizerSpatial(
                D=latent_dim,
                H=self.latent_H,
                W=self.latent_W,
                K=QINCO_K,
                M=QINCO_STEPS,
            )
        else:
            self.qinco = None

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),  # 12 -> 24
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2),   # 24 -> 48
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(32, in_channels, kernel_size=2, stride=2),  # 48 -> 96
        )

    def encode(self, x):
        h = self.encoder_conv(x)
        z_map = self.to_latent(h)
        return z_map

    def decode(self, z_map):
        h = self.from_latent(z_map)
        x_hat = self.decoder(h)
        return x_hat

    def forward(self, x):
        z_map = self.encode(x)
        if self.qinco is not None:
            z_q_map, codes, q_aux = self.qinco(z_map)
            x_hat = self.decode(z_q_map)
        else:
            codes, q_aux = None, None
            x_hat = self.decode(z_map)
        return x_hat, z_map, codes, q_aux

# =========================
# MAIN: LOAD DF, MODEL, ENCODE/DECODE, PLOT
# =========================

def build_batch_from_df(df: pl.DataFrame, n_samples: int):
    """
    Mimic the dataset logic for the first n_samples rows:
    returns x_batch (N,C,H,W) and m_pix_batch (N,C,H,W).
    """
    n = min(n_samples, df.height)
    x_list = []
    m_pix_list = []

    for idx in range(4000, n + 4000):
        imgs = []
        masks_pix_binary = []
        for band in BANDS:
            flat = df[f"splus_cut_{band}"][idx]
            img = _to_image_torch(flat)
            valid = torch.isfinite(img) & (img != 0.0)
            img_clean = img.clone()
            img_clean[~torch.isfinite(img_clean)] = 0.0
            imgs.append(img_clean)
            masks_pix_binary.append(valid.float())

        x = torch.stack(imgs, dim=0)  # (C,H,W)
        m_pix_basic = torch.stack(masks_pix_binary, dim=0)

        C, H, W = x.shape
        device_cpu = torch.device("cpu")
        x = x.to(device_cpu)
        m_pix_basic = m_pix_basic.to(device_cpu)

        x0 = torch.tensor(IMG_SIZE // 2, dtype=torch.float32, device=device_cpu)
        y0 = torch.tensor(IMG_SIZE // 2, dtype=torch.float32, device=device_cpu)
        a  = torch.tensor(float(df["a_pixel_det"][idx]), device=device_cpu)
        b  = torch.tensor(float(df["b_pixel_det"][idx]), device=device_cpu)
        th = torch.tensor(float(df["theta_det"][idx]), device=device_cpu)
        theta = th * math.pi / 180.0

        obj_mask = elliptical_mask(H, W, x0, y0, a, b, theta, device=device_cpu)
        obj_mask_full = obj_mask.unsqueeze(0).expand(C, H, W)
        m_pix = m_pix_basic * obj_mask_full

        x_list.append(x)
        m_pix_list.append(m_pix)

    x_batch = torch.stack(x_list, dim=0)          # (N,C,H,W)
    m_pix_batch = torch.stack(m_pix_list, dim=0)  # (N,C,H,W)
    return x_batch, m_pix_batch



In [3]:

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

# ----- 1) Load dataframe -----
df = pl.read_parquet(PARQUET_FILE)

Using device: cuda


In [4]:

if REQUIRE_F378_NOT_NULL:
    df = df.filter(pl.col("splus_cut_F378").is_not_null())
print("Rows in DF after filtering:", df.height)
if df.height == 0:
    print("No rows to visualize, exiting.")

# ----- 2) Build model and load weights -----
model = ConvAutoEncoder(in_channels=len(BANDS), latent_dim=LATENT_DIM).to(device)
state = torch.load(MODEL_PATH, map_location=device)
model.load_state_dict(state)
model.eval()
print("Model loaded from:", MODEL_PATH)

# ----- 3) Build a small batch from DF -----
x_batch, m_pix_batch = build_batch_from_df(df, N_SAMPLES)
x_batch = x_batch.to(device)
m_pix_batch = m_pix_batch.to(device)

# ----- 4) Encode / decode with QINCo -----
with torch.no_grad():
    x_hat_batch, z_map, codes, q_aux = model(x_batch)

print("x_batch shape:", x_batch.shape)      # (N,C,H,W)
print("x_hat_batch shape:", x_hat_batch.shape)
if codes is not None:
    print("codes shape (B,H_lat,W_lat,M):", codes.shape)

x_batch_cpu = x_batch.detach().cpu()
x_hat_cpu = x_hat_batch.detach().cpu()
m_pix_cpu = m_pix_batch.detach().cpu()

# ----- 5) Plot original vs reconstruction vs residual -----
band_idx = BANDS.index(PLOT_BAND)
N = x_batch_cpu.shape[0]

fig, axes = plt.subplots(
    N, 3, figsize=(12, 4 * N),
    squeeze=False
)

for row in range(N):
    x_orig = x_batch_cpu[row]   # (C,H,W)
    x_rec  = x_hat_cpu[row]     # (C,H,W)
    m_pix  = m_pix_cpu[row]     # (C,H,W)

    img_orig = x_orig[band_idx].numpy()
    img_rec  = x_rec[band_idx].numpy()
    img_res  = img_rec - img_orig

    mask_band = m_pix[band_idx].numpy()
    valid_orig = img_orig[mask_band > 0]
    valid_rec  = img_rec[mask_band > 0]
    if valid_orig.size == 0 or valid_rec.size == 0:
        valid_orig = img_orig
        valid_rec  = img_rec

    vmin_o, vmax_o = percentile_range(valid_orig, 1, 99)
    vmin_r, vmax_r = percentile_range(valid_rec, 1, 99)
    vmin = min(vmin_o, vmin_r)
    vmax = max(vmax_o, vmax_r)

    valid_res = img_res[mask_band > 0]
    if valid_res.size == 0:
        valid_res = img_res
    res_amp = float(np.percentile(np.abs(valid_res.reshape(-1)), 99.0))
    if res_amp == 0.0:
        res_amp = 1.0

    ax0, ax1, ax2 = axes[row]

    im0 = ax0.imshow(img_orig, origin="lower", cmap="gray",
                        vmin=vmin, vmax=vmax)
    ax0.set_title(f"Original ({BANDS[band_idx]}) idx={row}")
    fig.colorbar(im0, ax=ax0, fraction=0.046, pad=0.04)

    im1 = ax1.imshow(img_rec, origin="lower", cmap="gray",
                        vmin=vmin, vmax=vmax)
    ax1.set_title("Reconstruction")
    fig.colorbar(im1, ax=ax1, fraction=0.046, pad=0.04)

    im2 = ax2.imshow(img_res, origin="lower", cmap="bwr",
                        vmin=-res_amp, vmax=res_amp)
    ax2.set_title("Residual")
    fig.colorbar(im2, ax=ax2, fraction=0.046, pad=0.04)

    for ax in (ax0, ax1, ax2):
        ax.set_xticks([])
        ax.set_yticks([])

plt.tight_layout()
out_path = Path("inspect_plots") / "reconstruction_examples.png"
out_path.parent.mkdir(parents=True, exist_ok=True)
plt.savefig(out_path, dpi=150)
plt.close(fig)
print("Saved plot to:", out_path)

Rows in DF after filtering: 19331
Model loaded from: conv_autoencoder_splus_qinco_spatial.pt
x_batch shape: torch.Size([4, 12, 96, 96])
x_hat_batch shape: torch.Size([4, 12, 96, 96])
codes shape (B,H_lat,W_lat,M): torch.Size([4, 12, 12, 4])
Saved plot to: inspect_plots/reconstruction_examples.png
