In [1]:
# Downloading a bunch of data (takes a few minutes)
#!wget https://users.flatironinstitute.org/~flanusse/provabgs_legacysurvey_eval_v2.fits

In [2]:
import torch
import numpy as np

from aion.codecs import CodecManager
# Instantiate the manager
codec_manager = CodecManager(device="cuda")


from aion.model import AION

# Disable gradients for this notebook
torch.set_grad_enabled(False)
# Importing model
model = AION.from_pretrained("polymathic-ai/aion-base").to("cuda").eval()

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from tqdm import tqdm
from astropy.table import Table

from aion.modalities import (
    LegacySurveyImage,
    DESISpectrum,
    LegacySurveyFluxG,
    LegacySurveyFluxR,
    LegacySurveyFluxI,
    LegacySurveyFluxZ,
    Z,
)

# Loading the data
data = Table.read("provabgs_legacysurvey_eval_v2.fits")


# Define utility function to create modalities from the table
def format_data_modalities(data, device="cuda"):
    """Formats the input data into modality objects."""

    # Helper function
    def to_tensor(data_array, dtype="float32"):
        return torch.tensor(np.array(data_array).astype(dtype), device=device)

    # Create image modality
    image = LegacySurveyImage(
        flux=to_tensor(data["legacysurvey_image_flux"]),
        bands=["DES-G", "DES-R", "DES-I", "DES-Z"],
    )

    # Create spectrum modality
    spectrum = DESISpectrum(
        flux=to_tensor(data["desi_spectrum_flux"]),
        ivar=to_tensor(data["desi_spectrum_ivar"]),
        mask=to_tensor(data["desi_spectrum_mask"], dtype="bool"),
        wavelength=to_tensor(data["desi_spectrum_lambda"]),
    )

    # Create flux modalities
    g = LegacySurveyFluxG(value=to_tensor(data["legacysurvey_FLUX_G"]))
    r = LegacySurveyFluxR(value=to_tensor(data["legacysurvey_FLUX_R"]))
    i = LegacySurveyFluxI(value=to_tensor(data["legacysurvey_FLUX_I"]))
    z = LegacySurveyFluxZ(value=to_tensor(data["legacysurvey_FLUX_Z"]))

    return image, spectrum, g, r, i, z

In [4]:
# Create modalities from a batch of data
batch_size = 32
image, spectrum, g, r, i, z = format_data_modalities(data[:batch_size])

In [None]:
batch_size = 64
im_embeddings = []
sp_embeddings = []
all_embeddings = []

# Loop through the table in batches
for i in tqdm(range(0, 1, batch_size)):
    batch_data = data[i : i + batch_size]

    # Format data into modalities for the current batch
    image, spectrum, g, r, i, z = format_data_modalities(batch_data, device="cuda")

     # Helper function
    def to_tensor(data_array, dtype="float32"):
        return torch.tensor(np.array(data_array).astype(dtype), device="cuda")

    
    flux=to_tensor(batch_data["legacysurvey_image_flux"])
    print(flux.size())
    for i in range(4):
      array = flux[:,i,:,:].reshape(-1,)
      print(i, torch.mean(array), torch.std(array))

    
    spectrum =to_tensor(data["desi_spectrum_flux"])
    print(torch.mean(spectrum), torch.std(spectrum))
    print(spectrum.size())


    # Compute embeddings using the AION model
    im_embeddings.append(
        model.encode(codec_manager.encode(image), num_encoder_tokens=600).mean(axis=1)
    )

    sp_embeddings.append(
        model.encode(codec_manager.encode(spectrum), num_encoder_tokens=300).mean(
            axis=1
        )
    )

    all_embeddings.append(
        model.encode(
            codec_manager.encode(image, g, r, i, z), num_encoder_tokens=900
        ).mean(axis=1)
    )

# Concatenate the embeddings from all batches
im_embeddings = torch.cat(im_embeddings, dim=0).cpu().numpy()
sp_embeddings = torch.cat(sp_embeddings, dim=0).cpu().numpy()
all_embeddings = torch.cat(all_embeddings, dim=0).cpu().numpy()

print(f"Successfully processed {len(data)} images in batches of {batch_size}.")
print(f"Embeddings shape: {all_embeddings.shape}")

  0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([64, 4, 96, 96])
0 tensor(0.0032, device='cuda:0') tensor(0.0829, device='cuda:0')
1 tensor(0.0070, device='cuda:0') tensor(0.1411, device='cuda:0')
2 tensor(0.0099, device='cuda:0') tensor(0.1633, device='cuda:0')
3 tensor(0.0119, device='cuda:0') tensor(0.1609, device='cuda:0')
tensor(3.0287, device='cuda:0') tensor(4.1162, device='cuda:0')
torch.Size([3815, 7800])
