In [1]:
import json
from typing import TypedDict

import numpy as np
import torch
import torch.nn as nn
from constants import DATA_DIR
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

from astrofit.utils import AsteroidLoader

In [2]:
asteroid_loader = AsteroidLoader(DATA_DIR)

ASTEROIDS_FREQ_DATA_PATH = DATA_DIR / "asteroids_freq_data.json"

In [3]:
with open(ASTEROIDS_FREQ_DATA_PATH, "r") as f:
    asteroids_freq_data = json.load(f)

In [4]:
config, asteroids_data = asteroids_freq_data["config"], asteroids_freq_data["asteroids"]

In [5]:
filtered_data = {name: data for name, data in asteroids_data.items() if not data["is_failed"]}
print(f"Filtered {len(filtered_data)} asteroids ({100*(len(asteroids_data) - len(filtered_data)) / len(asteroids_data):.2f}% failed)")

Filtered 2662 asteroids (40.41% failed)


In [6]:
class AsteroidData(TypedDict):
    is_failed: bool
    reason: str | None
    period: float
    processing_time: float
    freq_features: list[list]  # 1 - 4 sequences of 50 floats (freqs) from 0 to 12
    pow_features: list[list]  # 1 - 4 sequences of 50 floats (pows) from 0 to 1 (the same shape as freq_features)

In [7]:
filtered_data = {name: AsteroidData(**data) for name, data in filtered_data.items()}

In [8]:
no_clipped = 0
for key in filtered_data:
    data = filtered_data[key]
    if len(data["freq_features"]) > 1:
        # Clip to just the first sequence
        data["freq_features"] = [data["freq_features"][0]]
        data["pow_features"] = [data["pow_features"][0]]

        no_clipped += 1

print(f"Clipped {no_clipped} asteroids")

Clipped 1652 asteroids


In [9]:
train_keys, val_test_keys = train_test_split(list(filtered_data.keys()), test_size=0.2, random_state=884288)
val_keys, test_keys = train_test_split(val_test_keys, test_size=0.33, random_state=884288)

print(f"Train: {len(train_keys)} asteroids ({100 * len(train_keys) / len(filtered_data):.2f})")
print(f"Validation: {len(val_keys)} asteroids ({100*len(val_keys) / len(filtered_data):.2f})")
print(f"Test: {len(test_keys)} asteroids ({100 * len(test_keys) / len(filtered_data):.2f})")

Train: 2129 asteroids (79.98)
Validation: 357 asteroids (13.41)
Test: 176 asteroids (6.61)


In [10]:
train_set, val_set, test_set = (
    {key: filtered_data[key] for key in train_keys},
    {key: filtered_data[key] for key in val_keys},
    {key: filtered_data[key] for key in test_keys},
)

In [11]:
def extract_features(data_set: dict[str, AsteroidData]) -> np.ndarray:
    freqs = np.array([data["freq_features"] for data in data_set.values()])
    powers = np.array([data["pow_features"] for data in data_set.values()])
    return np.stack([freqs, powers], axis=-1)

In [12]:
train_features = extract_features(train_set)
val_features = extract_features(val_set)
test_features = extract_features(test_set)

In [13]:
scaler = StandardScaler()
train_features_scaled: np.ndarray = scaler.fit_transform(
    train_features.reshape(-1, train_features.shape[-1]),
).reshape(train_features.shape)

# Transform validation and test data using the same scaler
val_features_scaled: np.ndarray = scaler.transform(
    val_features.reshape(-1, val_features.shape[-1]),
).reshape(val_features.shape)  # type: ignore

test_features_scaled: np.ndarray = scaler.transform(
    test_features.reshape(-1, test_features.shape[-1]),
).reshape(test_features.shape)  # type: ignore

In [14]:
train_periods = np.array([sample["period"] for sample in train_set.values()])
val_periods = np.array([sample["period"] for sample in val_set.values()])
test_periods = np.array([sample["period"] for sample in test_set.values()])

train_freqs = 24 / train_periods
val_freqs = 24 / val_periods
test_freqs = 24 / test_periods

# Standardize target frequencies using training set statistics
target_scaler = StandardScaler()
train_freqs_scaled: np.ndarray = target_scaler.fit_transform(train_freqs.reshape(-1, 1)).flatten()
val_freqs_scaled: np.ndarray = target_scaler.transform(val_freqs.reshape(-1, 1)).flatten()  # type: ignore
test_freqs_scaled: np.ndarray = target_scaler.transform(test_freqs.reshape(-1, 1)).flatten()  # type: ignore


