In [153]:
import sys
sys.path.insert(0, '..')

from datasets import load_dataset
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
import wandb
from tqdm import tqdm
import pandas as pd
import torch

from core.dataset import PSMDataset

In [116]:
class HFPSMDataset(Dataset):
    def __init__(self, ds, classes, seq_len, split='train'):
        super(HFPSMDataset, self).__init__()

        self.ds = ds[split]
        self.ds = self.ds.with_format('numpy')
        self.seq_len = seq_len
        self.split = split

        self.id2label = {i: x for i, x in enumerate(sorted(classes))}
        self.label2id = {v: k for k, v in self.id2label.items()}
        self.num_classes = len(classes)

    def preprocess_lc(self, X, aux_values):
        # Remove duplicate entries
        X = np.unique(X, axis=0)

        # Sort based on HJD
        sorted_indices = np.argsort(X[:, 0])
        X = X[sorted_indices]

        # Normalize
        mean = X[:, 1].mean()
        mad = stats.median_abs_deviation(X[:, 1])
        X[:, 1] = (X[:, 1] - mean) / mad
        X[:, 2] = X[:, 2] / mad

        # Save delta t before scaling
        delta_t = (X[:, 0].max() - X[:, 0].min()) / 365

        # Scale time from 0 to 1
        X[:, 0] = (X[:, 0] - X[:, 0].min()) / (X[:, 0].max() - X[:, 0].min())

        # Trim if longer than seq_len
        if X.shape[0] > self.seq_len:
            if self.split == 'train':   # random crop
                start = np.random.randint(0, len(X) - self.seq_len)
            else:  # 'center'
                start = (len(X) - self.seq_len) // 2

            X = X[start:start + self.seq_len, :]

        # Pad if needed and create mask
        mask = np.ones(self.seq_len)
        if X.shape[0] < self.seq_len:
            mask[X.shape[0]:] = 0
            X = np.pad(X, ((0, self.seq_len - X.shape[0]), (0, 0)), 'constant', constant_values=(0,))

        # Add mad and delta t to aux
        aux_values = np.concatenate((aux_values, [np.log10(mad), delta_t]))  

        # Add aux to X
        aux_values = np.tile(aux_values, (self.seq_len, 1))
        X = np.concatenate((X, aux_values), axis=-1)

        # Convert X and mask from float64 to float32
        X = X.astype(np.float32)
        mask = mask.astype(np.float32)

        return X, mask

    def preprocess_spectra(self, spectra):
        wavelengths = spectra[:, 0]
        flux = spectra[:, 1]
        flux_err = spectra[:, 2]

        new_wavelengths = np.arange(3850, 9000, 2)
        flux = np.interp(new_wavelengths, wavelengths, flux)
        flux_err = np.interp(new_wavelengths, wavelengths, flux_err)

        mean = np.mean(flux)
        mad = stats.median_abs_deviation(flux[flux != 0])
        
        flux = (flux - mean) / mad
        flux_err = flux_err / mad
        aux_values = np.full_like(flux, np.log10(mad))

        spectra = np.vstack([flux, flux_err, aux_values]).astype(np.float32)

        return spectra

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        el = self.ds[idx]
        
        label = self.label2id[el['label']]
        metadata = np.array(list(el['metadata']['meta_cols'].values()))

        photo_cols = np.array(list(el['metadata']['photo_cols'].values()))
        photometry, photometry_mask = self.preprocess_lc(el['photometry'], photo_cols)

        spectra = self.preprocess_spectra(el['spectra'])

        return photometry, photometry_mask, spectra, metadata, label

In [37]:
CLASSES = ['EW', 'SR', 'EA', 'RRAB', 'EB', 'ROT', 'RRC', 'HADS', 'M', 'DSCT']

In [130]:
ds = load_dataset('MeriDK/AstroM3Processed', name='full_42')

In [131]:
hf_train_dataset = HFPSMDataset(ds, classes=CLASSES, seq_len=200, split='train')
hf_val_dataset = HFPSMDataset(ds, classes=CLASSES, seq_len=200, split='validation')
hf_test_dataset = HFPSMDataset(ds, classes=CLASSES, seq_len=200, split='test')

In [114]:
for el in tqdm(hf_train_dataset):
    pass

In [120]:
len(hf_train_dataset), len(hf_val_dataset), len(hf_test_dataset)

In [121]:
api = wandb.Api()
run = api.run('meridk/AstroCLIPResults3/runs/3c2da15u')
config = run.config
config['use_wandb'] = False

In [122]:
train_dataset = PSMDataset(config, split='train')
val_dataset = PSMDataset(config, split='val')
test_dataset = PSMDataset(config, split='test')

In [143]:
df = pd.read_csv('/home/mariia/AstroML/data/asassn/preprocessed_data/full_lb/spectra_and_v_train.csv')

In [147]:
df.iloc[0][train_dataset.meta_cols].values.astype(np.float64)

In [148]:
p, pm, s, m, l = train_dataset[0]
hp, hpm, hs, hm, hl = hf_train_dataset[0]

In [150]:
len(hm)

In [151]:
hm.dtype

In [138]:
hp, hpm, hs, hm, hl = hf_train_dataset[0]
hm.dtype

In [159]:
hs.shape

In [123]:
len(train_dataset), len(val_dataset), len(test_dataset)

In [126]:
m

In [129]:
hm

In [60]:
l == hl

In [66]:
p.shape == hp.shape, pm.shape == hpm.shape, s.shape == hs.shape, m.shape == hm.shape

In [78]:
np.array_equal(p[:, :3], hp[:, :3]), np.array_equal(pm, hpm), np.array_equal(s, hs)

In [98]:
p[1, :3], hp[1, :3]

In [97]:
np.abs(p[:, :3] - hp[:, :3]) / p[:, :3] > 0.001

In [79]:
p[:, :3][1], hp[:, :3][1]

In [105]:
el = train_dataset.df.iloc[0]
p = train_dataset.get_vlc(el['name'])

In [106]:
hp = ds['train'][0]['photometry']

In [90]:
np.array_equal(p, hp)

In [91]:
p.dtype, hp.dtype

In [108]:
ds2 = load_dataset('MeriDK/AstroM3Dataset', name='sub10_42', trust_remote_code=True)
ds2 = ds2.with_format('numpy')

In [109]:
hp2 = ds2['train'][0]['photometry']

In [110]:
p[0], hp[0], hp2[0]

In [112]:
np.abs(p - hp) > 0.001

In [156]:
weights = torch.load("/home/mariia/AstroML/weights/2024-09-18-13-37-2wz4ysvn/weights-best.pth", weights_only=True)

In [158]:
weights.keys()