In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from astromodal.config import load_config
from tqdm import tqdm
import polars as pl
import random

In [3]:
# use gpu 1 on torch 

import os

os.environ['CUDA_VISIBLE_DEVICES'] = '1'

import torch 

In [4]:
config = load_config("/home/schwarz/projetoFM/config.yaml")

print(config['datacubes_paths'])

/home/astrodados4/downloads/hypercube/datacube_*.parquet


In [5]:
datacube_paths = config['datacubes_paths']

from astromodal.datasets.datacubes import load_datacube_files

In [6]:
train_files, val_files = load_datacube_files(
    datacubes_paths = datacube_paths,
    train_val_split = 1,
    nfiles_subsample = 1,
    seed = 42
)

[info] - Found 2444 datacube files
[info] - Subsampled to 1 files
[info] - Training files: 1
[info] - Validation files: 0


In [7]:
columns = [
    "splus_cut_F378",
    "splus_cut_F395",
    "splus_cut_F410",
    "splus_cut_F430",
    "splus_cut_F515",
    "splus_cut_F660",
    "splus_cut_F861",
    "splus_cut_R",
    "splus_cut_I",
    "splus_cut_Z",
    "splus_cut_U",
    "splus_cut_G",
]

bands = ["F378", "F395", "F410", "F430", "F515", "F660", "F861", "R", "I", "Z", "U", "G"]
cutout_size = 96

In [8]:
import polars as pl
from tqdm import tqdm

train_df = None

for f in tqdm(train_files, desc="Loading train files"):
    df = pl.read_parquet(f, columns=columns + ["mag_psf_r"], use_pyarrow=True)
    df = df.filter(pl.col(columns[0]).is_not_null())

    if df.height == 0:
        continue

    train_df = df if train_df is None else pl.concat([train_df, df], how="vertical", rechunk=False)

train_df = train_df.rechunk()

Loading train files:   0%|          | 0/1 [00:00<?, ?it/s]

Loading train files: 100%|██████████| 1/1 [00:29<00:00, 29.89s/it]


In [9]:
train_df = train_df.filter(pl.col("mag_psf_r") < 21)

In [10]:
from astromodal.datasets.spluscuts import SplusCutoutsDataset

train_dataset = SplusCutoutsDataset(
    train_df,
    bands=bands,
    img_size=cutout_size,
    return_valid_mask=True,
)
# val_dataset = SplusCutoutsDataset(
#     val_df,
#     bands=bands,
#     img_size=cutout_size,
#     return_valid_mask=True,
# )

In [11]:
from pathlib import Path

batch_size = 1024
max_gpu_batch_size = 1024
num_epochs = 10
learning_rate = 1e-3
latent_dim = 2

model_output_path = Path(config['models_folder']) / "./autoencoder_model_silu.pth"

In [12]:
from torch.utils.data import DataLoader

from astromodal.models.autoencoder import AutoEncoder
from torch.amp import GradScaler


train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,
    pin_memory=True,
)
# val_loader = DataLoader(
#     val_dataset,
#     batch_size=max_gpu_batch_size,
#     shuffle=False,
#     num_workers=0,
#     pin_memory=True,
# )



In [15]:
from astromodal.models.autoencoder import AutoEncoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = AutoEncoder.load_from_file(model_output_path, map_location=torch.device('cuda'))
model = model.to(device).eval()   # <-- THIS is the missing line

[info] - Loaded model from /home/schwarz/projetoFM/models/autoencoder_model_silu.pth


In [22]:
import numpy as np
import matplotlib.pyplot as plt
import torch

def percentile_stretch(img2d: np.ndarray, p_lo=1, p_hi=99):
    finite = np.isfinite(img2d)
    if not finite.any():
        return img2d, 0.0, 1.0
    vmin = np.percentile(img2d[finite], p_lo)
    vmax = np.percentile(img2d[finite], p_hi)
    if vmax <= vmin:
        vmax = vmin + 1e-6
    out = np.clip((img2d - vmin) / (vmax - vmin), 0, 1)
    return out, vmin, vmax

def symmetric_residual_stretch(res2d: np.ndarray, p_hi=99):
    """
    Map residuals to [0,1] using symmetric limits [-v, +v],
    where v = percentile(|res|, p_hi).
    Zero residual -> 0.5.
    """
    finite = np.isfinite(res2d)
    if not finite.any():
        v = 1.0
        out = np.zeros_like(res2d) + 0.5
        return out, -v, v

    absvals = np.abs(res2d[finite])
    v = np.percentile(absvals, p_hi)
    if v <= 0:
        v = 1e-6

    out = np.clip((res2d + v) / (2 * v), 0, 1)
    return out, -v, v

