In [32]:
import pyarrow as pa
import pyarrow_hotfix
from informer_models import InformerConfig, InformerForSequenceClassification
import torch
import yaml
import argparse
import torch
import torch.nn as nn

from datasets import load_dataset
from connect_later.pretrain import setup_model_config
from collections import namedtuple

from connect_later.dataset_preprocess_raw import create_train_dataloader_raw, create_test_dataloader_raw
from connect_later.informer_models import InformerForSequenceClassification
from connect_later.pretrain import get_dataset, setup_model_config

pa.PyExtensionType.set_auto_load(True)
pyarrow_hotfix.uninstall()

In [5]:
dataset = load_dataset("BrachioLab/supernova-timeseries")

In [6]:
train_dataset = dataset['train']
validation_dataset = dataset['validation']
test_dataset = dataset['test']

In [58]:
dataset

DatasetDict({
    train: Dataset({
        features: ['object_id', 'times_wv', 'lightcurve', 'label', 'redshift', 'hostgal_specz', 'hostgal_photoz', 'hostgal_photoz_err', 'ddf_bool'],
        num_rows: 7066
    })
    validation: Dataset({
        features: ['object_id', 'times_wv', 'lightcurve', 'label', 'redshift', 'hostgal_specz', 'hostgal_photoz', 'hostgal_photoz_err', 'ddf_bool'],
        num_rows: 782
    })
    test: Dataset({
        features: ['object_id', 'times_wv', 'lightcurve', 'label', 'redshift', 'hostgal_specz', 'hostgal_photoz', 'hostgal_photoz_err', 'ddf_bool'],
        num_rows: 3492890
    })
})

In [59]:
from collections import Counter

def count_labels(dataset):
    label_counts = Counter(dataset['label'])
    return dict(label_counts)

train_label_counts = count_labels(dataset['train'])
validation_label_counts = count_labels(dataset['validation'])
test_label_counts = count_labels(dataset['test'])

In [60]:
print("Train Label Counts:", train_label_counts)
print("Validation Label Counts:", validation_label_counts)
print("Test Label Counts:", test_label_counts)

Train Label Counts: {92: 215, 88: 333, 42: 1074, 90: 2082, 65: 883, 16: 832, 67: 187, 95: 158, 62: 436, 15: 446, 52: 165, 6: 136, 64: 92, 53: 27}
Validation Label Counts: {92: 24, 88: 37, 42: 119, 90: 231, 65: 98, 16: 92, 67: 21, 95: 17, 62: 48, 15: 49, 52: 18, 6: 15, 64: 10, 53: 3}
Test Label Counts: {42: 1000150, 90: 1659831, 16: 96572, 67: 40193, 62: 175094, 993: 9680, 92: 197155, 52: 63664, 88: 101424, 65: 93494, 991: 533, 992: 1702, 15: 13555, 95: 35782, 6: 1303, 53: 1453, 994: 1172, 64: 133}


In [50]:
#14 balanced classes
#15 balanced classes

In [61]:
unique_labels_train = set(dataset['train']['label'])
unique_labels_validation = set(dataset['validation']['label'])
unique_labels_test = set(dataset['test']['label'])

# Combine all unique labels across splits
all_unique_labels = unique_labels_train.union(unique_labels_validation)

print(all_unique_labels)

{64, 65, 67, 6, 15, 16, 88, 90, 92, 95, 42, 52, 53, 62}


In [62]:
print(unique_labels_train)
print(unique_labels_validation)
print(unique_labels_test)

{64, 65, 67, 6, 42, 15, 16, 52, 53, 88, 90, 92, 62, 95}
{64, 65, 67, 6, 42, 15, 16, 52, 53, 88, 90, 92, 62, 95}
{992, 993, 65, 67, 994, 64, 6, 42, 15, 16, 52, 53, 95, 88, 90, 92, 62, 991}


In [7]:
model = InformerForSequenceClassification.from_pretrained("BrachioLab/supernova-classification")

config.json:   0%|          | 0.00/1.80k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/88.6M [00:00<?, ?B/s]

num labels: 14
Using Fourier PE
classifier dropout: 0.2


In [8]:
config = InformerConfig.from_pretrained("BrachioLab/supernova-classification")
test_dataloader = create_test_dataloader_raw(
    config=config,
    dataset=test_dataset,
    batch_size=256,
    compute_loss=True
)

Map:   0%|          | 0/792 [00:00<?, ? examples/s]

original dataset size: 792


Filter:   0%|          | 0/792 [00:00<?, ? examples/s]

remove nans dataset size: 792


Flattening the indices:   0%|          | 0/792 [00:00<?, ? examples/s]

