In [24]:
import json
import logging
from collections import Counter
from itertools import product
from random import choice
from time import perf_counter
from typing import Literal, TypedDict

import numpy as np
from constants import DATA_DIR
from tqdm import tqdm

from astrofit.model import Asteroid, Lightcurve, LightcurveBin
from astrofit.utils import (
    AsteroidLoader,
    FrequencyDecomposer,
    LightcurveBinner,
    LightcurvePlotter,
    LightcurveSplitter,
)

In [25]:
logging.getLogger('matplotlib').setLevel(logging.WARNING)
logger = logging.getLogger("freq")

logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")

In [26]:
asteroid_loader = AsteroidLoader(DATA_DIR)
frequency_decomposer = FrequencyDecomposer()
lightcurve_binner = LightcurveBinner()
lightcurve_plotter = LightcurvePlotter()
lightcurve_splitter = LightcurveSplitter()

FEATURES_DIR = DATA_DIR / "features"
ASTEROIDS_JSON_FILE_NAME = "asteroids_freq_data_{config_no}.json"


FEATURES_DIR.mkdir(exist_ok=True)

In [27]:
MAX_PERIOD = 40

asteroids: dict[str, Asteroid] = {}
for asteroid_name in tqdm(asteroid_loader.available_asteroids):
    ast_name = asteroid_loader.load_asteroid(asteroid_name)
    # if asteroid.period > MAX_PERIOD:
    #     continue

    asteroids[ast_name.name] = ast_name

print(f"Loaded {len(asteroids)} asteroids")

100%|██████████| 5057/5057 [00:13<00:00, 380.04it/s]

Loaded 5057 asteroids





In [28]:
np.percentile([asteroid.period for asteroid in asteroids.values()], [5, 25, 50, 75, 90, 95])

array([  3.422438,   5.48469 ,   8.78678 ,  16.97036 ,  49.0198  ,
       108.6936  ])

In [29]:
class Config(TypedDict):
    max_hours_diff: float
    min_no_points: int
    top_k_bins: int
    buffer_bins: int
    select_bins_by: Literal["lightcurves", "points"]
    max_time_diff: float
    min_bin_size: int
    max_freq: float
    top_k_freqs: int
    nterms: int
    max_debug: bool  # If true, will print and plot everything


In [30]:
cached_lightcurves: dict[tuple, dict[str, list[Lightcurve]]] = {}
cached_bins: dict[tuple, dict[str, list[LightcurveBin]]] = {}


def _split_lightcurves(asteroid: Asteroid, config: Config) -> list[Lightcurve]:
    max_hours_diff = config["max_hours_diff"]
    min_no_points = config["min_no_points"]

    logger.debug(
        f"Splitting lightcurves for {asteroid.name} with max_hours_diff={max_hours_diff} and min_no_points={min_no_points}"
    )
    logger.debug(f"Before splitting: {len(asteroid.lightcurves)} lightcurves")

    key = (max_hours_diff, min_no_points)
    if key in cached_lightcurves and asteroid.name in cached_lightcurves[key]:
        logger.debug(f"Using cached lightcurves ({key}) for {asteroid.name}")
        splitted = cached_lightcurves[key][asteroid.name]
    else:
        splitted = lightcurve_splitter.split_lightcurves(
            asteroid.lightcurves,
            max_hours_diff=max_hours_diff,
            min_no_points=min_no_points,
        )

        if key not in cached_lightcurves:
            cached_lightcurves[key] = {}

        cached_lightcurves[key][asteroid.name] = splitted

    logger.debug(f"After splitting: {len(splitted)} lightcurves")

    return splitted


