In [21]:
from typing import NamedTuple

import numpy as np
import plotly.graph_objects as go
from constants import DATA_DIR
from sklearn.model_selection import train_test_split
from xgboost import XGBRegressor

from astrofit.model import Asteroid, Lightcurve
from astrofit.utils import AsteroidLoader, LightcurvePlotter, LightcurveSplitter

In [2]:
asteroid_loader = AsteroidLoader(DATA_DIR)
lightcurve_splitter = LightcurveSplitter()
ligthcurve_plotter = LightcurvePlotter()

In [3]:
asteroids = asteroid_loader.load_asteroids()

In [4]:
periods = [asteroid.period for asteroid in asteroids.values()]
print(f"Percentiles: {np.percentile(periods, [0, 5, 25, 50, 75, 85, 95, 100]).tolist()}")

max_period = 45
satisfying_asteroids = {asteroid.name: asteroid for asteroid in asteroids.values() if asteroid.period <= max_period}
print(
    f"If we cut-off asteroids with max period={max_period} we will have {len(satisfying_asteroids)} ({100 * len(satisfying_asteroids) / len(periods):.4f}%) asteroids"
)

Percentiles: [2.031527, 3.422438, 5.48469, 8.78678, 16.97036, 29.069799999999976, 108.69359999999969, 1304.1]
If we cut-off asteroids with max period=45 we will have 4514 (89.2624%) asteroids


In [5]:
asteroids = satisfying_asteroids

In [6]:
def split_lightcurves(asteroid: Asteroid, max_hours_diff: int = 1, min_no_points: int = 20) -> list[Lightcurve]:
    """
    Parameters default values are set based on experiments conducted with the period prediction.
    """
    return lightcurve_splitter.split_lightcurves(
        asteroid.lightcurves,
        max_hours_diff=max_hours_diff,
        min_no_points=min_no_points,
    )


def filter_anomalous_series(lightcurves: list[Lightcurve], magnitude_threshold: int = 2):
    magnitude_data = [lc.brightness_arr for lc in lightcurves]

    if not magnitude_data:
        return lightcurves

    medians = np.array([np.median(series) for series in magnitude_data])
    overall_median = np.median(medians)
    ratios = medians / overall_median

    is_anomaly = np.logical_or(ratios > 10**magnitude_threshold, ratios < 10 ** (-magnitude_threshold))

    valid_lightcurves = [lc for i, lc in enumerate(lightcurves) if not is_anomaly[i]]

    return valid_lightcurves

In [7]:
print(f"Pre filtering: {len(asteroids)} asteroids")
filtered_lightcurves = {}
for ast_name, ast_data in asteroids.items():
    splitted = split_lightcurves(ast_data)
    filtered = filter_anomalous_series(splitted)

    if filtered:
        filtered_lightcurves[ast_name] = {"asteroid": ast_data, "lightcurves": filtered}
    
print(f"Post filtering: {len(filtered_lightcurves)} asteroids ({100 * len(filtered_lightcurves) / len(asteroids):.4f}%)")


Pre filtering: 4514 asteroids
Post filtering: 609 asteroids (13.4914%)


In [8]:
lc_periods = []
for data in filtered_lightcurves.values():
    for lc in data["lightcurves"]:
        lc_periods.append(lc.period)

In [9]:
fig = go.Figure()

fig.add_trace(go.Histogram(x=lc_periods, histnorm="percent", name="Periods"))
fig.update_layout(title_text="Periods distribution", xaxis_title="Period", yaxis_title="Frequency (%)")
fig.show()

In [10]:
def cartesian_to_spherical(x: float, y: float, z: float) -> tuple[float, float, float]:
    r = np.sqrt(x ** 2 + y ** 2 + z ** 2)
    theta = np.arccos(z / r)
    phi = np.arctan2(y, x)

    return float(r), float(theta), float(phi)

In [11]:
lcs: list[list[Lightcurve]] = list(map(lambda x: x["lightcurves"], filtered_lightcurves.values()))
lcs[0]

[Lightcurve(id=7192, period=8.33419h, points_count=54, first_JD=2454939.919305, last_JD=2454940.266563),
 Lightcurve(id=7192, period=7.78769h, points_count=43, first_JD=2454951.90286, last_JD=2454952.227347),
 Lightcurve(id=7192, period=5.11879h, points_count=61, first_JD=2454972.650074, last_JD=2454972.863357),
 Lightcurve(id=7192, period=3.17146h, points_count=46, first_JD=2454973.655351, last_JD=2454973.787495)]

In [12]:
targets: list[tuple[float, float]] = list(
    map(lambda ast_data: (ast_data["asteroid"].lambd, ast_data["asteroid"].beta), filtered_lightcurves.values())
)
targets[0]

(274.0, -68.0)

In [13]:
class FeatureTuple(NamedTuple):
    JD: float
    min_brightness: float
    max_brightness: float
    mean_brightness: float
    median_brightness: float
    std_brightness: float
    r_earth: float
    theta_earth: float
    phi_earth: float
    r_sun: float
    theta_sun: float
    phi_sun: float

In [14]:
asteroid_features = []
for asteroid_lcs in lcs:
    features = []

    first_lc = asteroid_lcs[0]
    
    ref_JD = (first_lc.last_JD + first_lc.first_JD) / 2
    
    all_brigh = [lc.brightness_arr for lc in asteroid_lcs]
    max_b = max([val for arr in all_brigh for val in arr])
    min_b = min([val for arr in all_brigh for val in arr])

    for lc in asteroid_lcs:
        scaled_brightness_arr = (np.array(lc.brightness_arr) - min_b) / (max_b - min_b)
        features.append(
            FeatureTuple(
                (lc.last_JD + lc.first_JD) / 2 - ref_JD,
                float(scaled_brightness_arr.min()),
                float(scaled_brightness_arr.max()),
                float(scaled_brightness_arr.mean()),
                float(np.median(scaled_brightness_arr)),
                float(scaled_brightness_arr.std()),
                *cartesian_to_spherical(
                    *np.mean([[point.x_earth, point.y_earth, point.z_earth] for point in lc.points], axis=0)
                ),
                *cartesian_to_spherical(*np.mean([[point.x_sun, point.y_sun, point.z_sun] for point in lc.points], axis=0)),
            )
        )

    asteroid_features.append(features)