In [9]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
model.eval()
y_true = []
y_pred = []
for i, batch in enumerate(test_dataloader):
    print(f"processing batch {i}")
    batch = {k: v.to(device) for k, v in batch.items() if k != "objid"}
    with torch.no_grad():
        outputs = model(**batch)
    y_true.extend(batch['labels'].cpu().numpy())
    y_pred.extend(torch.argmax(outputs.logits, dim=2).squeeze().cpu().numpy())
print(f"accuracy: {sum([1 for i, j in zip(y_true, y_pred) if i == j]) / len(y_true)}")

processing batch 0
processing batch 1
processing batch 2
processing batch 3
accuracy: 0.8017676767676768


In [10]:
target_column = test_dataset['target']
times_wv_column = test_dataset['times_wv']
print(len(target_column[1]))

300


In [16]:
def count_humps(lst):
    hump_count = 0
    in_hump = False
    for elem in lst:
        if elem == 1 and not in_hump:
            hump_count += 1
            in_hump = True
        elif elem == 0:
            in_hump = False
    return hump_count

In [17]:
ones_count = [sublist.count(1) for sublist in daf_list]
humps_per_sublist = [count_humps(sublist) for sublist in daf_list]

In [18]:
print(len(humps_per_sublist))

792


In [33]:
class AccuracyMetric(nn.Module):
    def forward(self, predictions, ground_truth):
        predictions = torch.tensor(predictions)
        ground_truth = torch.tensor(ground_truth)
        correct = torch.eq(predictions, ground_truth).sum()
        accuracy = correct.float() / ground_truth.size(0)
        return accuracy

In [35]:
class PrecisionMetric(nn.Module):
    # Precision: TP / (TP + FP)
    def forward(self, predictions, ground_truth):
        predictions = torch.tensor(predictions)
        ground_truth = torch.tensor(ground_truth)
        true_positive = torch.logical_and(predictions == 1, ground_truth == 1).sum()
        predicted_positive = (predictions == 1).sum()
        precision = true_positive.float() / predicted_positive if predicted_positive != 0 else torch.tensor(0.)
        return precision

In [34]:
class RecallMetric(nn.Module):
    # Recall: TP / (TP + FN)
    def forward(self, predictions, ground_truth):
        predictions = torch.tensor(predictions)
        ground_truth = torch.tensor(ground_truth)
        true_positive = torch.logical_and(predictions == 1, ground_truth == 1).sum()
        actual_positive = (ground_truth == 1).sum()
        recall = true_positive.float() / actual_positive if actual_positive != 0 else torch.tensor(0.)
        return recall

In [36]:
class F1ScoreMetric(nn.Module):
    def __init__(self):
        super().__init__()
        self.precision_metric = PrecisionMetric()
        self.recall_metric = RecallMetric()
    
    def forward(self, predictions, ground_truth):
        precision = self.precision_metric(predictions, ground_truth)
        recall = self.recall_metric(predictions, ground_truth)
        f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) != 0 else torch.tensor(0.)
        return f1_score

In [42]:
predictions = [[0, 0, 0, 0, 1, 1, 0], [0, 0, 0, 0, 1, 1, 0]]
ground_truth = [[0, 0, 0, 1, 1, 1, 0], [0, 0, 0, 1, 1, 1, 0]]

metric = AccuracyMetric()
precision_metric = PrecisionMetric()
recall_metric = RecallMetric()
f1_score_metric = F1ScoreMetric()

accuracy = metric(predictions, ground_truth)
precision = precision_metric(predictions, ground_truth)
recall = recall_metric(predictions, ground_truth)
f1_score = f1_score_metric(predictions, ground_truth)

In [43]:
print("Accuracy:", accuracy.item())
print("Precision:", precision.item())
print("Recall:", recall.item())
print("F1 Score:", f1_score.item())

Accuracy: 0.8571428656578064
Precision: 1.0
Recall: 0.6666666865348816
F1 Score: 0.800000011920929


In [63]:
import torch
import torch.nn as nn

class MaxAccuracyMetric(nn.Module):
    def forward(self, predictions, ground_truth):
        max_accuracies = []
        for gt in ground_truth:
            accuracies = []
            for pred in predictions:
                correct = torch.eq(pred, gt).sum()
                accuracy = correct.float() / gt.size(0)
                accuracies.append(accuracy)
            max_accuracies.append(max(accuracies))
        
        return torch.tensor(max_accuracies)

In [64]:
predictions = [[0, 0, 0, 0, 1], [0, 0, 0, 1, 1]]
ground_truth = [[0, 0, 0, 1, 1], [0, 0, 0, 0, 1], [0, 0, 0, 1, 0]]
predictions = torch.tensor(predictions)
ground_truth = torch.tensor(ground_truth)

metric = MaxAccuracyMetric()

max_accuracies = metric(predictions, ground_truth)
print("Max Accuracies:", max_accuracies)

Max Accuracies: tensor([1.0000, 1.0000, 0.8000])
