In [1]:
import sys
import os
import torch
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

sys.path.append(os.path.abspath("../.."))

import AstroChemNet.data_processing as dp
from AstroChemNet.inference import Inference
import AstroChemNet.data_loading as dl
from configs.autoencoder import AEConfig
from configs.general import GeneralConfig
from nn_architectures.autoencoder import Autoencoder, load_autoencoder

In [2]:
autoencoder = load_autoencoder(Autoencoder, GeneralConfig, AEConfig, inference=True)

processing = dp.Processing(GeneralConfig, AEConfig)
inference = Inference(
    GeneralConfig,
    processing,
    autoencoder,
)

Loading Pretrained Model
Setting Autoencoder to Inference Mode
Latents MinMax: -0.16997122764587402, 40.74977493286133


In [3]:
training_np, validation_np = dl.load_datasets(GeneralConfig, AEConfig.columns)
del training_np

processing.abundances_scaling(validation_np)
validation_dataset = torch.from_numpy(validation_np)

validation_Dataset = dl.AutoencoderDataset(validation_dataset)

validation_dataloader = dl.tensor_to_dataloader(AEConfig, validation_Dataset)

Data_matrix Memory usage: 1129.568 MB


In [4]:
def calculate_errors(outputs, targets):
    targets = processing.inverse_abundances_scaling(targets)
    errors = torch.abs((outputs - targets) / targets)
    return errors.mean(dim=0)

species_errors = np.zeros(333)
for i, features in enumerate(validation_dataloader):
    features = features[0].to(GeneralConfig.device)
    latents = inference.encode(features)
    outputs = inference.decode(latents)

    species_errors += calculate_errors(outputs, features).numpy()

species_errors /= len(validation_dataloader)
print(f"Mean Error: {species_errors.mean():.8f}")
print(f"Std Error: {species_errors.std():.8f}")
print(f"Max Error: {species_errors.max():.8f}")

Mean Error: 0.01821904
Std Error: 0.01168675
Max Error: 0.05770965


In [None]:
epochs_path = os.path.splitext(AEConfig.save_model_path)[0] + ".json"

with open(epochs_path, "r") as f:
    data = json.load(f)

df = pd.DataFrame(data)

plt.figure(figsize=(12, 6))
plt.plot(df.index, np.log10(df["mean"]), label="Mean")

dropout_changes = df["dropout"].diff() < 0
for idx in df.index[dropout_changes]:
    plt.axvline(
        x=idx,
        color="red",
        linestyle="--",
        linewidth=1,
        label="Dropout decrease"
        if "Dropout decrease" not in plt.gca().get_legend_handles_labels()[1]
        else "",
    )

lr_changes = df["learning_rate"].diff() < 0
for idx in df.index[lr_changes]:
    plt.axvline(
        x=idx,
        color="blue",
        linestyle="--",
        linewidth=1,
        label="Learning rate decrease"
        if "Learning rate decrease" not in plt.gca().get_legend_handles_labels()[1]
        else "",
    )

plt.title("Log Mean Relative Error vs. Epochs Training Plot")
plt.xlabel("Epochs")
plt.ylabel("Log Mean Relative error")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

df