In [None]:
# Uncomment line below to install exlib
# !pip install exlib

In [1]:
import torch
import yaml
import argparse
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import sys
sys.path.append('/shared_data0/chaenyk/exlib/src')
import exlib
import math
import torch.nn.functional as F

from datasets import load_dataset
from collections import namedtuple
from exlib.datasets.supernova import SupernovaDataset, SupernovaClsModel, SupernovaFixScore, get_supernova_scores
from exlib.utils.supernova_helper import *
from tqdm.auto import tqdm

# Baselines
from exlib.features.time_series.identity import IdentityGroups
from exlib.features.time_series.random import RandomGroups
from exlib.features.time_series.slice import SliceGroups
from exlib.features.time_series.clustering import ClusterGroups
from exlib.features.time_series.archipelago import ArchipelagoGroups

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

### Overview
* The objective is to classify astronomical sources that vary with time into different classes

### Load datasets and pre-trained models

In [2]:
test_dataset = SupernovaDataset(data_dir = "BrachioLab/supernova-timeseries", split="test")
model = SupernovaClsModel(model_path = "BrachioLab/supernova-classification")

num labels: 14
Using Fourier PE
classifier dropout: 0.2


### Model prediction

In [3]:
model = model.to(device)
test_dataloader = create_test_dataloader(
    dataset=test_dataset,
    batch_size=5,
    compute_loss=True
)

original dataset size: 792
remove nans dataset size: 792


In [4]:
# model prediction
with torch.no_grad():
    y_true = []
    y_pred = []
    alignment_scores_all = []
    for bi, batch in tqdm(enumerate(test_dataloader), total=len(test_dataloader)):
        batch = {k: v.to(device) for k, v in batch.items() if k != "objid"}
        outputs = model(**batch)
        y_true.extend(batch['labels'].cpu().numpy())
        y_pred.extend(torch.argmax(outputs.logits, dim=2).squeeze().cpu().numpy())
# model prediction
print(f"accuracy: {sum([1 for i, j in zip(y_true, y_pred) if i == j]) / len(y_true)}")

  0%|          | 0/159 [00:00<?, ?it/s]

accuracy: 0.7967171717171717


### Feature alignment

In [5]:
test_dataloader = create_test_dataloader_raw(
    dataset=test_dataset,
    batch_size=5,
    compute_loss=True
)

original dataset size: 792
remove nans dataset size: 792


### Baselines
- Identity
- Random
- 5 slices
- 10 slices
- 15 slices
- Clustering
- Archipelago

In [6]:
scores = get_supernova_scores(dataset=test_dataset, batch_size = 5)

original dataset size: 792
remove nans dataset size: 792
num labels: 14
Using Fourier PE
classifier dropout: 0.2


  0%|          | 0/159 [00:00<?, ?it/s]

Avg alignment of identity features: 0.0060
Avg alignment of random features: 0.0145
Avg alignment of 5 features: 0.0146
Avg alignment of 10 features: 0.0257
Avg alignment of 15 features: 0.0247
Avg alignment of clustering features: 0.0649
Avg alignment of archipelago features: 0.0519