In [15]:
asteroid_features[0]

[FeatureTuple(JD=0.0, min_brightness=0.06698639578861265, max_brightness=0.9679608827329206, mean_brightness=0.5394594271925635, median_brightness=0.5010150323789354, std_brightness=0.2501093044181226, r_earth=1.206444268498147, theta_earth=1.7053892493704461, phi_earth=0.5463802556772964, r_sun=2.2050836837997836, theta_sun=1.6442810695267913, phi_sun=0.524353804651573),
 FeatureTuple(JD=11.972169499844313, min_brightness=0.022325746956144227, max_brightness=0.8720809381838064, mean_brightness=0.5394594813878091, median_brightness=0.6049295055629328, std_brightness=0.252372296397668, r_earth=1.2165818743965822, theta_earth=1.6910537384369488, phi_earth=0.49460537798068943, r_sun=2.208249473272179, theta_sun=1.6369405697467934, phi_sun=0.588939620709234),
 FeatureTuple(JD=32.66378149949014, min_brightness=0.0, max_brightness=1.0, mean_brightness=0.5394593619503435, median_brightness=0.6143590733482894, std_brightness=0.31275606223138386, r_earth=1.3139208289785724, theta_earth=1.660081

In [16]:
brig_arr = []
for ast in asteroid_features:
    for feature in ast:
        feature: FeatureTuple
        brig_arr.append(
            [
                feature.min_brightness,
                feature.mean_brightness,
                feature.median_brightness,
                feature.max_brightness,
            ]
        )

np.percentile(np.array(brig_arr), [0, 25, 50, 75, 100], axis=0)

array([[0.00000000e+00, 2.85735481e-04, 2.92919021e-04, 4.93597974e-04],
       [1.20503724e-01, 4.64065535e-01, 4.54225625e-01, 6.41660342e-01],
       [2.35539222e-01, 5.02769106e-01, 5.08431515e-01, 7.49350347e-01],
       [3.54732685e-01, 5.40109617e-01, 5.55910573e-01, 8.55905869e-01],
       [9.28908196e-01, 9.65893420e-01, 9.82374260e-01, 1.00000000e+00]])

In [17]:
print(f"Total asteroids: {len(asteroid_features)}")
print(f"Total lightcurves: {sum(map(len, asteroid_features))}")
print(f"Average lightcurves per asteroid: {np.mean(list(map(len, asteroid_features)))}")
print(f"No. features per lightcurve: {len(asteroid_features[0][0])}")

Total asteroids: 609
Total lightcurves: 10831
Average lightcurves per asteroid: 17.78489326765189
No. features per lightcurve: 12


In [23]:
X_train, X_val_test, y_train, y_val_test = train_test_split(asteroid_features, targets, test_size=0.2)
X_val, X_test, y_val, y_test = train_test_split(X_val_test, y_val_test, test_size=0.33)

[len(x) for x in X_train]

[4,
 6,
 6,
 7,
 1,
 28,
 4,
 19,
 22,
 4,
 35,
 34,
 6,
 2,
 11,
 12,
 9,
 9,
 24,
 17,
 1,
 5,
 2,
 6,
 20,
 7,
 73,
 4,
 22,
 36,
 4,
 13,
 1,
 6,
 1,
 14,
 1,
 8,
 3,
 3,
 31,
 30,
 36,
 8,
 14,
 20,
 11,
 1,
 6,
 7,
 34,
 12,
 31,
 17,
 32,
 5,
 2,
 3,
 17,
 12,
 14,
 10,
 50,
 4,
 61,
 20,
 39,
 17,
 30,
 19,
 143,
 1,
 3,
 2,
 5,
 11,
 85,
 9,
 1,
 14,
 19,
 73,
 37,
 1,
 7,
 32,
 40,
 1,
 18,
 13,
 10,
 5,
 16,
 39,
 19,
 21,
 1,
 56,
 30,
 66,
 5,
 25,
 29,
 1,
 25,
 6,
 2,
 30,
 7,
 3,
 43,
 7,
 51,
 1,
 4,
 2,
 3,
 10,
 36,
 9,
 1,
 2,
 5,
 21,
 18,
 7,
 39,
 116,
 19,
 16,
 3,
 8,
 3,
 13,
 1,
 11,
 36,
 30,
 4,
 7,
 2,
 5,
 30,
 9,
 3,
 21,
 27,
 29,
 11,
 69,
 5,
 13,
 21,
 40,
 21,
 3,
 5,
 13,
 1,
 5,
 4,
 13,
 21,
 17,
 3,
 5,
 3,
 12,
 18,
 20,
 1,
 5,
 17,
 9,
 23,
 17,
 47,
 4,
 90,
 5,
 7,
 19,
 2,
 13,
 19,
 7,
 30,
 3,
 10,
 3,
 11,
 5,
 1,
 107,
 69,
 1,
 8,
 3,
 5,
 9,
 1,
 30,
 1,
 16,
 18,
 12,
 14,
 18,
 18,
 6,
 25,
 2,
 25,
 2,
 30,
 7,
 4,
 138,
 31,
 11,