In [1]:
import json
from typing import TypedDict

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from constants import DATA_DIR
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset, random_split
from tqdm import tqdm

from astrofit.utils import AsteroidLoader

In [2]:
asteroid_loader = AsteroidLoader(DATA_DIR)

ASTEROIDS_FREQ_DATA_PATH = DATA_DIR / "asteroids_freq_data.json"

In [3]:
with open(ASTEROIDS_FREQ_DATA_PATH, "r") as f:
    asteroids_freq_data = json.load(f)

In [4]:
config, asteroids_data = asteroids_freq_data["config"], asteroids_freq_data["asteroids"]

In [5]:
filtered_data = {name: data for name, data in asteroids_data.items() if not data["is_failed"]}
print(f"Filtered {len(filtered_data)} asteroids ({100*(len(asteroids_data) - len(filtered_data)) / len(asteroids_data):.2f}% failed)")

Filtered 2662 asteroids (40.41% failed)


In [6]:
class AsteroidData(TypedDict):
    is_failed: bool
    reason: str | None
    period: float
    processing_time: float
    freq_features: list[list]  # 1 - 4 sequences of 50 floats (freqs) from 0 to 12
    pow_features: list[list]  # 1 - 4 sequences of 50 floats (pows) from 0 to 1 (the same shape as freq_features)

In [7]:
filtered_data = {name: AsteroidData(**data) for name, data in filtered_data.items()}

In [8]:
train_keys, val_test_keys = train_test_split(list(filtered_data.keys()), test_size=0.2, random_state=884288)
val_keys, test_keys = train_test_split(val_test_keys, test_size=0.33, random_state=884288)

print(f"Train: {len(train_keys)} asteroids ({100 * len(train_keys) / len(filtered_data):.2f})")
print(f"Validation: {len(val_keys)} asteroids ({100*len(val_keys) / len(filtered_data):.2f})")
print(f"Test: {len(test_keys)} asteroids ({100 * len(test_keys) / len(filtered_data):.2f})")

Train: 2129 asteroids (79.98)
Validation: 357 asteroids (13.41)
Test: 176 asteroids (6.61)


In [9]:
train_set, val_set, test_set = (
    {key: filtered_data[key] for key in train_keys},
    {key: filtered_data[key] for key in val_keys},
    {key: filtered_data[key] for key in test_keys},
)

In [10]:
def extract_features(data_set: dict[str, AsteroidData]) -> list[np.ndarray]:
    freq_features = [np.array(sample["freq_features"]) for sample in data_set.values()]
    pow_features = [np.array(sample["pow_features"]) for sample in data_set.values()]
    return [np.stack([freqs, pows], axis=2) for freqs, pows in zip(freq_features, pow_features)]


def standardize_and_pad_sequences(sequences: list[np.ndarray], max_len: int, scaler=None) -> tuple[np.ndarray, StandardScaler]:
    # Flatten all sequences
    flat_sequences = np.vstack([seq.reshape(-1, seq.shape[-1]) for seq in sequences])

    # Fit or transform with scaler
    if scaler is None:
        scaler = StandardScaler()
        flat_standardized = scaler.fit_transform(flat_sequences)
    else:
        flat_standardized = scaler.transform(flat_sequences)

    # Reshape back to original sequence shapes
    standardized_sequences = []
    start = 0
    for seq in sequences:
        end = start + seq.shape[0] * seq.shape[1]
        standardized_seq = flat_standardized[start:end].reshape(seq.shape)
        standardized_sequences.append(standardized_seq)
        start = end

    # Pad standardized sequences with a special value (e.g., -1000)
    padded = [
        np.pad(seq, ((0, max_len - len(seq)), (0, 0), (0, 0)), mode="constant", constant_values=-1000)
        for seq in standardized_sequences
    ]
    return np.stack(padded), scaler


In [11]:
train_features = extract_features(train_set)
val_features = extract_features(val_set)
test_features = extract_features(test_set)


In [12]:
# Find max length
max_len = max(  # equals to 4 but just in case
    max(len(seq) for seq in train_features),
    max(len(seq) for seq in val_features),
    max(len(seq) for seq in test_features),
)

# Standardize and pad
train_features_scaled, feature_scaler = standardize_and_pad_sequences(train_features, max_len)
val_features_scaled, _ = standardize_and_pad_sequences(val_features, max_len, scaler=feature_scaler)
test_features_scaled, _ = standardize_and_pad_sequences(test_features, max_len, scaler=feature_scaler)


In [13]:
# Find max value in train_features_scaled but not equal to -1000
max_value = np.max(train_features_scaled[train_features_scaled != -1000])
min_value = np.min(train_features_scaled[train_features_scaled != -1000])
min_value, max_value

(np.float64(-3.2857769060746778), np.float64(29.815852864803333))

