In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from astromodal.models.autoencoder import AutoEncoder
from astromodal.config import load_config
import torch

import polars as pl

from pathlib import Path

In [55]:
config = load_config("/home/schwarz/projetoFM/config.yaml")

model_path = Path(Path(config['models_folder']) / "./autoencoder_model_silu.pth")

In [4]:
model = AutoEncoder.load_from_file(str(model_path))

[info] - Loaded model from /home/schwarz/projetoFM/models/autoencoder_model_silu.pth


In [5]:
# get device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [6]:
from astromodal.datatypes import SplusCuts

In [36]:
SplusCuts.__metadata__["columns"]

['splus_cut_F378',
 'splus_cut_F395',
 'splus_cut_F410',
 'splus_cut_F430',
 'splus_cut_F515',
 'splus_cut_F660',
 'splus_cut_F861',
 'splus_cut_R',
 'splus_cut_I',
 'splus_cut_Z',
 'splus_cut_U',
 'splus_cut_G']

In [52]:
field = "HYDRA-0075"
path = str(config['datacubes_paths']).replace("*", field)

columns = ["id"] + SplusCuts.__metadata__["columns"]

df = pl.read_parquet(path, columns=columns, use_pyarrow=True)
df = df.filter(pl.col(SplusCuts.__metadata__["columns"][0]).is_not_null())

In [42]:
from astromodal.datasets.spluscuts import SplusCutoutsDataset
from torch.utils.data import DataLoader

bands = ["F378", "F395", "F410", "F430", "F515", "F660", "F861", "R", "I", "Z", "U", "G"]
cutout_size = 96

batch_size = 1024

dataset = SplusCutoutsDataset(
    df,
    bands=bands,
    img_size=cutout_size,
    return_valid_mask=True,
)

loader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=14,
    pin_memory=True,
)

In [None]:
batch_size = 1024
ids_out = []
latents_out = []

model.eval()
with torch.no_grad():
    for i, (x_norm, m_valid) in enumerate(loader):
        latents = model.encode(x_norm.to(device))
        latents = latents.reshape(latents.shape[0], -1)

        start = i * batch_size
        end = start + latents.shape[0]

        ids = df["id"][start:end].to_numpy()

        ids_out.extend(ids)
        latents_out.extend(latents.cpu().numpy())

In [48]:
df_latents = pl.DataFrame(
    {
        "id": ids_out,
        "latent": latents_out,  # List[f32]
    }
)

In [58]:
folder = Path(config['hdd_folder']) / "image_latents" / f"{field}.parquet"
folder.parent.mkdir(parents=True, exist_ok=True)

df_latents.write_parquet(folder)

In [59]:
df = pl.read_parquet("/home/astrodados4/downloads/projetin/image_latents/SPLUS-n04s23.parquet")

In [60]:
df

id,latent
binary,"array[f32, 1152]"
"b""i06n04s2300003""","[-3.307046, -3.794522, … -3.243335]"
"b""i06n04s2300005""","[-3.032363, -2.69076, … -0.818269]"
"b""i06n04s230000K""","[-3.60503, -3.818086, … -3.103487]"
"b""i06n04s230000V""","[-4.12054, -3.993376, … -2.275294]"
"b""i06n04s2300017""","[-4.157759, -3.296415, … -3.182621]"
…,…
"b""i06n04s2301COQ""","[-2.671929, -3.534103, … -1.723974]"
"b""i06n04s2301CP0""","[-4.512311, -3.279397, … -2.475172]"
"b""i06n04s2301CPS""","[-4.051606, -3.260572, … -2.235711]"
"b""i06n04s2301CQ1""","[-3.690176, -2.973642, … -2.426193]"