def plot_before_after_residual_12bands(
    x_in: torch.Tensor,        # (C,H,W)
    x_rec: torch.Tensor,       # (C,H,W)
    band_names=None,
    p_lo=1,
    p_hi=99,
    use_input_percentiles=True,
    residual_p_hi=99,
    suptitle=None,
    residual_cmap="RdBu_r",
):
    x_in = x_in.detach().float().cpu().numpy()
    x_rec = x_rec.detach().float().cpu().numpy()
    res = x_rec - x_in

    C, H, W = x_in.shape
    assert C == 12, f"Expected 12 bands, got {C}"

    if band_names is None:
        band_names = [f"b{i}" for i in range(C)]

    fig, axes = plt.subplots(
        3, C, figsize=(2.2*C, 6.6), constrained_layout=True
    )

    residual_im = None  # will store one imshow handle for colorbar

    for c in range(C):
        # ----- input & recon scaling (unchanged logic)
        ref = x_in[c] if use_input_percentiles else np.concatenate(
            [x_in[c].ravel(), x_rec[c].ravel()]
        )
        finite = np.isfinite(ref)
        if finite.any():
            vmin = np.percentile(ref[finite], p_lo)
            vmax = np.percentile(ref[finite], p_hi)
            if vmax <= vmin:
                vmax = vmin + 1e-6
        else:
            vmin, vmax = 0.0, 1.0

        def norm_with(v):
            return np.clip((v - vmin) / (vmax - vmin), 0, 1)

        im_in = norm_with(x_in[c])
        im_rec = norm_with(x_rec[c])

        # ----- residual scaling (RAW units, symmetric)
        finite_r = np.isfinite(res[c])
        if finite_r.any():
            v = np.percentile(np.abs(res[c][finite_r]), residual_p_hi)
            if v <= 0:
                v = 1e-6
        else:
            v = 1.0

        # ----- plotting
        axes[0, c].imshow(im_in, origin="lower")
        axes[0, c].set_title(str(band_names[c]))

        axes[1, c].imshow(im_rec, origin="lower")

        residual_im = axes[2, c].imshow(
            res[c],
            origin="lower",
            cmap=residual_cmap,
            vmin=-v,
            vmax=+v,
        )

        axes[0, c].axis("off")
        axes[1, c].axis("off")
        axes[2, c].axis("off")

    axes[0, 0].set_ylabel("Input", fontsize=12)
    axes[1, 0].set_ylabel("Recon", fontsize=12)
    axes[2, 0].set_ylabel("Residual", fontsize=12)

    # ----- shared colorbar for residuals
    cbar = fig.colorbar(
        residual_im,
        ax=axes[2, :],
        orientation="horizontal",
        fraction=0.06,
        pad=0.08,
    )
    cbar.set_label("Residual (Recon − Input)", fontsize=11)

    if suptitle:
        fig.suptitle(suptitle, fontsize=14)

    return fig

In [24]:
from torch.utils.data import DataLoader

def save_reconstructions(
    model,
    dataloader,
    device,
    band_names=None,
    n_examples=3,
    p_lo=1,
    p_hi=99,
    residual_p_hi=99,
    output_dir=None,
):
    batch = next(iter(dataloader))
    x = batch[0] if isinstance(batch, (tuple, list)) else batch

    x = x.to(device)

    with torch.no_grad():
        x_rec = model.encode(x)
        x_rec = model.decode(x_rec)

    # check if x and x_rec have same shape and if are identical
    assert x.shape == x_rec.shape, f"Input and reconstruction shapes do not match: {x.shape} vs {x_rec.shape}"
    if torch.allclose(x, x_rec):
        print("Warning: Input and reconstruction are identical!")
    
    for i in range(min(n_examples, x.shape[0])):
        fig = plot_before_after_residual_12bands(
            x_in=x[i],
            x_rec=x_rec[i],
            band_names=band_names,
            p_lo=p_lo,
            p_hi=p_hi,
            residual_p_hi=residual_p_hi,
            use_input_percentiles=True,
            suptitle=f"Example {i}",
        )
        if output_dir is not None:
            output_path = output_dir / f"reconstruction_{i}.png"
            fig.savefig(output_path)
            plt.close(fig)

In [25]:
loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=0)

bands = ["F378", "F395", "F410", "F430", "F515", "F660", "F861", "R", "I", "Z", "U", "G"]

outfolder = Path(config['models_folder']) / "plots" / "autoencoder_reconstructions"
outfolder.mkdir(parents=True, exist_ok=True)


save_reconstructions(
    model=model,
    dataloader=loader,
    device=device,
    band_names=bands,
    n_examples=40,
    p_lo=1,
    p_hi=99,
    output_dir=outfolder
)