In [1]:
import os
import sys

import numpy as np
import pandas as pd
import torch

sys.path.append(os.path.abspath("../.."))

import AstroChemNet.data_loading as dl
import AstroChemNet.data_processing as dp
from AstroChemNet.inference import Inference

from configs.autoencoder import AEConfig
from configs.emulator import EMConfig
from configs.general import GeneralConfig
from nn_architectures.autoencoder import Autoencoder, load_autoencoder

In [2]:
autoencoder = load_autoencoder(Autoencoder, GeneralConfig, AEConfig, inference=True)

processing = dp.Processing(GeneralConfig, AEConfig)

inference = Inference(
    GeneralConfig,
    processing,
    autoencoder,
)

Loading Pretrained Model
Setting Autoencoder to Inference Mode
Latents MinMax: -0.16997122764587402, 31.517663955688477


In [3]:
training_np, validation_np = dl.load_datasets(GeneralConfig, EMConfig.columns)

del training_np

In [4]:
unique_models = np.unique(validation_np[:, 1])

selected_models = np.random.choice(unique_models, 100, replace=False)

print(f"Selected {len(selected_models)} random models: {selected_models[:10]}...")

Selected 100 random models: [ 936.  793. 2113.  526. 7697. 1523. 8444. 9125. 9001.  917.]...


In [None]:
results = []

for model in selected_models:
    mask = validation_np[:, 1] == model
    subset = validation_np[mask]
    species_data = subset[:, -GeneralConfig.num_species :]
    latents = inference.encode(species_data).cpu().numpy()
    metadata_data = subset[:, : GeneralConfig.num_metadata]

    for i in range(len(subset)):
        row = {}
        for j, col in enumerate(GeneralConfig.metadata):
            row[col] = metadata_data[i, j]
        for j in range(AEConfig.latent_dim):
            row[f"latent_{j}"] = latents[i, j]
        results.append(row)

df = pd.DataFrame(results)

In [6]:
df

Unnamed: 0,Index,Model,Time,latent_0,latent_1,latent_2,latent_3,latent_4,latent_5,latent_6,latent_7,latent_8,latent_9,latent_10,latent_11,latent_12,latent_13
0,277992.0,936.0,0.000000,-0.005123,0.033869,0.966454,-0.159672,0.877128,-0.016465,-0.026586,-0.022536,-0.103704,-0.058041,0.148431,1.211116,-0.142424,-0.165564
1,277993.0,936.0,92.900002,-0.005585,0.046064,0.959038,-0.162984,0.879264,-0.019986,-0.030470,-0.023766,-0.101088,-0.065106,0.162406,1.231017,-0.144976,-0.164165
2,277994.0,936.0,185.800003,-0.005742,0.050054,0.956381,-0.163927,0.879857,-0.021255,-0.031810,-0.024182,-0.100215,-0.067481,0.167083,1.237340,-0.145793,-0.163672
3,277995.0,936.0,278.700012,-0.005838,0.052464,0.954718,-0.164463,0.880190,-0.022051,-0.032635,-0.024437,-0.099683,-0.068929,0.169928,1.241112,-0.146283,-0.163365
4,277996.0,936.0,371.600006,-0.005897,0.053918,0.953694,-0.164774,0.880384,-0.022541,-0.033139,-0.024592,-0.099360,-0.069807,0.171651,1.243370,-0.146577,-0.163177
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
29695,2396191.0,8067.0,27126.800781,-0.005986,0.056128,0.952110,-0.165229,0.880671,-0.023301,-0.033914,-0.024829,-0.098869,-0.071149,0.174277,1.246782,-0.147022,-0.162886
29696,2396192.0,8067.0,27219.699219,-0.005986,0.056128,0.952110,-0.165229,0.880671,-0.023301,-0.033914,-0.024829,-0.098869,-0.071149,0.174277,1.246782,-0.147022,-0.162886
29697,2396193.0,8067.0,27312.599609,-0.005986,0.056128,0.952110,-0.165229,0.880671,-0.023301,-0.033914,-0.024829,-0.098869,-0.071149,0.174277,1.246782,-0.147022,-0.162886
29698,2396194.0,8067.0,27405.500000,-0.005986,0.056128,0.952110,-0.165229,0.880671,-0.023301,-0.033914,-0.024829,-0.098869,-0.071149,0.174277,1.246782,-0.147022,-0.162886
