In [1]:
import sys

sys.path.insert(0, '..')

In [34]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import joblib
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.utils.data import Dataset

In [None]:
"""
We used 16 features from Gaia EDR3, 2MASS and AllWISE. These
include the EDR3 G, BP, RP magnitudes and the associated uncertainties, the 𝐵𝑃−𝑅𝑃 color, the 𝐵𝑃−𝑅𝑃 excess factor, signal-to-noise
ratios in G and BP, the renormalized unit weight error (RUWE), the
𝐽 − 𝐾𝑠 color, the absolute 𝑊𝑅𝑃 magnitude and the absolute 𝑊𝐽 𝐾
magnitude. The EDR3 signal-to-noise ratios are essentially the ratio of the observed flux divided by the error in the flux. As noted
earlier, the EDR3 photometric uncertainties and flux errors encode
information about the photometric variability of stars. We also used
the absolute, “reddening-free” Wesenheit magnitudes (Madore 1982;
Lebzelter et al. 2018)
𝑊𝑅𝑃 = 𝑀RP − 1.3(𝐵𝑃 − 𝑅𝑃) , (1)
and
𝑊𝐽 𝐾 = 𝑀Ks − 0.686(𝐽 − 𝐾𝑠) (2)
and the probabilistic EDR3 distances from Bailer-Jones et al. (2021)
"""

In [70]:
def train_epoch():
    model.train()
    total_loss = []
    total_correct_predictions = 0
    total_predictions = 0

    for X, y in tqdm(train_dataloader):
        X, y = X.to(device), y.to(device)

        optimizer.zero_grad()

        logits = model(X)
        loss = criterion(logits, y)
        total_loss.append(loss.item())

        probabilities = torch.nn.functional.softmax(logits, dim=1)
        _, predicted_labels = torch.max(probabilities, dim=1)
        correct_predictions = (predicted_labels == y).sum().item()

        total_correct_predictions += correct_predictions
        total_predictions += y.size(0)

        loss.backward()
        optimizer.step()

    return sum(total_loss) / len(total_loss), total_correct_predictions / total_predictions
    
def val_epoch():
    model.eval()
    total_loss = []
    total_correct_predictions = 0
    total_predictions = 0

    with torch.no_grad():
        for X, y in tqdm(val_dataloader):
            X, y = X.to(device), y.to(device)

            logits = model(X)
            loss = criterion(logits, y)
            total_loss.append(loss.item())

            probabilities = torch.nn.functional.softmax(logits, dim=1)
            _, predicted_labels = torch.max(probabilities, dim=1)
            correct_predictions = (predicted_labels == y).sum().item()
    
            total_correct_predictions += correct_predictions
            total_predictions += y.size(0)

    return sum(total_loss) / len(total_loss), total_correct_predictions / total_predictions

In [41]:
METADATA_COLS = [
    'mean_vmag', 'amplitude', 'period', 'phot_g_mean_mag', 'e_phot_g_mean_mag', 'lksl_statistic',
    'rfr_score', 'phot_bp_mean_mag', 'e_phot_bp_mean_mag', 'phot_rp_mean_mag', 'e_phot_rp_mean_mag',
    'bp_rp', 'parallax', 'parallax_error', 'parallax_over_error', 'pmra', 'pmra_error', 'pmdec',
    'pmdec_error', 'j_mag', 'e_j_mag', 'h_mag', 'e_h_mag', 'k_mag', 'e_k_mag', 'w1_mag', 'e_w1_mag',
    'w2_mag', 'e_w2_mag', 'w3_mag', 'w4_mag', 'j_k', 'w1_w2', 'w3_w4', 'pm', 'ruwe'
]