In [15]:
class AsteroidDataset(Dataset):
    def __init__(self, data: np.ndarray, targets: np.ndarray):
        self.data = torch.FloatTensor(data)
        self.targets = torch.FloatTensor(targets)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]

In [16]:
train_dataset = AsteroidDataset(train_features_scaled, train_freqs_scaled)
val_dataset = AsteroidDataset(val_features_scaled, val_freqs_scaled)
test_dataset = AsteroidDataset(test_features_scaled, test_freqs_scaled)

In [17]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [18]:
data, targets = next(iter(train_loader))
print(data.shape, targets.shape)

torch.Size([32, 1, 50, 2]) torch.Size([32])


In [19]:
class AsteroidPeriodPredictor(nn.Module):
    def __init__(self):
        super(AsteroidPeriodPredictor, self).__init__()

        self.cnn = nn.Sequential(
            nn.Conv1d(2, 32, kernel_size=3, padding=1),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Conv1d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1),
        )

        self.fc = nn.Sequential(
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 1),
        )

    def forward(self, x):
        batch_size, num_sessions, num_freq, num_features = x.shape

        # Reshape to (batch_size, num_features, num_freq)
        x = x.squeeze(1)  # Remove the num_sessions dimension (which is 1)
        x = x.permute(0, 2, 1)

        x = self.cnn(x)
        x = x.view(batch_size, -1)  # Flatten: (batch_size, 128)
        x = self.fc(x)
        return x.squeeze()

In [20]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AsteroidPeriodPredictor().to(device)
model