def _get_top_k_bins(lightcurves: list[Lightcurve], config: Config, asteroid: Asteroid) -> list[LightcurveBin]:
    max_time_diff = config["max_time_diff"]
    min_bin_size = config["min_bin_size"]
    top_k_bins = config["top_k_bins"]
    buffer_bins = config["buffer_bins"]  # In case of too few frequencies for some of selected bins

    logger.debug(f"Getting top {top_k_bins} bins with max_time_diff={max_time_diff} and min_bin_size={min_bin_size}")

    # If using the same lightcurves AND the same binning parameters, we can reuse the bins
    composite_key = (config["max_hours_diff"], config["min_no_points"], max_time_diff, min_bin_size)
    if composite_key in cached_bins and asteroid.name in cached_bins[composite_key]:
        logger.debug(f"Using cached bins ({composite_key}) for {asteroid.name}")
        bins = cached_bins[composite_key][asteroid.name]
    else:
        bins = lightcurve_binner.bin_lightcurves(
            lightcurves,
            max_time_diff=max_time_diff,
            min_bin_size=min_bin_size,
        )

        if composite_key not in cached_bins:
            cached_bins[composite_key] = {}

        cached_bins[composite_key][asteroid.name] = bins

    logger.debug(f"After binning {len(bins)} bins available")
    if len(bins) < top_k_bins:
        logger.debug(f"Using {len(bins)} bins instead of {top_k_bins}")

    if config["select_bins_by"] == "lightcurves":
        return sorted(bins, reverse=True)[: top_k_bins + buffer_bins]

    elif config["select_bins_by"] == "points":
        return sorted(bins, key=lambda bin: bin.points_count, reverse=True)[: top_k_bins + buffer_bins]

    else:
        raise ValueError("Invalid value for select_bins_by")


def _get_top_k_freqs(lightcurve_bin: LightcurveBin, config: Config, asteroid: Asteroid) -> np.ndarray:
    nterms = config["nterms"]
    top_k_freqs = config["top_k_freqs"]

    logger.debug(
        f"Getting top {top_k_freqs} frequencies with nterms={nterms} for "
        f"lightcurves={len(lightcurve_bin)} with total points={len(lightcurve_bin.times)}"
    )

    freq_data = frequency_decomposer.decompose_bin(
        lightcurve_bin,
        fourier_nterms=nterms,
        top_k=top_k_freqs,
        max_freq=config["max_freq"],
        show_plot=config["max_debug"],
    )

    if config["max_debug"]:
        # Return also the ratio of the frequency to the true frequency
        true_freq_ratio = freq_data[:, 0] / (24 / asteroid.period)
        return np.column_stack((freq_data, true_freq_ratio))
    else:
        return freq_data


def _has_anomalous_series(data: list[list[float]], magnitude_threshold: int = 2):
    if not data:
        return False

    medians = np.array([np.median(series) for series in data])
    overall_median = np.median(medians)
    ratios = medians / overall_median

    anomalous_series_exist = np.any(np.logical_or(ratios > 10**magnitude_threshold, ratios < 10 ** (-magnitude_threshold)))

    return anomalous_series_exist


def get_freq_features(
    asteroid: Asteroid,
    config: Config,
) -> list[list] | dict:
    splitted_lightcurves = _split_lightcurves(asteroid, config)
    if _has_anomalous_series([lc.brightness_arr for lc in splitted_lightcurves]):
        logger.debug("Anomalous series detected")

        return {"status": "failed", "reason": "anomalous series"}

    # Includes buffer bins
    top_k_bins = _get_top_k_bins(splitted_lightcurves, config, asteroid)

    if not top_k_bins:
        logger.debug("No bins available")

        return {"status": "failed", "reason": "no bins"}

    top_k_bins_no = config["top_k_bins"]
    buffer_bins_no = config["buffer_bins"]

    freq_data = []
    for ind, _bin in enumerate(top_k_bins):
        if len(freq_data) == top_k_bins_no:
            break

        if ind >= top_k_bins_no:
            logger.debug(f"Using buffer bin {ind - top_k_bins_no + 1} / {buffer_bins_no}")

        if config["max_debug"]:
            lightcurve_plotter.plot_lightcurves(_bin)

        bin_freq = _get_top_k_freqs(_bin, config, asteroid)
        if len(bin_freq) < config["top_k_freqs"]:
            logger.debug(f"Bin {ind} has only {len(bin_freq)} frequencies, skipping")

            continue

        freq_data.append(bin_freq.tolist())

    if not freq_data:
        logger.debug("No frequencies available")

        return {"status": "failed", "reason": "no frequencies"}

    logger.debug(f"{'-'*50}\n")

    return freq_data, top_k_bins

In [None]:
options = {
    "max_hours_diff": [1, 2, 4, 8],
    "min_no_points": [10, 20],
    "top_k_bins": [2, 4],
    "buffer_bins": [3],
    "select_bins_by": ["lightcurves", "points"],
    "max_time_diff": [30, 45, 60],
    "min_bin_size": [1, 2],
    "max_freq": [12],
    "top_k_freqs": [50],
    "nterms": [3],
    "max_debug": [False],
}