In [80]:
class MetaVDataset(Dataset):
    def __init__(self, file, split='train', classes=None, min_samples=None, max_samples=None,
                 random_seed=42, verbose=True):
        self.df = pd.read_csv(file)
        self.metadata_cols = METADATA_COLS
        self.df = self.df[self.metadata_cols + ['edr3_source_id', 'variable_type']]

        self.split = split
        self.verbose = verbose
        self.classes = classes
        self.min_samples = min_samples
        self.max_samples = max_samples

        self.random_seed = random_seed
        np.random.seed(random_seed)

        self._drop_nan()
        self._drop_duplicates()
        self._filter_classes()
        self._limit_samples()
        self._split()
        self._normalize()

        self.id2target = {i: x for i, x in enumerate(sorted(self.df['variable_type'].unique()))}
        self.target2id = {v: k for k, v in self.id2target.items()}
        self.num_classes = len(self.id2target)

    def _drop_nan(self):
        if self.verbose:
            print('Dropping nan values...', end=' ')

        self.df.dropna(axis=0, how='any', inplace=True)

        if self.verbose:
            print(f'Done. Left with {len(self.df)} rows.')

    def _drop_duplicates(self):
        if self.verbose:
            print('Dropping duplicated values...', end=' ')

        self.df.drop_duplicates(subset=['edr3_source_id'], keep='last', inplace=True)

        if self.verbose:
            print(f'Done. Left with {len(self.df)} rows.')

    def _filter_classes(self):
        if self.classes:
            if self.verbose:
                print(f'Leaving only classes: {self.classes}... ', end='')

            self.df = self.df[self.df['variable_type'].isin(self.classes)]

            if self.verbose:
                print(f'{len(self.df)} objects left.')

    def _limit_samples(self):
        if self.max_samples or self.min_samples:
            if self.verbose:
                print(f'Removing objects that have more than {self.max_samples} or less than {self.min_samples} '
                      f'samples... ', end='')

            value_counts = self.df['variable_type'].value_counts()

            if self.min_samples:
                classes_to_remove = value_counts[value_counts < self.min_samples].index
                self.df = self.df[~self.df['variable_type'].isin(classes_to_remove)]

            if self.max_samples:
                classes_to_limit = value_counts[value_counts > self.max_samples].index
                for class_type in classes_to_limit:
                    class_indices = self.df[self.df['variable_type'] == class_type].index
                    indices_to_keep = np.random.choice(class_indices, size=self.max_samples, replace=False)
                    self.df = self.df.drop(index=set(class_indices) - set(indices_to_keep))

            if self.verbose:
                print(f'{len(self.df)} objects left.')

    def _split(self):
        unique_ids = self.df['edr3_source_id'].unique()
        train_ids, temp_ids = train_test_split(unique_ids, test_size=0.2, random_state=self.random_seed)
        val_ids, test_ids = train_test_split(temp_ids, test_size=0.5, random_state=self.random_seed)

        if self.split == 'train':
            self.df = self.df[self.df['edr3_source_id'].isin(train_ids)]
        elif self.split == 'val':
            self.df = self.df[self.df['edr3_source_id'].isin(val_ids)]
        elif self.split == 'test':
            self.df = self.df[self.df['edr3_source_id'].isin(test_ids)]
        else:
            print('Split is not train, val, or test. Keeping the whole dataset')

        if self.verbose:
            print(f'{self.split} split is selected: {len(self.df)} objects left.')

    def _normalize(self):
        if self.split == 'train':
            self.scaler = StandardScaler()
            self.scaler.fit(self.df[self.metadata_cols])
            joblib.dump(self.scaler, 'scaler.pkl')
        else:
            self.scaler = joblib.load('scaler.pkl')

        self.df[self.metadata_cols] = self.scaler.transform(self.df[self.metadata_cols])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        el = self.df.iloc[idx]
        X = el[self.metadata_cols].values.astype(np.float32)
        y = self.target2id[el['variable_type']]

        return X, y

In [65]:
class MetaClassifier(nn.Module):
    def __init__(self, input_dim=36, hidden_dim=128, num_classes=15):
        super(MetaClassifier, self).__init__()
        
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, num_classes)
        self.dropout = nn.Dropout(p=0.5)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)

        return x

In [81]:
file = '/home/mariia/AstroML/data/asassn/asassn_catalog_full.csv'
train_dataset = MetaVDataset(file, split='train', classes=None, min_samples=5000, max_samples=20000, random_seed=42, verbose=True)
val_dataset = MetaVDataset(file, split='val', classes=None, min_samples=5000, max_samples=20000, random_seed=42, verbose=True)

In [95]:
train_dataset[4]

In [82]:
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=128, shuffle=False)

In [83]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
print('Using', device)

model = MetaClassifier(num_classes=train_dataset.num_classes)
model = model.to(device)
optimizer = Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