In [14]:
train_periods = np.array([sample["period"] for sample in train_set.values()])
val_periods = np.array([sample["period"] for sample in val_set.values()])
test_periods = np.array([sample["period"] for sample in test_set.values()])


In [15]:
train_freqs = 24 / train_periods
val_freqs = 24 / val_periods
test_freqs = 24 / test_periods

# Standardize target frequencies
target_scaler = StandardScaler()
train_freqs_scaled = target_scaler.fit_transform(train_freqs.reshape(-1, 1)).flatten()
val_freqs_scaled = target_scaler.transform(val_freqs.reshape(-1, 1)).flatten()
test_freqs_scaled = target_scaler.transform(test_freqs.reshape(-1, 1)).flatten()

In [16]:
class AsteroidDataset(Dataset):
    def __init__(self, data: np.ndarray, targets: np.ndarray):
        self.data = torch.FloatTensor(data)
        self.targets = torch.FloatTensor(targets)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]

In [17]:
train_dataset = AsteroidDataset(train_features_scaled, train_freqs_scaled)
val_dataset = AsteroidDataset(val_features_scaled, val_freqs_scaled)
test_dataset = AsteroidDataset(test_features_scaled, test_freqs_scaled)

In [18]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [19]:
class AsteroidPeriodPredictor(nn.Module):
    def __init__(self):
        super(AsteroidPeriodPredictor, self).__init__()

        self.cnn = nn.Sequential(
            nn.Conv2d(2, 32, kernel_size=(3, 1), padding=(1, 0)),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=(3, 1), padding=(1, 0)),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=(3, 1), padding=(1, 0)),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1)),
        )

        self.fc = nn.Sequential(
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Linear(128, 1),
        )

    def forward(self, x):
        batch_size, num_sessions, num_freq, num_features = x.shape

        x = x.permute(0, 3, 1, 2)  # Change to (batch_size, channels, sessions, frequencies)

        x = self.cnn(x)

        x = x.view(batch_size, -1)  # Flatten: (batch_size, 128)

        x = self.fc(x)
        return x.squeeze()

In [20]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AsteroidPeriodPredictor().to(device)
model