configs = []
for ind, option_values in enumerate(product(*options.values())):
    option = dict(zip(options.keys(), option_values))
    configs.append(Config(**option))

print(f"Generated {len(configs)} configurations")
print(f"Given ~4m per configuration, this will take ~{len(configs) * 4 / 60} hours")

In [None]:
def calculate_and_save_features(config: Config, config_no: int):
    config["max_debug"] = False

    failed_cnt = 0
    asteroids_data = {}
    for asteroid_name, ast_name in tqdm(asteroids.items()):
        start = perf_counter()
        features = get_freq_features(ast_name, config)
        processing_time = perf_counter() - start

        asteroids_data[asteroid_name] = {
            "is_failed": False,
            "reason": None,
            "period": ast_name.period,
            "processing_time": processing_time,
            "features": [],
        }
        if isinstance(features, dict):
            failed_cnt += 1
            asteroids_data[asteroid_name]["is_failed"] = True
            asteroids_data[asteroid_name]["reason"] = features["reason"]
            continue

        assert len(features) in (1, 2, 3, 4), f"Invalid number of sequences: {len(features)} for {asteroid_name}"

        asteroids_data[asteroid_name]["features"] = features

    print(f"Failed asteroids: {failed_cnt} ({failed_cnt / len(asteroids) * 100:.2f}%)")
    print(f"{'-'*50}")

    dump_data = {
        "config": config,
        "asteroids": asteroids_data,
    }

    with open(FEATURES_DIR / (ASTEROIDS_JSON_FILE_NAME.format(config_no=config_no)), "w") as f:
        json.dump(dump_data, f, indent=4)

In [None]:
calculated_configs = []
for file in FEATURES_DIR.iterdir():
    if file.suffix != ".json":
        continue

    with open(file, "r") as f:
        data = json.load(f)
    
    calculated_configs.append(data["config"])

print(f"Already calculated {len(calculated_configs)} configurations")

In [None]:
logging.getLogger().setLevel(logging.INFO)

times = []
for ind, config in enumerate(configs):
    if config in calculated_configs:
        continue

    print(f"Config {ind + 1} / {len(configs)}")
    print(f"{'#'*10}")
    print(config)
    print(f"{'#'*10}")
    start = perf_counter()
    calculate_and_save_features(config, config_no=ind + 1)
    times.append(perf_counter() - start)

    print(f"Average time per configuration so far: {np.mean(times):.2f} seconds")
    print(f"{'-'*50}")

logging.getLogger().setLevel(logging.DEBUG)

In [None]:
ast_name = asteroids["1177 T-3"]
ast_name

In [None]:
config = configs[0]

config['max_debug'] = True

In [None]:
data = get_freq_features(ast_name, config)

In [None]:
with open(FEATURES_DIR / (ASTEROIDS_JSON_FILE_NAME.format(config_no=1)), "r") as f:
    data = json.load(f)


asteroids_data = data["asteroids"]


In [None]:
failed_asteroids = filter(lambda x: x[1]["is_failed"], asteroids_data.items())


In [None]:
reasons = Counter([failed_asteroid["reason"] for _, failed_asteroid in failed_asteroids])
print(reasons)

In [None]:
diffs = {}
for asteroid_name, asteroid_data in asteroids_data.items():
    if asteroid_data["is_failed"]:
        continue

    target_freq = 24 / asteroid_data["period"]
    freqs = np.array(asteroid_data["features"])

    top_k_ratio = freqs[:,:,0] / target_freq
    min_diff = np.min(np.abs(top_k_ratio - 1))  # Closest to 1

    diffs[asteroid_name] = min_diff

In [None]:
percentiles = np.percentile(list(diffs.values()), [0, 5, 25, 50, 75, 95, 100])
percentiles

In [None]:
selected_asteroids = {name: data for name, data in asteroids_data.items() if name in diffs and diffs[name] > percentiles[-2]}
len(selected_asteroids)

In [None]:
ast_name = choice(list(selected_asteroids.keys()))
print(repr(asteroids[ast_name]))
print(f"Target frequency: {24 / asteroids[ast_name].period}")

get_freq_features(asteroids[ast_name], config)