train_losses, val_losses = [], []
train_accs, val_accs = [], []

In [85]:
for i in range(20):
    print(f'Epoch {i}')
    
    train_loss, train_acc = train_epoch()
    print(f'Train Loss: {round(train_loss, 3)} Acc: {round(train_acc, 2)}')
    
    val_loss, val_acc = val_epoch()
    print(f'Val Loss: {round(val_loss, 3)} Acc: {round(val_acc, 2)}')

    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accs.append(train_acc)
    val_accs.append(val_acc)

In [86]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Plot training and validation losses on the left side
ax1.plot(train_losses, label='Train Loss')
ax1.plot(val_losses, label='Validation Loss')
ax1.set_title('Training and Validation Loss')
ax1.set_xlabel('Epochs')
ax1.set_ylabel('Loss')
ax1.legend()

# Plot training and validation accuracies on the right side
ax2.plot(train_accs, label='Train Accuracy')
ax2.plot(val_accs, label='Validation Accuracy')
ax2.set_title('Training and Validation Accuracy')
ax2.set_xlabel('Epochs')
ax2.set_ylabel('Accuracy')
ax2.legend()

plt.show()

In [87]:
model.eval()

all_true_labels = []
all_predicted_labels = []

for X, y in tqdm(val_dataloader):
    with torch.no_grad():
        X = X.to(device)

        logits = model(X)
        probabilities = torch.nn.functional.softmax(logits, dim=1)
        _, predicted_labels = torch.max(probabilities, dim=1)

        all_true_labels.extend(y.numpy())
        all_predicted_labels.extend(predicted_labels.cpu().numpy())

In [88]:
sum([all_true_labels[i] == all_predicted_labels[i] for i in range(len(all_predicted_labels))])/len(all_predicted_labels)

In [90]:
# Calculate confusion matrix
conf_matrix = confusion_matrix(all_true_labels, all_predicted_labels)

# Calculate percentage values for confusion matrix
conf_matrix_percent = 100 * conf_matrix / conf_matrix.sum(axis=1)[:, np.newaxis]

    # Get the labels from the id2target mapping
labels = [val_dataset.id2target[i] for i in range(len(conf_matrix))]

# Plot both confusion matrices side by side
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(20, 7))

# Plot absolute values confusion matrix
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels, ax=axes[0])
axes[0].set_xlabel('Predicted')
axes[0].set_ylabel('True')
axes[0].set_title('Confusion Matrix - Absolute Values')

# Plot percentage values confusion matrix
sns.heatmap(conf_matrix_percent, annot=True, fmt='.0f', cmap='Blues', xticklabels=labels, yticklabels=labels, ax=axes[1])
axes[1].set_xlabel('Predicted')
axes[1].set_ylabel('True')
axes[1].set_title('Confusion Matrix - Percentages')

In [26]:
METADATA_COLS = [
    'mean_vmag', 'amplitude', 'period', 'phot_g_mean_mag', 'e_phot_g_mean_mag', 'lksl_statistic',
    'rfr_score', 'phot_bp_mean_mag', 'e_phot_bp_mean_mag', 'phot_rp_mean_mag', 'e_phot_rp_mean_mag',
    'bp_rp', 'parallax', 'parallax_error', 'parallax_over_error', 'pmra', 'pmra_error', 'pmdec',
    'pmdec_error', 'j_mag', 'e_j_mag', 'h_mag', 'e_h_mag', 'k_mag', 'e_k_mag', 'w1_mag', 'e_w1_mag',
    'w2_mag', 'e_w2_mag', 'w3_mag', 'w4_mag', 'j_k', 'w1_w2', 'w3_w4', 'pm', 'ruwe'
]

In [5]:
df = pd.read_csv('/home/mariia/AstroML/data/asassn/asassn_catalog_full.csv')

