In [1]:
import pyarrow as pa
import pyarrow_hotfix
import torch
import yaml
import argparse
import torch
import torch.nn as nn
import numpy as np
import exlib

from datasets import load_dataset
from collections import namedtuple
from exlib.datasets.pretrain import setup_model_config, get_dataset, get_dataset, setup_model_config
from exlib.datasets.dataset_preprocess_raw import create_train_dataloader_raw, create_test_dataloader_raw
from exlib.datasets.informer_models import InformerConfig, InformerForSequenceClassification
from tqdm.auto import tqdm
pa.PyExtensionType.set_auto_load(True)
pyarrow_hotfix.uninstall()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

### Dataset

In [2]:
# load dataset
dataset = load_dataset("BrachioLab/supernova-timeseries")
train_dataset = dataset['train']
validation_dataset = dataset['validation']
test_dataset = dataset['test']

### Model predictions

In [4]:
# load model
model = InformerForSequenceClassification.from_pretrained("BrachioLab/supernova-classification")
model = model.to(device)
config = InformerConfig.from_pretrained("BrachioLab/supernova-classification")
test_dataloader = create_test_dataloader_raw(
    config=config,
    dataset=test_dataset,
    batch_size=256,
    compute_loss=True
)

num labels: 14
Using Fourier PE
classifier dropout: 0.2
original dataset size: 792
remove nans dataset size: 792


In [5]:
with torch.no_grad():
    y_true = []
    y_pred = []
    alignment_scores_all = []
    for bi, batch in tqdm(enumerate(test_dataloader), total=len(test_dataloader)):
        batch = {k: v.to(device) for k, v in batch.items() if k != "objid"}
        outputs = model(**batch)
        y_true.extend(batch['labels'].cpu().numpy())
        y_pred.extend(torch.argmax(outputs.logits, dim=2).squeeze().cpu().numpy())
# model prediction
print(f"accuracy: {sum([1 for i, j in zip(y_true, y_pred) if i == j]) / len(y_true)}")

  0%|          | 0/4 [00:00<?, ?it/s]

accuracy: 0.7967171717171717


### Alignment scores

In [6]:
# baseline
def baseline(valid_length):
    num_groups = (valid_length // 10) + 1
    pred_groups = []
    for group_idx in range(num_groups):
        start_index = group_idx * 10
        end_index = min((group_idx + 1) * 10, valid_length)
        group_list = [1 if start_index <= i < end_index else 0 for i in range(valid_length)]
        pred_groups.append(group_list)
    return pred_groups

In [7]:
# alignment
def process_group_pair(pred_group, true_group, device):
    pred_groups = torch.tensor(pred_group, dtype=torch.float32).to(device)
    true_groups = torch.tensor(true_group, dtype=torch.float32).to(device)
    pred_groups_bool = pred_groups.to(torch.bool)
    true_groups_bool = true_groups.to(torch.bool)
    intersections = (pred_groups_bool.unsqueeze(1) & true_groups_bool.unsqueeze(0)).float().sum(dim=2)
    unions = (pred_groups_bool.unsqueeze(1) | true_groups_bool.unsqueeze(0)).float().sum(dim=2)
    ious = intersections / unions
    ious = torch.nan_to_num(ious, nan=0.0)
    max_iou, _ = torch.max(ious, dim=1)
    avg_list = []
    for col in range(pred_groups.size(1)):
        mask = pred_groups[:, col] == 1
        if torch.any(mask):
            avg_iou = max_iou[mask].mean().item()
        else:
            avg_iou = 0
        avg_list.append(avg_iou)
    return avg_list

def calculate_alignment_scores(pred_groups_batch, true_groups_batch, device):
    alignment_scores = []
    for i in range(len(pred_groups_batch)):
        avg_list = process_group_pair(pred_groups_batch[i], true_groups_batch[i], device)
        alignment_score = sum(avg_list) / len(avg_list) if avg_list else 0
        alignment_scores.append(alignment_score)
    return alignment_scores

In [8]:
with torch.no_grad():
    alignment_scores_all = []
    for bi, batch in tqdm(enumerate(test_dataloader), total=len(test_dataloader)):
        # prediction
        batch = {k: v.to(device) for k, v in batch.items() if k != "objid"}
        
        times_wv_column = batch['past_time_features'].to('cpu')
        target_column = batch['past_values'].to('cpu')
        x_column = np.concatenate((times_wv_column, target_column), axis=2) # time, wavelength, flux, flux_error
        time_values = x_column[:, :, 0].tolist() # time_values is from 0 to 1, and if it is less than 300 random values
        
        # predicted group
        valid_time_values_batch = []
        zeros_batch = []
        valid_length_batch = []
        pred_groups_batch = []
        for idx, time_list in enumerate(time_values):
            valid_length = next((j for j in range(1, len(time_list)) if time_list[j] <= time_list[j-1]), len(time_list))
            valid_time_values_batch.append(time_list[:valid_length])
            
            pred_groups = baseline(valid_length)
            pred_groups_batch.append(pred_groups)
            # pred_groups_batch: batch_size * pred_group_num * valid_length

        # ground truth group - need to update
        true_groups_batch = pred_groups_batch
        #true_groups_batch = [[[0.0 for _ in sub_group] for sub_group in group] for group in pred_groups_batch]

        # alignment score
        alignment_scores = calculate_alignment_scores(pred_groups_batch, true_groups_batch, device)
        alignment_scores_all.extend(alignment_scores)
# all alignment score
print(f"average alignment score: {sum(alignment_scores_all) / len(alignment_scores_all)}")

  0%|          | 0/4 [00:00<?, ?it/s]

average alignment score: 1.0
