In [1]:
import pyarrow as pa
import pyarrow_hotfix
import torch
import yaml
import argparse
import torch
import torch.nn as nn
import numpy as np

from datasets import load_dataset
from collections import namedtuple
from exlib.datasets.pretrain import setup_model_config, get_dataset, get_dataset, setup_model_config
from exlib.datasets.dataset_preprocess_raw import create_train_dataloader_raw, create_test_dataloader_raw
from exlib.datasets.informer_models import InformerConfig, InformerForSequenceClassification
pa.PyExtensionType.set_auto_load(True)
pyarrow_hotfix.uninstall()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")



In [2]:
# load dataset
dataset = load_dataset("BrachioLab/supernova-timeseries")
train_dataset = dataset['train']
validation_dataset = dataset['validation']
test_dataset = dataset['test']

In [3]:
# load model
model = InformerForSequenceClassification.from_pretrained("BrachioLab/supernova-classification")
model = model.to(device)
config = InformerConfig.from_pretrained("BrachioLab/supernova-classification")
test_dataloader = create_test_dataloader_raw(
    config=config,
    dataset=test_dataset,
    batch_size=256,
    compute_loss=True
)

  return self.fget.__get__(instance, owner)()


num labels: 14
Using Fourier PE
classifier dropout: 0.2




Map:   0%|          | 0/792 [00:00<?, ? examples/s]

original dataset size: 792


Filter:   0%|          | 0/792 [00:00<?, ? examples/s]

remove nans dataset size: 792


Flattening the indices:   0%|          | 0/792 [00:00<?, ? examples/s]

In [10]:
# baseline
def baseline(num_groups, group_length, device):
    groups = torch.zeros(num_groups, group_length, device=device)
    for i in range(num_groups):
        groups[i, 10*i:10*(i+1)] = 1
    return groups

In [11]:
# groundtruth - need to update
def groundtruth(group_length, device):
    group = torch.ones(group_length, dtype=torch.int)
    group[150:] = 0
    groups = torch.stack([group]*10)
    return groups

In [12]:
# alignment
# iou
def calculate_iou(tensor1, tensor2, group_length):
    tensor1_bool = tensor1.to(torch.bool)
    tensor2_bool = tensor2.to(torch.bool)

    intersection = (tensor1_bool & tensor2_bool).float().sum(dim=1)
    #union = (tensor1_bool | tensor2_bool).float().sum(dim=1)
    #iou = intersection / union
    iou = intersection / group_length
    return iou

# average
def compute_average_iou_per_column(pred_groups, best_iou_list, x_column_length):
    times_align = []
    
    for i in range(x_column_length):
        time_align = []
        for j in range(pred_groups.shape[0]):
            if pred_groups[j][i] == 1:
                time_align.append(best_iou_list[j])
        if time_align:
            avg_iou = sum(time_align) / len(time_align)
        else:
            avg_iou = 0
        times_align.append(avg_iou)

    return times_align

In [13]:
y_true = []
y_pred = []
batch_size = 256
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size)
for i, batch in enumerate(test_dataloader):
    target_column = batch['target']
    times_wv_column = batch['times_wv']
    x_column = np.concatenate((times_wv_column, target_column), axis=1)
    pred_groups = baseline(30, len(x_column), device).to(device)
    true_groups = groundtruth(len(x_column), device).to(device)
    best_iou_list = []

    for pred_group in pred_groups:
        ious = calculate_iou(pred_group.unsqueeze(0), true_groups, len(x_column))
        best_iou = ious.max().item()
        best_iou_list.append(best_iou)
        
    times_align = compute_average_iou_per_column(pred_groups, best_iou_list, len(x_column))
    alignment_score = sum(times_align) / len(times_align)
    print(alignment_score)

0.01666666753590107
0.01666666753590107
0.01666666753590107
0.01666666753590107


#### ground truth

In [9]:
target_column = test_dataset['target']
times_wv_column = test_dataset['times_wv']
x_column = np.concatenate((times_wv_column, target_column), axis=2)
x_column.shape
# time wv flux error

(792, 300, 4)

In [69]:
def find_example_indices_by_class(dataset):
    example_indices = {}  # Dictionary to store index of first occurrence for each class
    for index, data in enumerate(test_dataset):
        label = test_dataset['label'][index]
        if label not in example_indices:
            example_indices[label] = index
        if len(example_indices) == len(set(test_dataset['label'])):  # Early exit if all classes found
            break
    return example_indices

test_example_indices = find_example_indices_by_class(test_dataset)

print("Example indices by class for the test set:", test_example_indices)

Example indices by class for the test set: {12: 0, 10: 24, 3: 61, 11: 181, 8: 413, 2: 512, 9: 605, 13: 626, 6: 644, 1: 693, 4: 743, 0: 762, 7: 778, 5: 789}


In [70]:
def process_data(entries, s=3):
    data_by_wv = {}
    time_dictionary = {}

    for time, wavelength, flux, flux_error in entries:
        flux_std = 0 if (flux - (s * flux_error) <= 0 and flux + (s * flux_error) >= 0) else flux

        if wavelength not in data_by_wv:
            data_by_wv[wavelength] = {'time': [], 'flux': [], 'flux_std': [], 'flux_error': []}
        
        data_by_wv[wavelength]['time'].append(time)
        data_by_wv[wavelength]['flux'].append(flux)
        data_by_wv[wavelength]['flux_std'].append(flux_std)
        data_by_wv[wavelength]['flux_error'].append(flux_error)

        rounded_time = int(time)
        if rounded_time not in time_dictionary:
            time_dictionary[rounded_time] = []
        time_dictionary[rounded_time].append(flux_std)

    return data_by_wv, time_dictionary

In [71]:
def create_time_bool_dictionaries(w_size, time_dictionary):
    time_bool_dictionary = {k: 1 if any(v != 0 for v in values) else 0 for k, values in time_dictionary.items()}
    min_key = min(time_bool_dictionary.keys())
    max_key = max(time_bool_dictionary.keys())
    all_keys = range(min_key, max_key + 1)
    full_window_bool_dictionary = {key: time_bool_dictionary.get(key, 0) for key in all_keys}
    
    time_bool = list(full_window_bool_dictionary.values())
    
    window_bool = [1] * len(time_bool)
    window_size = w_size
    half_window = window_size // 2

    for i in range(len(time_bool)):
        start = max(0, i - half_window)
        end = min(len(time_bool), i + half_window + 1)
        if sum(time_bool[start:end]) == 0 and (end - start) == window_size:
            window_bool[i] = 0
    window_bool_dictionary = {time: window_bool[i] for i, time in enumerate(full_window_bool_dictionary)}
    full_window_bool_dictionary = window_bool_dictionary

    return time_bool_dictionary, window_bool_dictionary, full_window_bool_dictionary

In [72]:
def get_continuous_blocks(window_dict):
    continuous_blocks = []
    current_block = []

    sorted_times = sorted(window_dict.keys())
    for time in sorted_times:
        if window_dict[time] == 1:
            if not current_block:
                current_block = [time, time]  # Start new block
            else:
                current_block[1] = time  # Extend current block
        else:
            if current_block:
                continuous_blocks.append(current_block)
                current_block = []
    if current_block:
        continuous_blocks.append(current_block)  # Append the last block if any

    return continuous_blocks