In [1]:
import sys
import os

exlib_path = "/shared_data0/chaenyk/exlib"
sys.path.insert(0, exlib_path)
import pyarrow as pa
import pyarrow_hotfix
import torch
import yaml
import argparse
import torch
import torch.nn as nn
from datasets import load_dataset
from collections import namedtuple

import exlib
from exlib.datasets.pretrain import setup_model_config, get_dataset, get_dataset, setup_model_config
from exlib.datasets.dataset_preprocess_raw import create_train_dataloader_raw, create_test_dataloader_raw, create_test_dataloader
from exlib.datasets.informer_models import InformerConfig, InformerForSequenceClassification
from tqdm.auto import tqdm
pa.PyExtensionType.set_auto_load(True)
pyarrow_hotfix.uninstall()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

from example_function_supernova import Timeseries, get_llm_generated_answer, isolate_individual_features, is_claim_relevant, distill_relevant_features, calculate_expert_alignment_score

In [2]:
import openai
from openai import OpenAI

# Initialize the client with your API key
client = OpenAI(api_key="sk-None-ptTj0i2Hx0GJvjOunH0QT3BlbkFJcdhzvU2L1X5HU4F23HWB")  # Replace with your actual API key

### Dataset

In [3]:
# load dataset
dataset = load_dataset("BrachioLab/supernova-timeseries")
train_dataset = dataset['train']
validation_dataset = dataset['validation']
test_dataset = dataset['test']

In [4]:
test_dataset

Dataset({
    features: ['objid', 'times_wv', 'target', 'label', 'redshift'],
    num_rows: 792
})

### Dataset Samples

In [5]:
config = InformerConfig.from_pretrained("BrachioLab/supernova-classification")
test_dataloader = create_test_dataloader_raw(
    config=config,
    dataset=test_dataset,
    batch_size=25,
    compute_loss=True
)

original dataset size: 792
remove nans dataset size: 792


In [6]:
# SNIa class 11

In [7]:
def find_label_in_dataloader(dataloader, target_label=11):
    for batch in dataloader:
        indices = (batch['labels'] == target_label).nonzero(as_tuple=True)[0]
        
        if len(indices) > 0:
            label_values_time = batch['past_time_features'][:, indices, :]
            label_values = batch['past_values'][:, indices, :]
            return label_values_time, label_values
    
    return None
label_11_past_time, label_11_past_values = find_label_in_dataloader(test_dataloader, target_label=11)
if label_11_past_time is not None:
    print(f"Past time for label 11: {label_11_past_time}")
else:
    print("No instances of label 11 found in the dataset.")