AsteroidPeriodPredictor(
  (cnn): Sequential(
    (0): Conv1d(2, 32, kernel_size=(3,), stride=(1,), padding=(1,))
    (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv1d(32, 64, kernel_size=(3,), stride=(1,), padding=(1,))
    (4): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Conv1d(64, 128, kernel_size=(3,), stride=(1,), padding=(1,))
    (7): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
    (9): AdaptiveAvgPool1d(output_size=1)
  )
  (fc): Sequential(
    (0): Linear(in_features=128, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.3, inplace=False)
    (3): Linear(in_features=256, out_features=128, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.3, inplace=False)
    (6): Linear(in_features=128, out_features=1, bias=True)
  )
)

In [21]:
def train_model(model, train_loader, val_loader, num_epochs=1000, patience=50):
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=20)

    best_val_loss = float("inf")
    epochs_without_improvement = 0
    best_model_state = None

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        train_preds, train_targets = [], []

        for data, targets in train_loader:
            data, targets = data.to(device), targets.to(device)

            optimizer.zero_grad()
            outputs = model(data)
            loss = criterion(outputs, targets)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            train_loss += loss.item()
            train_preds.extend(outputs.cpu().detach().numpy())
            train_targets.extend(targets.cpu().numpy())

        train_loss /= len(train_loader)
        train_r2 = r2_score(train_targets, train_preds)
        train_mae = mean_absolute_error(train_targets, train_preds)

        model.eval()
        val_loss = 0
        val_preds, val_targets = [], []

        with torch.no_grad():
            for data, targets in val_loader:
                data, targets = data.to(device), targets.to(device)
                outputs = model(data)
                val_loss += criterion(outputs, targets).item()
                val_preds.extend(outputs.cpu().numpy())
                val_targets.extend(targets.cpu().numpy())

        val_loss /= len(val_loader)
        val_r2 = r2_score(val_targets, val_preds)
        val_mae = mean_absolute_error(val_targets, val_preds)

        scheduler.step(val_loss)

        print(f"{epoch + 1}/{num_epochs} - ", end="")
        print(f"Train-loss: {train_loss:.4f}, Train-R2: {train_r2:.4f}, Train-MAE: {train_mae:.4f}", end="\t- ")
        print(f"Val-loss: {val_loss:.4f}, Val-R2: {val_r2:.4f}, Val-MAE: {val_mae:.4f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_without_improvement = 0
            best_model_state = model.state_dict()
        else:
            epochs_without_improvement += 1

        if epochs_without_improvement >= patience:
            print(f"\nEarly stopping triggered after {epoch + 1} epochs")
            break

    print(f"Best validation loss: {best_val_loss}")

    # Load the best model
    model.load_state_dict(best_model_state)
    return model

In [22]:
trained_model = train_model(model, train_loader, val_loader, patience=100)

1/1000 - Train-loss: 0.9748, Train-R2: 0.0248, Train-MAE: 0.7890	- Val-loss: 1.1157, Val-R2: 0.0187, Val-MAE: 0.7909
2/1000 - Train-loss: 0.9581, Train-R2: 0.0464, Train-MAE: 0.7772	- Val-loss: 1.0923, Val-R2: 0.0237, Val-MAE: 0.8053
3/1000 - Train-loss: 0.9312, Train-R2: 0.0680, Train-MAE: 0.7679	- Val-loss: 1.0816, Val-R2: 0.0463, Val-MAE: 0.7883
4/1000 - Train-loss: 0.9276, Train-R2: 0.0745, Train-MAE: 0.7636	- Val-loss: 1.1136, Val-R2: 0.0160, Val-MAE: 0.8017
5/1000 - Train-loss: 0.9310, Train-R2: 0.0718, Train-MAE: 0.7695	- Val-loss: 1.1254, Val-R2: -0.0060, Val-MAE: 0.8099
6/1000 - Train-loss: 0.9103, Train-R2: 0.0891, Train-MAE: 0.7607	- Val-loss: 1.1232, Val-R2: -0.0061, Val-MAE: 0.8082
7/1000 - Train-loss: 0.9146, Train-R2: 0.0915, Train-MAE: 0.7581	- Val-loss: 1.1248, Val-R2: -0.0124, Val-MAE: 0.8017
8/1000 - Train-loss: 0.9128, Train-R2: 0.0872, Train-MAE: 0.7585	- Val-loss: 1.1017, Val-R2: 0.0201, Val-MAE: 0.7889
9/1000 - Train-loss: 0.9023, Train-R2: 0.1014, Train-MAE: 0.7

In [23]:
def evaluate_model(model, test_loader, device, target_scaler):
    model.eval()
    test_preds = []
    test_targets = []

    with torch.no_grad():
        for data, targets in tqdm(test_loader, desc="Evaluating on test set"):
            data, targets = data.to(device), targets.to(device)
            outputs = model(data)
            test_preds.extend(outputs.cpu().numpy())
            test_targets.extend(targets.cpu().numpy())

    test_preds = np.array(test_preds)
    test_targets = np.array(test_targets)

    # Inverse transform the predictions and targets if they were scaled
    if target_scaler is not None:
        test_preds = target_scaler.inverse_transform(test_preds.reshape(-1, 1)).flatten()
        test_targets = target_scaler.inverse_transform(test_targets.reshape(-1, 1)).flatten()

    # Calculate metrics
    mse = mean_squared_error(test_targets, test_preds)
    rmse = np.sqrt(mse)
    mae = mean_absolute_error(test_targets, test_preds)
    r2 = r2_score(test_targets, test_preds)

    # Convert frequencies back to periods
    test_periods_pred = 24 / test_preds
    test_periods_true = 24 / test_targets

    # Calculate period-specific metrics
    period_mae = mean_absolute_error(test_periods_true, test_periods_pred)
    period_mse = mean_squared_error(test_periods_true, test_periods_pred)
    period_rmse = np.sqrt(period_mse)
    period_r2 = r2_score(test_periods_true, test_periods_pred)

    return {
        "Frequency MSE": mse,
        "Frequency RMSE": rmse,
        "Frequency MAE": mae,
        "Frequency R2": r2,
        "Period MAE": period_mae,
        "Period RMSE": period_rmse,
        "Period R2": period_r2,
        "Predictions": test_preds,
        "True Values": test_targets,
        "Period Predictions": test_periods_pred,
        "True Periods": test_periods_true,
    }

In [24]:
results = evaluate_model(trained_model, test_loader, device, target_scaler)

# Print the results
for metric, value in results.items():
    if isinstance(value, np.ndarray):
        print(f"{metric}: {value[:5]} ... {value[-5:]}")
    else:
        print(f"{metric}: {value:.4f}")

Evaluating on test set: 100%|██████████| 6/6 [00:00<00:00, 787.98it/s]

Frequency MSE: 2.5998
Frequency RMSE: 1.6124
Frequency MAE: 1.1652
Frequency R2: 0.1273
Period MAE: 4.1774
Period RMSE: 7.1485
Period R2: 0.1006
Predictions: [2.5978858 3.2501724 4.708563  2.7569394 1.3529259] ... [3.7265263 3.5082579 3.182117  2.5361636 5.026518 ]
True Values: [1.5298998  3.8621905  8.243286   0.60228866 1.2255026 ] ... [3.4471862 3.4124355 2.9316916 1.3523716 4.6582766]
Period Predictions: [ 9.238281   7.384224   5.0970964  8.705305  17.73933  ] ... [6.4403143 6.8410025 7.542149  9.463112  4.7746773]
True Periods: [15.687302   6.2140903  2.9114602 39.848003  19.583801 ] ... [ 6.9622    7.0331    8.186399 17.746601  5.15212 ]



