In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
from astromodal.config import load_config
from tqdm import tqdm
import polars as pl
import random


In [4]:
config = load_config("config.yaml")

print(config['datacubes_paths'])

/home/astrodados4/downloads/hypercube/datacube_*.parquet


In [5]:
datacube_paths = config['datacubes_paths']

from astromodal.datasets.datacubes import load_datacube_files

In [6]:
train_files, val_files = load_datacube_files(
    datacubes_paths = datacube_paths,
    train_val_split = 0.7,
    nfiles_subsample = 10,
    seed = 42
)

[info] - Found 2444 datacube files
[info] - Subsampled to 10 files
[info] - Training files: 7
[info] - Validation files: 3


In [7]:
def read_random_rows_parquet(
    path: str,
    n_rows: int,
    seed: int | None = None,
) -> pl.DataFrame:
    if seed is not None:
        random.seed(seed)

    lf = pl.scan_parquet(path)

    total = lf.select(pl.len()).collect().item()
    if n_rows >= total:
        return lf.collect()

    # random indices
    idx = random.sample(range(total), n_rows)

    return (
        lf
        .with_row_count("_rowid")
        .filter(pl.col("_rowid").is_in(idx))
        .drop("_rowid")
        .collect()
    )

In [14]:
columns = [
    "splus_cut_F378",
    "splus_cut_F395",
    "splus_cut_F410",
    "splus_cut_F430",
    "splus_cut_F515",
    "splus_cut_F660",
    "splus_cut_F861",
    "splus_cut_R",
    "splus_cut_I",
    "splus_cut_Z",
    "splus_cut_U",
    "splus_cut_G",
]

bands = ["F378", "F395", "F410", "F430", "F515", "F660", "F861", "R", "I", "Z", "U", "G"]
cutout_size = 96

In [9]:
import polars as pl
from tqdm import tqdm

train_df = None

for f in tqdm(train_files, desc="Loading train files"):
    df = pl.read_parquet(f, columns=columns, use_pyarrow=True)
    df = df.filter(pl.col(columns[0]).is_not_null())

    if df.height == 0:
        continue

    train_df = df if train_df is None else pl.concat([train_df, df], how="vertical", rechunk=False)

train_df = train_df.rechunk()

Loading train files:   0%|          | 0/7 [00:00<?, ?it/s]

Loading train files: 100%|██████████| 7/7 [01:29<00:00, 12.74s/it]


In [10]:
import polars as pl
from tqdm import tqdm

val_df = None

for f in tqdm(val_files, desc="Loading val files"):
    df = pl.read_parquet(f, columns=columns, use_pyarrow=True)
    df = df.filter(pl.col(columns[0]).is_not_null())

    if df.height == 0:
        continue

    val_df = df if val_df is None else pl.concat([val_df, df], how="vertical", rechunk=False)

val_df = val_df.rechunk()

Loading val files: 100%|██████████| 3/3 [02:11<00:00, 43.93s/it]


In [15]:
from astromodal.datasets.spluscuts import SplusCutoutsDataset

train_dataset = SplusCutoutsDataset(
    train_df,
    bands=bands,
    img_size=cutout_size,
    return_valid_mask=True,
)
val_dataset = SplusCutoutsDataset(
    val_df,
    bands=bands,
    img_size=cutout_size,
    return_valid_mask=True,
)

In [24]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import polars as pl

def train_epoch(
    model: nn.Module,
    dataloader: DataLoader,
    optimizer: torch.optim.Optimizer,
    device: torch.device,
    scaler: torch.cuda.amp.GradScaler,
    use_amp: bool = True,
) -> float:
    model.train()
    total_loss = 0.0
    n_batches = 0

    for batch in tqdm(dataloader, desc="Training", leave=False):
        x_norm, m_valid = batch
        x_norm = x_norm.to(device, non_blocking=True)
        m_valid = m_valid.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)

        with torch.cuda.amp.autocast(enabled=use_amp):
            x_recon, _ = model(x_norm)

            # boolean mask for indexing
            mv = m_valid > 0.5

            # if a batch ever has no valid pixels, skip safely
            if mv.any():
                loss = F.mse_loss(x_recon[mv], x_norm[mv])
            else:
                loss = torch.zeros((), device=device)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += float(loss.detach().cpu())
        n_batches += 1

    return total_loss / max(n_batches, 1)

def validate(
    model: nn.Module,
    dataloader: DataLoader,
    device: torch.device,
) -> float:
    model.eval()
    total_loss = 0.0
    n_batches = 0

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validating", leave=False):
            x_norm, m_valid = batch
            x_norm = x_norm.to(device, non_blocking=True)
            m_valid = m_valid.to(device, non_blocking=True)

            x_recon, _ = model(x_norm)

            mv = m_valid > 0.5
            if mv.any():
                loss = F.mse_loss(x_recon[mv], x_norm[mv])
            else:
                loss = torch.zeros((), device=device)

            total_loss += float(loss.detach().cpu())
            n_batches += 1

    return total_loss / max(n_batches, 1)

In [32]:
batch_size = 64
max_gpu_batch_size = 32
num_epochs = 10
learning_rate = 1e-3
use_amp = True

In [None]:
from astromodal.models.autoencoder import AutoEncoder
from torch.amp.GradScaler import GradScaler

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [34]:
model = AutoEncoder(
    in_channels = len(bands),
    latent_dim = 2,
)

In [36]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_loss = train_epoch(model, train_dataset, optimizer, device, scaler, USE_AMP)

NameError: name 'scaler' is not defined

In [None]:
val_loss = validate(model, val_loader, device)