AsteroidPeriodPredictor(
  (cnn): Sequential(
    (0): Conv2d(2, 32, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(32, 64, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Conv2d(64, 128, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0))
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
    (9): AdaptiveAvgPool2d(output_size=(1, 1))
  )
  (fc): Sequential(
    (0): Linear(in_features=128, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.25, inplace=False)
    (3): Linear(in_features=256, out_features=128, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.25, inplace=False)
    (6): Linear(in_features=128, out_features=1, bias=True)
  )
)

In [21]:
data, targets = next(iter(train_loader))
data.shape, targets.shape

(torch.Size([32, 4, 50, 2]), torch.Size([32]))

In [22]:
def train_model(model, train_loader, val_loader, num_epochs=1000, patience=50):
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.3, patience=40)

    best_val_loss = float("inf")
    epochs_without_improvement = 0
    best_model_state = None

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        train_preds, train_targets = [], []

        for data, targets in train_loader:
            data, targets = data.to(device), targets.to(device)

            optimizer.zero_grad()
            outputs = model(data)
            loss = criterion(outputs, targets)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            train_loss += loss.item()
            train_preds.extend(outputs.cpu().detach().numpy())
            train_targets.extend(targets.cpu().numpy())

        train_loss /= len(train_loader)
        train_r2 = r2_score(train_targets, train_preds)
        train_mae = mean_absolute_error(train_targets, train_preds)

        model.eval()
        val_loss = 0
        val_preds, val_targets = [], []

        with torch.no_grad():
            for data, targets in val_loader:
                data, targets = data.to(device), targets.to(device)
                outputs = model(data)
                val_loss += criterion(outputs, targets).item()
                val_preds.extend(outputs.cpu().numpy())
                val_targets.extend(targets.cpu().numpy())

        val_loss /= len(val_loader)
        val_r2 = r2_score(val_targets, val_preds)
        val_mae = mean_absolute_error(val_targets, val_preds)

        scheduler.step(val_loss)

        print(f"{epoch + 1}/{num_epochs} - ", end="")
        print(f"Train-loss: {train_loss:.4f}, Train-R2: {train_r2:.4f}, Train-MAE: {train_mae:.4f}", end="\t- ")
        print(f"Val-loss: {val_loss:.4f}, Val-R2: {val_r2:.4f}, Val-MAE: {val_mae:.4f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_without_improvement = 0
            best_model_state = model.state_dict()
        else:
            epochs_without_improvement += 1

        if epochs_without_improvement >= patience:
            print(f"\nEarly stopping triggered after {epoch + 1} epochs")
            break

    print(f"Best validation loss: {best_val_loss}")

    # Load the best model
    model.load_state_dict(best_model_state)
    return model

In [23]:
trained_model = train_model(model, train_loader, val_loader, patience=500)


1/1000 - Train-loss: 1.0180, Train-R2: -0.0179, Train-MAE: 0.8180	- Val-loss: 1.0877, Val-R2: 0.0086, Val-MAE: 0.8182
2/1000 - Train-loss: 1.0062, Train-R2: -0.0066, Train-MAE: 0.8090	- Val-loss: 1.1051, Val-R2: -0.0006, Val-MAE: 0.8216
3/1000 - Train-loss: 1.0014, Train-R2: -0.0017, Train-MAE: 0.8100	- Val-loss: 1.0993, Val-R2: 0.0024, Val-MAE: 0.8210
4/1000 - Train-loss: 0.9969, Train-R2: 0.0001, Train-MAE: 0.8092	- Val-loss: 1.1071, Val-R2: -0.0056, Val-MAE: 0.8196
5/1000 - Train-loss: 1.0075, Train-R2: -0.0008, Train-MAE: 0.8083	- Val-loss: 1.0997, Val-R2: -0.0018, Val-MAE: 0.8341
6/1000 - Train-loss: 1.0038, Train-R2: -0.0030, Train-MAE: 0.8127	- Val-loss: 1.0963, Val-R2: 0.0011, Val-MAE: 0.8219
7/1000 - Train-loss: 1.0034, Train-R2: -0.0005, Train-MAE: 0.8120	- Val-loss: 1.0948, Val-R2: 0.0030, Val-MAE: 0.8200
8/1000 - Train-loss: 0.9992, Train-R2: 0.0010, Train-MAE: 0.8087	- Val-loss: 1.0928, Val-R2: 0.0046, Val-MAE: 0.8236
9/1000 - Train-loss: 1.0015, Train-R2: -0.0003, Train-M

In [26]:
def evaluate_model(model, test_loader, device, target_scaler):
    model.eval()
    test_preds = []
    test_targets = []

    with torch.no_grad():
        for data, targets in tqdm(test_loader, desc="Evaluating on test set"):
            data, targets = data.to(device), targets.to(device)
            outputs = model(data)
            test_preds.extend(outputs.cpu().numpy())
            test_targets.extend(targets.cpu().numpy())

    test_preds = np.array(test_preds)
    test_targets = np.array(test_targets)

    # Inverse transform the predictions and targets if they were scaled
    if target_scaler is not None:
        test_preds = target_scaler.inverse_transform(test_preds.reshape(-1, 1)).flatten()
        test_targets = target_scaler.inverse_transform(test_targets.reshape(-1, 1)).flatten()

    # Calculate metrics
    mse = mean_squared_error(test_targets, test_preds)
    rmse = np.sqrt(mse)
    mae = mean_absolute_error(test_targets, test_preds)
    r2 = r2_score(test_targets, test_preds)

    # Convert frequencies back to periods
    test_periods_pred = 24 / test_preds
    test_periods_true = 24 / test_targets

    # Calculate period-specific metrics
    period_mae = mean_absolute_error(test_periods_true, test_periods_pred)
    period_mse = mean_squared_error(test_periods_true, test_periods_pred)
    period_rmse = np.sqrt(period_mse)
    period_r2 = r2_score(test_periods_true, test_periods_pred)

    return {
        "Frequency MSE": mse,
        "Frequency RMSE": rmse,
        "Frequency MAE": mae,
        "Frequency R2": r2,
        "Period MAE": period_mae,
        "Period RMSE": period_rmse,
        "Period R2": period_r2,
        "Predictions": test_preds,
        "True Values": test_targets,
        "Period Predictions": test_periods_pred,
        "True Periods": test_periods_true,
    }

In [27]:
results = evaluate_model(trained_model, test_loader, device, target_scaler)

# Print the results
for metric, value in results.items():
    if isinstance(value, np.ndarray):
        print(f"{metric}: {value[:5]} ... {value[-5:]}")
    else:
        print(f"{metric}: {value:.4f}")

Evaluating on test set: 100%|██████████| 6/6 [00:00<00:00, 257.81it/s]

Frequency MSE: 2.9596
Frequency RMSE: 1.7204
Frequency MAE: 1.3901
Frequency R2: 0.0065
Period MAE: 4.9119
Period RMSE: 8.1482
Period R2: -0.1685
Predictions: [3.4303386 3.2419407 3.4303422 3.4303393 3.4303439] ... [3.2433562 3.2103775 3.2103758 3.4303472 3.4303448]
True Values: [1.5298998  3.8621905  8.243286   0.60228866 1.2255026 ] ... [3.4471862 3.4124355 2.9316916 1.3523716 4.6582766]
Period Predictions: [6.9963937 7.402973  6.9963865 6.9963923 6.996383 ] ... [7.399742  7.4757566 7.4757605 6.9963765 6.9963813]
True Periods: [15.687302   6.2140903  2.9114602 39.848003  19.583801 ] ... [ 6.9622    7.0331    8.186399 17.746601  5.15212 ]