Past time for label 11: tensor([[[59596.0703,  9710.2803],
         [59601.3477,  7545.9800],
         [59606.0312,  9710.2803],
         [59609.0352,  8590.9004],
         [59634.0312,  9710.2803],
         [59636.0234,  8590.9004],
         [59640.0195,  8590.9004],
         [59649.0039,  9710.2803],
         [59650.0547,  6223.2402],
         [59653.9961,  9710.2803],
         [59660.0273,  4826.8501],
         [59664.0000,  8590.9004],
         [59669.9922,  8590.9004],
         [59684.1172,  9710.2803],
         [59686.1172,  9710.2803],
         [59687.0898,  8590.9004],
         [59712.0391,  9710.2803],
         [59718.0195,  7545.9800],
         [59725.9570,  8590.9004]],

        [[59729.4297,  8590.9004],
         [59748.1406,  7545.9800],
         [59749.3477,  8590.9004],
         [59750.1211,  6223.2402],
         [59770.1328,  8590.9004],
         [59775.2773,  8590.9004],
         [59776.0469,  7545.9800],
         [59779.0938,  4826.8501],
         [59786.1289,  4826.8

In [8]:
n = 0
time_data = label_11_past_time[n,:,0]
wv_data = label_11_past_time[n,:,1]
value_data = label_11_past_values[n,:,0]

In [9]:
print(time_data.size())
print(wv_data.size())
print(value_data.size())

torch.Size([19])
torch.Size([19])
torch.Size([19])


In [10]:
### Extract Explanation
example = Timeseries(time=time_data, wv=wv_data, value=value_data)

prompt = '''Analyze this supernova time series data. This dataset represents astrophysical observations, where each time series consists of values corresponding to different wavelengths and the time at which these values were recorded.
Among the type of supernova(e.g., SNIa, SNIbc, SNIax, SNII, RRL, PISN) this time series data is classified as Type Ia supernova(SNIa).
Provide a reasoning chain for what interpretable time series features you see from this data that you use to make such predictions. Provide a short paragraph that is around 100-200 words.'''

analysis_result = get_llm_generated_answer(example=example, prompt=prompt)
print(analysis_result)

The provided time series data for the Type Ia supernova (SNIa) exhibits several interpretable features indicative of the distinct behavior characteristic of this type of supernova. SNIas are typically marked by a rapid rise to a peak luminosity followed by a gradual decline, which can be analyzed in this dataset through the value data across time. The significant fluctuations in the value data—such as the sharp peak at approximately 59650.0546875 days where we see a recorded value of 111.83946228027344—suggests a luminous event, likely representing the explosion of the supernova. Concurrently, the wavelengths captured, especially the pronounced values at 8590.900390625 and 9710.2802734375, align with the emission lines expected for SNIa due to the presence of specific elements like nickel and cobalt, which contribute to the light curve shape. The presence of both high and low values around those peaks potentially signals the interaction of ejected material with the surrounding environm

In [11]:
### Make it to atomic claims
raw_atomic_claims = isolate_individual_features(analysis_result)

print("Extracted atomic claims:")
for i, claim in enumerate(raw_atomic_claims, 1):
    print(f"{i}. {claim}")

Extracted atomic claims:
1. *Feature of SNIa*: The provided time series data exhibits distinct behavior characteristic of Type Ia supernova.
2. *Luminosity Behavior*: SNIas typically show a rapid rise to peak luminosity followed by a gradual decline.
3. *Data Analysis*: This behavior can be analyzed through the value data across time in the dataset.
4. *Sharp Peak*: There is a significant peak in the value data at approximately 59650.0546875 days with a recorded value of 111.83946228027344.
5. *Luminous Event*: The sharp peak suggests a luminous event, likely representing the supernova explosion.
6. *Wavelengths Captured*: The dataset captures significant wavelengths, particularly at 8590.900390625 and 9710.2802734375.
7. *Emission Lines*: The pronounced values at these wavelengths align with emission lines expected for SNIas due to specific elements.
8. *Element Presence*: Elements like nickel and cobalt contribute to the light curve shape of SNIas.
9. *Interaction Signals*: The prese

In [12]:
### Relevant claims
example = Timeseries(time=time_data, wv=wv_data, value=value_data)
answer = "The supernova classification for this data is Type 1a supernova."
atomic_claims = raw_atomic_claims

relevant_claims = is_claim_relevant(example, answer, atomic_claims)

print("Extracted relevant claims:")
for i, claim in enumerate(relevant_claims, 1):
    print(f"{i}. {claim}")

Extracted relevant claims:
1. *Feature of SNIa*: The provided time series data exhibits distinct behavior characteristic of Type Ia supernova, judgment: contained, explain why this claim is directly supported by the dataset: The time series data shows a specific trend in brightness and decline that aligns with typical behaviors observed in Type Ia supernovae, confirming that the dataset showcases these characteristics of SNIa.
2. *Luminosity Behavior*: SNIas typically show a rapid rise to peak luminosity followed by a gradual decline, judgment: contained, explain why this claim is directly supported by the dataset: The time series data exhibits a clear pattern of a rapid increase in recorded values leading to a peak, followed by a gradual decline, consistent with known luminosity patterns of Type Ia supernovae.
3. *Data Analysis*: This behavior can be analyzed through the value data across time in the dataset, judgment: contained, explain why this claim is directly supported by the dat

In [13]:
### Distilled relevant claims
example = Timeseries(time=time_data, wv=wv_data, value=value_data)
answer = "The supernova classification for this data is Type 1a supernova."

distilled_claims = distill_relevant_features(example, answer, relevant_claims)

print("Distilled relevant claims:")
for i, claim in enumerate(distilled_claims, 1):
    print(f"{i}. {claim}")

Distilled relevant claims:
1. *Feature of SNIa*: The provided time series data exhibits distinct behavior characteristic of Type Ia supernova
2. *Luminosity Behavior*: SNIas typically show a rapid rise to peak luminosity followed by a gradual decline
3. *Data Analysis*: This behavior can be analyzed through the value data across time in the dataset
4. *Sharp Peak*: There is a significant peak in the value data at approximately 59650.0546875 days with a recorded value of 111.83946228027344
5. *Luminous Event*: The sharp peak suggests a luminous event
6. *Wavelengths Captured*: The dataset captures significant wavelengths
7. *Emission Lines*: The pronounced values at these wavelengths align with emission lines expected for SNIas due to specific elements
8. *Element Presence*: Elements like nickel and cobalt contribute to the light curve shape of SNIas
9. *Interaction Signals*: The presence of both high and low values around the peaks signals interaction of ejected material with the surro

In [14]:
groundtruth_claims = [
    "Flux values are nonzero for most timestamps",
    "Multiple wavelength bands are observed per timestamp",
    "Each timestamp contains both flux and uncertainty values at least once",
    "The light curve exhibits temporal continuity."
]

In [20]:
### Compute alignment scores
scores = calculate_expert_alignment_score(distilled_claims, groundtruth_claims)
scores

{'individual_scores': [{'claim': '*Feature of SNIa*: The provided time series data exhibits distinct behavior characteristic of Type Ia supernova',
   'matched_ground_truth': None,
   'result': 'NO',
   'score': 0.0},
  {'claim': '*Luminosity Behavior*: SNIas typically show a rapid rise to peak luminosity followed by a gradual decline',
   'matched_ground_truth': None,
   'result': 'NO',
   'score': 0.0},
  {'claim': '*Data Analysis*: This behavior can be analyzed through the value data across time in the dataset',
   'matched_ground_truth': None,
   'result': 'NO',
   'score': 0.0},
  {'claim': '*Sharp Peak*: There is a significant peak in the value data at approximately 59650.0546875 days with a recorded value of 111.83946228027344',
   'matched_ground_truth': None,
   'result': 'NO',
   'score': 0.0},
  {'claim': '*Luminous Event*: The sharp peak suggests a luminous event',
   'matched_ground_truth': None,
   'result': 'NO',
   'score': 0.0},
  {'claim': '*Wavelengths Captured*: The

In [17]:
### pipeline
# Input: example, answer, explanation
# Step 1: isolate individual explanation / distill relevant features (hallucination)
# Step 2: determine expert alignment for each feature
# Step 3: aggregate (relevance / alignment / coverage)

In [18]:
### ground truth
# Not locally 0 (approximately)
# More bands per timestamp (globally)
# Error + fluxes per timestamp at least once (reason: signed to noise ratio)
# Continuity in time (Structure: Peak)