In [16]:
metadata_cols = ['mean_vmag', 'amplitude', 'period', 'phot_g_mean_mag', 'e_phot_g_mean_mag', 'lksl_statistic', 'rfr_score', 'phot_bp_mean_mag', 'e_phot_bp_mean_mag', 
                 'phot_rp_mean_mag', 'e_phot_rp_mean_mag', 'bp_rp', 'parallax', 'parallax_error', 'parallax_over_error', 'pmra', 'pmra_error', 'pmdec', 'pmdec_error', 
                 'j_mag', 'e_j_mag', 'h_mag', 'e_h_mag', 'k_mag', 'e_k_mag', 'w1_mag', 'e_w1_mag', 'w2_mag', 'e_w2_mag', 'w3_mag', 'e_w3_mag', 'w4_mag', 'e_w4_mag', 
                 'j_k', 'w1_w2', 'w3_w4', 'apass_vmag', 'e_apass_vmag', 'apass_bmag', 'e_apass_bmag', 'apass_gpmag', 'e_apass_gpmag', 'apass_rpmag', 'e_apass_rpmag', 
                 'apass_ipmag', 'e_apass_ipmag', 'FUVmag', 'e_FUVmag', 'NUVmag', 'e_NUVmag', 'pm', 'ruwe']

In [17]:
df = df[metadata_cols + ['edr3_source_id', 'variable_type']]

In [22]:
df['variable_type'].value_counts().head(30)

In [28]:
df = df[METADATA_COLS]

In [29]:
df.isna().sum() / len(df) * 100

In [31]:
df

In [32]:
df.dropna(axis=0, how='any', inplace=False)

In [None]:
class MetaVDataset(Dataset):
    def __init__(self, file, split='train', scales=None, classes=None, min_samples=None, max_samples=None, random_seed=42, verbose=True):
        self.df = pd.read_csv(file)
        self.split = split
        self.random_sample = random_sample
        self.verbose = verbose
        self.scales = scales
        self.classes = classes
        self.min_samples = min_samples
        self.max_samples = max_samples

        self._filter_classes()
        self._limit_samples()
        self._split()
        self._normalize()

        self.id2target = {i: x for i, x in enumerate(sorted(self.df['variable_type'].unique()))}
        self.target2id = {v: k for k, v in self.id2target.items()}
        self.num_classes = len(self.id2target)
        
    def _filter_classes(self):
        if self.classes:
            if self.verbose:
                print(f'Leaving only classes: {self.classes}... ', end='')

            self.df = self.df[self.df['variable_type'].isin(self.classes)]

            if self.verbose:
                print(f'{len(self.df)} objects left.')

    def _limit_samples(self):
        if self.max_samples or self.min_samples:
            if self.verbose:
                print(f'Removing objects that have more than {self.max_samples} or less than {self.min_samples} samples... ', end='')

            value_counts = self.df['variable_type'].value_counts()

            if self.min_samples:
                classes_to_remove = value_counts[value_counts < self.min_samples].index
                self.df = self.df[~self.df['variable_type'].isin(classes_to_remove)]

            if self.max_samples:
                classes_to_limit = value_counts[value_counts > self.max_samples].index
                for class_type in classes_to_limit:
                    class_indices = self.df[self.df['variable_type'] == class_type].index
                    indices_to_keep = np.random.choice(class_indices, size=self.max_samples, replace=False)
                    self.df = self.df.drop(index=set(class_indices) - set(indices_to_keep))

            if self.verbose:
                print(f'{len(self.df)} objects left.')

    def _split(self):
        unique_ids = self.df['edr3_source_id'].unique()
        train_ids, temp_ids = train_test_split(unique_ids, test_size=0.2, random_state=self.random_seed)
        val_ids, test_ids = train_test_split(temp_ids, test_size=0.5, random_state=self.random_seed)

        if self.split == 'train':
            self.df = self.df[self.df['edr3_source_id'].isin(train_ids)]
        elif self.split == 'val':
            self.df = self.df[self.df['edr3_source_id'].isin(val_ids)]
        elif self.split == 'test':
            self.df = self.df[self.df['edr3_source_id'].isin(test_ids)]
        else:
            print('Split is not train, val, or test. Keeping the whole dataset')

        if self.verbose:
            print(f'{self.split} split is selected: {len(self.df)} objects left.')

    def _normalize(self):
        if self.split in ('val', 'test') and self.scales is None:
            raise Error('Scales must be provided for val/test splits')

        pass
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        el = self.df.iloc[idx]

        X = self.get_vlc(el['name']) if el['band'] == 'v' else self.get_glc(el['name'])
        X, mask = self.preprocess(X, el['period'], el['band'])
        y = self.target2id[el['target']]

        return X, mask, y