In [1]:
import numpy as np
import pandas as pd
from tqdm import tqdm

from astroquery.hips2fits import hips2fits
import sys 
import os

from astropy.wcs import WCS

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

from utils.sersic_functions import generate_random_pos
import torch

In [None]:
def get_wcs(ra, dec):
    wcs = WCS(naxis=2)
    wcs.wcs.crpix = [15.0, 15.0]
    wcs.wcs.cdelt = [-6.944444445183e-5, 6.944444445183e-5]  # grados/píxel
    wcs.wcs.crval = [ra, dec]
    wcs.wcs.ctype = ["RA---TAN", "DEC--TAN"]
    wcs.wcs.cunit = ["deg", "deg"]
    return wcs

def get_galaxy_img(ra, dec, level, size):
    
    w = WCS(header={
        'NAXIS': 2,
        'NAXIS1': size,
        'NAXIS2': size,
        'CTYPE1': 'RA---TAN',
        'CTYPE2': 'DEC--TAN',
        'CDELT1': -6.94444461259988E-05 * (2 ** level),  
        'CDELT2': 6.94444461259988E-05 * (2 ** level),  
        'CRPIX1': size/2,
        'CRPIX2': size/2,
        'CUNIT1': 'deg',
        'CUNIT2': 'deg',
        'CRVAL1': ra,
        'CRVAL2': dec,
    })

    result = hips2fits.query_with_wcs(
        hips='CDS/P/PanSTARRS/DR1/r',
        wcs=w,
        get_query_payload=False,
        format='fits')


    r = result[0].data.byteswap().newbyteorder()
    r = np.nan_to_num(r, 0)

    return r

def get_multires(df, idx, size, num_augmentations):
    '''
    Retorna la tupla (imgs, pos_host)

    imgs: imagenes en multi resolucion centradas en la SN
    pos_host: distancia desde el centro a la galaxia host en pixeles, en la forma (x,y)
    '''
    row = df.iloc[idx]
    radius_sersic = row["rSerRadius"]
    ab_sersic = row["rSerAb"]
    phi_sersic = row["rSerPhi"]

    host_ra = row["host_ra"]
    host_dec = row["host_dec"]

    # Posición arbitraria

    imagenes = []
    posiciones = []
    for x in range(num_augmentations):

        # Se genera una posicion de SN centrada en el host
        pos = generate_random_pos(
            sersic_radius=radius_sersic,
            sersic_ab=ab_sersic,
            sersic_phi=phi_sersic,
            img_size=600 # El radio maximo era 300 pix
        )

        wcs = get_wcs(host_ra, host_dec)
        
        # Le sumamos el centro a la posicion de la SN y obtenemos sus coordenadas (ra,dec)
        ra_sn, dec_sn = wcs.pixel_to_world_values([pos + 14])[0]

        # Obtenemos la imagen en multi-resolucion
        multi = []
        for i in range(5):
            img = get_galaxy_img(ra_sn, dec_sn, level=i, size=size)
            multi.append(img)

        imagenes.append(np.array(multi))
        posiciones.append(-pos) # Ahora la imagen esta centrada en la SN por lo que la posicion al host es lo opuesto

    return np.stack(imagenes) , np.stack(posiciones)

In [30]:
df = pd.read_csv("..\data\SERSIC\df_pasquet_train.csv")

In [22]:
data_1 = np.load("..\data\SERSIC\X_train_pasquet_linear_p1_augmented_x10.npz")
X_train_1 = data_1["imgs"]
y_train_1 = data_1["pos"]

data_2 = np.load("..\data\SERSIC\X_train_pasquet_linear_p2_augmented_x10.npz")
X_train_2 = data_2["imgs"]
y_train_2 = data_2["pos"]

data_3 = np.load("..\data\SERSIC\X_train_pasquet_linear_p3_augmented_x10.npz")
X_train_3 = data_3["imgs"]
y_train_3 = data_3["pos"]

In [24]:
X_train = np.concatenate([X_train_1, X_train_2, X_train_3])

In [26]:
y_train = np.concatenate([y_train_1, y_train_2, y_train_3])

In [33]:
mask_ceros = (X_train.sum((1,2))==0).any(1)

In [None]:
missing_idxs = np.array(range(len(mask_ceros)))[mask_ceros]

In [38]:
size=30
num_augmentations=10

In [39]:
for idx in tqdm(missing_idxs):
    while True:
        try:
            img, pos = get_multires(df, idx // num_augmentations, size=size, num_augmentations=1)

            X_train[idx] = img[0].transpose(1, 2, 0)
            y_train[idx] = pos

            break  
        except:
            pass

100%|██████████| 461/461 [13:52<00:00,  1.81s/it]


In [40]:
mask_ceros = (X_train.sum((1,2))==0).any(1)
mask_ceros.sum()

0

In [41]:
np.savez(f'..\data\SERSIC\X_train_pasquet_augmented_x10.npz', imgs=X_train, pos=y_train)

In [1]:
import numpy as np

data = np.load("..\\data\\SERSIC\\X_train_pasquet_augmented_x10.npz")

print(data.files)  # Verifica que 'imgs' y 'pos' están presentes

X_train = data["imgs"]
y_train = data["pos"]

print(X_train.shape, X_train.dtype)
print(y_train.shape, y_train.dtype)

['imgs', 'pos']
(605460, 30, 30, 5) float32
(605460, 2) float64
