In [71]:
import json

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence
from torch.utils.data import DataLoader, Dataset

from constants import ASTEROIDS_DIR, DATA_DIR
from src.model import Asteroid, Lightcurve
from src.utils import LightcurveBinner

sns.set_theme(style="darkgrid")
plt.rcParams["figure.figsize"] = (14, 6)

In [3]:
PERIOD_FILE = "period.txt"
C_G_S = "\033[1;32m"
C_Y_S = "\033[1;33m"
C_E = "\033[0m"


In [4]:
ASTEROIDS_DF = pd.read_csv(DATA_DIR / "asteroids.csv", index_col=0)
ASTEROIDS_DF = ASTEROIDS_DF.dropna(subset=["number"])
ASTEROIDS_DF["number"] = ASTEROIDS_DF["number"].astype(int)
ASTEROIDS_DF

Unnamed: 0_level_0,number,name,designation,comment,created,modified
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
3414,1,Ceres,,albedo effects,2021-11-18 10:13:38,2021-11-18 10:14:01
101,2,Pallas,,,,2019-05-31 11:22:28
102,3,Juno,,,,2010-04-15 15:54:09
3415,4,Vesta,,albedo effects,2021-11-19 11:18:10,2021-11-19 11:27:06
103,5,Astraea,,,,2010-04-15 15:54:09
...,...,...,...,...,...,...
7103,353971,,,,2023-05-23 07:51:09,2023-05-23 07:51:09
7110,354510,,,,2023-05-23 07:51:09,2023-05-23 07:51:09
7194,362935,,,,2023-05-23 07:51:13,2023-05-23 07:51:13
7320,380282,,,,2023-05-23 07:51:18,2023-05-23 07:51:18


In [5]:
AVAILABLE_ASTEROIDS = {}
for directory in ASTEROIDS_DIR.iterdir():
    if not directory.is_dir():
        continue

    asteroid_name = directory.name.split("_")[0]
    work_name = directory.name

    res = ASTEROIDS_DF.query(f"name == '{asteroid_name}'")
    if len(res) != 1:
        raise ValueError(f"Found multiple asteroids with name {asteroid_name} (work name: {work_name})")

    (asteroid_num,) = res["number"]

    if not (directory / PERIOD_FILE).exists():
        raise FileNotFoundError(f"Missing {PERIOD_FILE} for {work_name}")
    
    with open(directory / PERIOD_FILE, "r") as f:
        period = float(f.read().strip())
    
    AVAILABLE_ASTEROIDS[work_name] = {"id": asteroid_num, "name": asteroid_name, "period": period}

AVAILABLE_ASTEROIDS = {k: AVAILABLE_ASTEROIDS[k] for k in sorted(AVAILABLE_ASTEROIDS)}
AVAILABLE_ASTEROIDS

{'Ceres': {'id': 1, 'name': 'Ceres', 'period': 9.074173},
 'Eros': {'id': 433, 'name': 'Eros', 'period': 5.27025528},
 'Eunomia': {'id': 15, 'name': 'Eunomia', 'period': 6.082754},
 'Flora': {'id': 8, 'name': 'Flora', 'period': 12.86667},
 'Interamnia': {'id': 704, 'name': 'Interamnia', 'period': 8.712337},
 'Iris': {'id': 7, 'name': 'Iris', 'period': 7.138844},
 'Metis': {'id': 9, 'name': 'Metis', 'period': 5.079177},
 'Pallas': {'id': 2, 'name': 'Pallas', 'period': 7.81322},
 'Sylvia': {'id': 87, 'name': 'Sylvia', 'period': 5.183641},
 'Vesta': {'id': 4, 'name': 'Vesta', 'period': 5.342124}}

In [6]:
asteroid_name = "Eunomia"  # Interamnia, Eros, Ceres, Eunomia
chosen_asteroid = AVAILABLE_ASTEROIDS[asteroid_name]
asteroid_id, known_period = chosen_asteroid["id"], chosen_asteroid["period"]

with open(ASTEROIDS_DIR / asteroid_name / "lc.json", "r") as f:
    raw_data = json.load(f)

In [7]:
asteroid = Asteroid.from_lightcurves(
    id=asteroid_id,
    name=asteroid_name,
    period=known_period,
    data=raw_data
)
asteroid

Asteroid(id=15, name=Eunomia, period=6.082754, lightcurves=109)

In [8]:
lightcurve_binner = LightcurveBinner()

In [9]:
bins = lightcurve_binner.bin_lightcurves_by_asteroid(asteroid, max_time_diff=30, min_n=3)


In [10]:
for ind, _bin in enumerate(bins):
    first_JD = _bin[0].first_JD
    last_JD = _bin[-1].last_JD

    duration = last_JD - first_JD
    lc_range = f"{duration:5.2f} days"
    if duration < 1:
        lc_range += f" - {duration * 24:.2f} hours"

    print(
        f"{ind:2} - {C_Y_S}{len(_bin):2}{C_E} lcs from {first_JD:.2f} to {last_JD:.2f} {C_G_S}({lc_range}){C_E}"
    )

 0 - [1;33m 3[0m lcs from 2435462.58 to 2435468.70 [1;32m( 6.12 days)[0m
 1 - [1;33m 3[0m lcs from 2442103.37 to 2442130.40 [1;32m(27.04 days)[0m
 2 - [1;33m 5[0m lcs from 2442149.30 to 2442158.36 [1;32m( 9.06 days)[0m
 3 - [1;33m 6[0m lcs from 2444912.26 to 2444941.00 [1;32m(28.74 days)[0m
 4 - [1;33m 3[0m lcs from 2445363.86 to 2445388.76 [1;32m(24.89 days)[0m
 5 - [1;33m 3[0m lcs from 2445831.84 to 2445858.85 [1;32m(27.01 days)[0m
 6 - [1;33m 4[0m lcs from 2446345.58 to 2446363.77 [1;32m(18.19 days)[0m
 7 - [1;33m 9[0m lcs from 2453886.48 to 2453897.63 [1;32m(11.15 days)[0m
 8 - [1;33m 9[0m lcs from 2454934.46 to 2454964.42 [1;32m(29.95 days)[0m
 9 - [1;33m 8[0m lcs from 2454965.20 to 2454992.29 [1;32m(27.10 days)[0m
10 - [1;33m15[0m lcs from 2455936.37 to 2455953.49 [1;32m(17.12 days)[0m
11 - [1;33m13[0m lcs from 2456303.62 to 2456327.77 [1;32m(24.15 days)[0m
12 - [1;33m 4[0m lcs from 2458284.58 to 2458295.67 [1;32m(11.09 days)[0m

In [11]:
for lc in bins:
    print(f"- {len(lc)} lightcurves")
    for points in lc:
        print(f"  - {len(points)} points")

- 3 lightcurves
  - 66 points
  - 90 points
  - 24 points
- 3 lightcurves
  - 70 points
  - 55 points
  - 29 points
- 5 lightcurves
  - 46 points
  - 20 points
  - 30 points
  - 29 points
  - 21 points
- 6 lightcurves
  - 64 points
  - 44 points
  - 7 points
  - 14 points
  - 33 points
  - 19 points
- 3 lightcurves
  - 6 points
  - 25 points
  - 18 points
- 3 lightcurves
  - 13 points
  - 42 points
  - 11 points
- 4 lightcurves
  - 36 points
  - 37 points
  - 29 points
  - 14 points
- 9 lightcurves
  - 37 points
  - 36 points
  - 45 points
  - 29 points
  - 49 points
  - 54 points
  - 62 points
  - 61 points
  - 113 points
- 9 lightcurves
  - 78 points
  - 116 points
  - 50 points
  - 39 points
  - 46 points
  - 16 points
  - 51 points
  - 57 points
  - 71 points
- 8 lightcurves
  - 48 points
  - 76 points
  - 75 points
  - 61 points
  - 50 points
  - 57 points
  - 67 points
  - 44 points
- 15 lightcurves
  - 226 points
  - 143 points
  - 228 points
  - 193 points
  - 201 points
  - 16

In [114]:
class LightCurveSessionDataset(Dataset):
    def __init__(self, bins: list[list[Lightcurve]], labels: list[float]):
        """
        Args:
            sessions (list of lists of Lightcurve): Nested list where each sublist represents a session of lightcurves, each lightcurve has a shape of (sequence_length, 2)
            labels (list of floats): The corresponding rotational period for each session
        """
        self.sessions = self._extract_points(bins)
        self.labels = labels

    def __len__(self):
        return len(self.sessions)

    def __getitem__(self, idx):
        return self.sessions[idx], self.labels[idx]

    def _extract_points(self, bins: list[list[Lightcurve]]) -> list[list[torch.Tensor]]:
        sessions = []
        for session in bins:
            session_points = []
            for lightcurve in session:
                data_points = [(point.JD, point.brightness) for point in lightcurve.points]
                session_points.append(torch.tensor(data_points, dtype=torch.float32))
            sessions.append(session_points)

        return sessions


def collate_fn(batch):
    sessions, labels = zip(*batch)

    sessions_padded = [pad_sequence(session, batch_first=True, padding_value=0.0) for session in sessions]
    labels = torch.tensor(labels, dtype=torch.float32)

    return sessions_padded, labels

In [135]:
periods = [period] * len(bins)
dataset = LightCurveSessionDataset(bins, periods)
data_loader = DataLoader(dataset, batch_size=1, collate_fn=collate_fn)

for session_data, target in data_loader:
    print(f"Session data shapes: {[data.shape for data in session_data]}")

Session data shapes: [torch.Size([3, 90, 2])]
Session data shapes: [torch.Size([3, 70, 2])]
Session data shapes: [torch.Size([5, 46, 2])]
Session data shapes: [torch.Size([6, 64, 2])]
Session data shapes: [torch.Size([3, 25, 2])]
Session data shapes: [torch.Size([3, 42, 2])]
Session data shapes: [torch.Size([4, 37, 2])]
Session data shapes: [torch.Size([9, 113, 2])]
Session data shapes: [torch.Size([9, 116, 2])]
Session data shapes: [torch.Size([8, 76, 2])]
Session data shapes: [torch.Size([15, 228, 2])]
Session data shapes: [torch.Size([13, 447, 2])]
Session data shapes: [torch.Size([4, 121, 2])]


In [151]:
class LightCurveLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super(LightCurveLSTM, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        # LSTM layer
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)

        # Fully connected layer
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        # x shape should be [num_lightcurves, seq_length, input_dim]
        print(x.shape)
        num_lightcurves, seq_len, input_dim = x.shape
        
        # Initialize hidden state and cell state
        h0 = torch.zeros(self.num_layers, num_lightcurves, self.hidden_dim).to(x.device)
        c0 = torch.zeros(self.num_layers, num_lightcurves, self.hidden_dim).to(x.device)

        # Forward propagate LSTM
        out, _ = self.lstm(x, (h0, c0))  # out: tensor of shape (batch_size * num_lightcurves, seq_len, hidden_dim)

        # Decode the hidden state of the last time step
        out = self.fc(out[:, -1, :])

        # Average the predictions across lightcurves within the same session
        out = out.mean(dim=0, keepdim=True)

        return out

In [157]:
for session_data, target in data_loader:
    break

torch.stack(session_data).size()


torch.Size([1, 3, 90, 2])

In [152]:
input_dim = 2  # brightness and time
hidden_dim = 128
output_dim = 1  # predicting a single float value
num_layers = 1

# Model instantiation
model = LightCurveLSTM(input_dim, hidden_dim, output_dim, num_layers)
model

LightCurveLSTM(
  (lstm): LSTM(2, 128, batch_first=True)
  (fc): Linear(in_features=128, out_features=1, bias=True)
)

In [153]:
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


# Training loop
def train_model(model, data_loader, optimizer, criterion, num_epochs):
    model.train()
    for epoch in range(num_epochs):
        for session_data, targets in data_loader:
            session_data = torch.stack(session_data)  # Convert list of tensors to a 3D tensor
            outputs = model(session_data)
            loss = criterion(outputs, targets)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

In [154]:
train_model(model, data_loader, optimizer, criterion, num_epochs=10)

torch.Size([1, 3, 90, 2])


ValueError: too many values to unpack (expected 3)

In [None]:
raise

In [12]:
PADDED_DATA_TYPE = list[list[list[tuple[float, float]]]]
MASKS_TYPE = list[list[list[int]]]

In [13]:
def preprocess_data(bins: list[list[Lightcurve]]) -> tuple[PADDED_DATA_TYPE, MASKS_TYPE]:
    padded_data = []
    masks = []
    for group in bins:
        max_length = max(len(lc) for lc in group)
        padded_group = []
        mask_group = []
        for lc in group:
            padded_lc = [(point.brightness, point.JD) for point in lc.points] + [(0, 0)] * (max_length - len(lc))
            mask_lc = [1] * len(lc) + [0] * (max_length - len(lc))
            padded_group.append(padded_lc)
            mask_group.append(mask_lc)

        padded_data.append(padded_group)
        masks.append(mask_group)

    return padded_data, masks

In [14]:
data, masks = preprocess_data(bins)

In [15]:
for lc in data:
    print(f"- {len(lc)} lightcurves")
    for points in lc:
        print(f"  - {len(points)} points")

- 3 lightcurves
  - 90 points
  - 90 points
  - 90 points
- 3 lightcurves
  - 70 points
  - 70 points
  - 70 points
- 5 lightcurves
  - 46 points
  - 46 points
  - 46 points
  - 46 points
  - 46 points
- 6 lightcurves
  - 64 points
  - 64 points
  - 64 points
  - 64 points
  - 64 points
  - 64 points
- 3 lightcurves
  - 25 points
  - 25 points
  - 25 points
- 3 lightcurves
  - 42 points
  - 42 points
  - 42 points
- 4 lightcurves
  - 37 points
  - 37 points
  - 37 points
  - 37 points
- 9 lightcurves
  - 113 points
  - 113 points
  - 113 points
  - 113 points
  - 113 points
  - 113 points
  - 113 points
  - 113 points
  - 113 points
- 9 lightcurves
  - 116 points
  - 116 points
  - 116 points
  - 116 points
  - 116 points
  - 116 points
  - 116 points
  - 116 points
  - 116 points
- 8 lightcurves
  - 76 points
  - 76 points
  - 76 points
  - 76 points
  - 76 points
  - 76 points
  - 76 points
  - 76 points
- 15 lightcurves
  - 228 points
  - 228 points
  - 228 points
  - 228 points
  -