In [1]:
from datetime import datetime
from pathlib import Path

import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
from constants import DATA_DIR
from torch.utils.data import DataLoader, Dataset

from astrofit.model import Asteroid
from astrofit.utils import AsteroidLoader, LightcurveBinner
from astrofit.utils.enums import BinningMethod

sns.set_theme(style="darkgrid")
plt.rcParams["figure.figsize"] = (14, 6)

In [2]:
assert torch.cuda.is_available(), "CUDA is not available"

AssertionError: CUDA is not available

In [4]:
PERIOD_FILE = "period.txt"
C_G_S = "\033[1;32m"
C_Y_S = "\033[1;33m"
C_E = "\033[0m"


In [3]:
asteroid_loader = AsteroidLoader(DATA_DIR)
lightcurve_binner = LightcurveBinner()

In [5]:
asteroid_name = "Eunomia"  # Interamnia, Eros, Ceres, Eunomia


In [6]:
asteroid = asteroid_loader.load_asteroid(asteroid_name)
asteroid

Asteroid(id=15, name=Eunomia, period=6.082754, lightcurves=109)

In [8]:
bins = lightcurve_binner.bin_lightcurves_from_asteroid(
    asteroid,
    max_time_diff=30,
    binning_method=BinningMethod.FIRST_TO_FIRST_DIFF,
    min_bin_size=3,
)

In [9]:
for ind, _bin in enumerate(bins):
    first_JD = _bin[0].first_JD
    last_JD = _bin[-1].last_JD

    duration = last_JD - first_JD
    lc_range = f"{duration:5.2f} days"
    if duration < 1:
        lc_range += f" - {duration * 24:.2f} hours"

    print(
        f"{ind:2} - {C_Y_S}{len(_bin):2}{C_E} lcs from {first_JD:.2f} to {last_JD:.2f} {C_G_S}({lc_range}){C_E}"
    )

 0 - [1;33m 3[0m lcs from 2435462.58 to 2435468.70 [1;32m( 6.12 days)[0m
 1 - [1;33m 3[0m lcs from 2442103.37 to 2442130.40 [1;32m(27.04 days)[0m
 2 - [1;33m 5[0m lcs from 2442149.30 to 2442158.36 [1;32m( 9.06 days)[0m
 3 - [1;33m 6[0m lcs from 2444912.26 to 2444941.00 [1;32m(28.74 days)[0m
 4 - [1;33m 3[0m lcs from 2445363.86 to 2445388.76 [1;32m(24.89 days)[0m
 5 - [1;33m 3[0m lcs from 2445831.84 to 2445858.85 [1;32m(27.01 days)[0m
 6 - [1;33m 4[0m lcs from 2446345.58 to 2446363.77 [1;32m(18.19 days)[0m
 7 - [1;33m 9[0m lcs from 2453886.48 to 2453897.63 [1;32m(11.15 days)[0m
 8 - [1;33m 9[0m lcs from 2454934.46 to 2454964.42 [1;32m(29.95 days)[0m
 9 - [1;33m 9[0m lcs from 2454965.20 to 2454995.33 [1;32m(30.13 days)[0m
10 - [1;33m15[0m lcs from 2455936.37 to 2455953.49 [1;32m(17.12 days)[0m
11 - [1;33m13[0m lcs from 2456303.62 to 2456327.77 [1;32m(24.15 days)[0m
12 - [1;33m 4[0m lcs from 2458284.58 to 2458295.67 [1;32m(11.09 days)[0m

In [None]:
# TODO:
# - sprawdzić jak działa LSTM dokładniej i co jest outputem :v
# - poddać krzywe do analizy Fourierowskiej (każda krzywa osobno i dostaję X elementów), może być zrobione w binach
# - przerzucić przez LSTMa i z każdego bina zagregować te wyjścia - na przykład średnią
# - przerzucić to nam sam koniec przez jaką warstwę FC i zwrócić wynik

In [10]:
for lc in bins:
    print(f"- {len(lc)} lightcurves")
    for points in lc:
        print(f"  - {len(points)} points")

- 3 lightcurves
  - 66 points
  - 90 points
  - 24 points
- 3 lightcurves
  - 70 points
  - 55 points
  - 29 points
- 5 lightcurves
  - 46 points
  - 20 points
  - 30 points
  - 29 points
  - 21 points
- 6 lightcurves
  - 64 points
  - 44 points
  - 7 points
  - 14 points
  - 33 points
  - 19 points
- 3 lightcurves
  - 6 points
  - 25 points
  - 18 points
- 3 lightcurves
  - 13 points
  - 42 points
  - 11 points
- 4 lightcurves
  - 36 points
  - 37 points
  - 29 points
  - 14 points
- 9 lightcurves
  - 37 points
  - 36 points
  - 45 points
  - 29 points
  - 49 points
  - 54 points
  - 62 points
  - 61 points
  - 113 points
- 9 lightcurves
  - 78 points
  - 116 points
  - 50 points
  - 39 points
  - 46 points
  - 16 points
  - 51 points
  - 57 points
  - 71 points
- 9 lightcurves
  - 48 points
  - 76 points
  - 75 points
  - 61 points
  - 50 points
  - 57 points
  - 67 points
  - 44 points
  - 52 points
- 15 lightcurves
  - 226 points
  - 143 points
  - 228 points
  - 193 points
  - 201

In [11]:
class AsteroidDataset(Dataset):
    """
    Dataset containing asteroids, each composed of Lightcurve sessions
    """

    def __init__(self, asteroids: list[Asteroid]):
        self.asteroids = self._extract_points(asteroids)
        self.periods = self._extract_periods(asteroids)

    def __len__(self):
        return len(self.asteroids)

    def __getitem__(self, idx):
        return self.asteroids[idx], self.periods[idx]

    def _extract_points(self, asteroids: list[Asteroid]) -> list[list[list[torch.Tensor]]]:
        parsed_asteroids = []
        for asteroid in asteroids:
            bins = lightcurve_binner.bin_lightcurves_from_asteroid(
                asteroid,
                max_time_diff=30,
                binning_method=BinningMethod.FIRST_TO_FIRST_DIFF,
                min_bin_size=3,
            )
            sessions = []
            for session in bins:
                session_points = []
                for lightcurve in session:
                    data_points = [(point.JD, point.brightness) for point in lightcurve.points]
                    session_points.append(torch.tensor(data_points, dtype=torch.float32))

                sessions.append(session_points)

            parsed_asteroids.append(sessions)

        return parsed_asteroids

    def _extract_periods(self, asteroids: list[Asteroid]) -> list[float]:
        return [asteroid.period for asteroid in asteroids]


def collate_fn(asteroid_batch):
    assert len(asteroid_batch) == 1, "Batch size must be 1"

    asteroid, period = asteroid_batch[0]

    max_lightcurves = max(len(session) for session in asteroid)
    max_points = max(max(lightcurve.size(0) for lightcurve in session) for session in asteroid)

    padded_sessions = []
    for session in asteroid:
        padded_lightcurves = [
            (
                torch.cat([lightcurve, torch.zeros((max_points - lightcurve.size(0), 2))])
                if lightcurve.size(0) < max_points
                else lightcurve
            )
            for lightcurve in session
        ]
        padded_lightcurves = torch.stack(padded_lightcurves)
        num_padding_lightcurves = max_lightcurves - padded_lightcurves.size(0)
        if num_padding_lightcurves > 0:
            padding = torch.zeros((num_padding_lightcurves, max_points, 2))
            padded_lightcurves = torch.cat((padded_lightcurves, padding), dim=0)

        padded_sessions.append(padded_lightcurves)

    padded_asteroid = torch.stack(padded_sessions)
    period_tensor = torch.tensor([period], dtype=torch.float32)

    padded_asteroid = padded_asteroid.unsqueeze(0)  # Add batch dimension

    return padded_asteroid, period_tensor

In [12]:
asteroids = []
for key in asteroid_loader.available_asteroids:
    if key in ("Interamnia", "Eros"):
        continue

    asteroid = asteroid_loader.load_asteroid(key)
    if len(asteroid.lightcurves) < 10:
        continue

    asteroids.append(asteroid)

asteroids

[Asteroid(id=54, name=Alexandra, period=7.02264, lightcurves=38),
 Asteroid(id=82, name=Alkmene, period=13.00079, lightcurves=16),
 Asteroid(id=29, name=Amphitrite, period=5.39012, lightcurves=66),
 Asteroid(id=64, name=Angelina, period=8.75033, lightcurves=22),
 Asteroid(id=43, name=Ariadne, period=5.761987, lightcurves=43),
 Asteroid(id=5, name=Astraea, period=16.80059, lightcurves=25),
 Asteroid(id=36, name=Atalante, period=9.92692, lightcurves=31),
 Asteroid(id=94, name=Aurora, period=7.226189, lightcurves=22),
 Asteroid(id=63, name=Ausonia, period=9.29759, lightcurves=16),
 Asteroid(id=28, name=Bellona, period=15.70785, lightcurves=24),
 Asteroid(id=1, name=Ceres, period=9.074173, lightcurves=46),
 Asteroid(id=34, name=Circe, period=12.17458, lightcurves=17),
 Asteroid(id=65, name=Cybele, period=6.081435, lightcurves=62),
 Asteroid(id=41, name=Daphne, period=5.987981, lightcurves=49),
 Asteroid(id=99, name=Dike, period=18.11914, lightcurves=30),
 Asteroid(id=48, name=Doris, period

In [13]:
test_asteroids = [asteroid_loader.load_asteroid("Interamnia"), asteroid_loader.load_asteroid("Eros")]
test_asteroids

[Asteroid(id=704, name=Interamnia, period=8.712337, lightcurves=188),
 Asteroid(id=433, name=Eros, period=5.27025528, lightcurves=118)]

In [14]:
dataset = AsteroidDataset(asteroids)
data_loader = DataLoader(dataset, batch_size=1, collate_fn=collate_fn)

for batch in data_loader:
    inputs, targets = batch
    print(targets)
    print(inputs.size(), targets.size())

tensor([7.0226])
torch.Size([1, 8, 5, 345, 2]) torch.Size([1])
tensor([13.0008])
torch.Size([1, 2, 6, 15, 2]) torch.Size([1])
tensor([5.3901])
torch.Size([1, 9, 9, 174, 2]) torch.Size([1])
tensor([8.7503])
torch.Size([1, 4, 6, 74, 2]) torch.Size([1])
tensor([5.7620])
torch.Size([1, 8, 9, 149, 2]) torch.Size([1])
tensor([16.8006])
torch.Size([1, 3, 7, 44, 2]) torch.Size([1])
tensor([9.9269])
torch.Size([1, 4, 11, 369, 2]) torch.Size([1])
tensor([7.2262])
torch.Size([1, 3, 4, 42, 2]) torch.Size([1])
tensor([9.2976])
torch.Size([1, 2, 5, 331, 2]) torch.Size([1])
tensor([15.7079])
torch.Size([1, 4, 7, 381, 2]) torch.Size([1])
tensor([9.0742])
torch.Size([1, 7, 8, 465, 2]) torch.Size([1])
tensor([12.1746])
torch.Size([1, 3, 8, 220, 2]) torch.Size([1])
tensor([6.0814])
torch.Size([1, 10, 6, 467, 2]) torch.Size([1])
tensor([5.9880])
torch.Size([1, 7, 7, 651, 2]) torch.Size([1])
tensor([18.1191])
torch.Size([1, 4, 10, 402, 2]) torch.Size([1])
tensor([11.8901])
torch.Size([1, 2, 18, 591, 2]) to

In [15]:
class AsteroidLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, dropout=0.2):
        super(AsteroidLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout)
        self.fc = nn.Linear(hidden_size, 1)

    def forward(self, x: torch.Tensor):
        # Reshape x to (batch_size * num_sessions, num_lightcurves, input_size)
        batch_size, num_sessions, num_lightcurves, num_points, num_features = x.size()
        
        # Flatten the num_points dimension while keeping the last dimension as input_size
        x = x.view(batch_size * num_sessions, num_lightcurves * num_points, num_features)

        # Pass through LSTM
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
 
        out, _ = self.lstm(x, (h0, c0))

        # Take the output from the last time step
        out = out[:, -1, :]

        # Reshape to (batch_size, num_sessions, hidden_size)
        out = out.view(batch_size, num_sessions, -1)

        # Aggregate the session outputs (e.g., mean)
        out = torch.mean(out, dim=1)

        # Pass through the fully connected layer
        out = self.fc(out)

        return out

In [16]:
input_size = 2  # JD and brightness
hidden_size = 32
num_layers = 2
dropout = 0.15

In [17]:
model = AsteroidLSTM(input_size, hidden_size, num_layers).cuda()
model

RuntimeError: No CUDA GPUs are available

In [19]:
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)


def train_model(
    model: nn.Module,
    data_loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    criterion: nn.MSELoss, 
    num_epochs: int,
    save_path: Path,
):
    model.train()
    best_loss = float("inf")

    for epoch in range(num_epochs):
        for inputs, targets in data_loader:
            #inputs, targets = inputs.cuda(), targets.cuda()

            outputs = model(inputs)
            loss = criterion(outputs, targets.view(-1, 1))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        scheduler.step()
        print(f"Epoch [{epoch+1}/100], Loss: {loss.item():.4f}")

        if loss.item() < best_loss:
            best_loss = loss.item()
            torch.save(model.state_dict(), save_path)
            print(f"Model saved with loss {best_loss:.4f}")

In [20]:
save_path = MODELS_DIR / f"asteroid_lstm_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pt"

train_model(model, data_loader, optimizer, criterion, num_epochs=100, save_path=save_path)

Epoch [1/100], Loss: 12.3935
Model saved with loss 12.3935
Epoch [2/100], Loss: 10.5926
Model saved with loss 10.5926
Epoch [3/100], Loss: 10.3596
Model saved with loss 10.3596
Epoch [4/100], Loss: 10.3191
Model saved with loss 10.3191
Epoch [5/100], Loss: 10.3019
Model saved with loss 10.3019


KeyboardInterrupt: 

In [None]:
test_asteroids

In [None]:
# save_path = MODELS_DIR / "asteroid_lstm_20240516_005338.pt"

In [None]:
dataset = AsteroidDataset(test_asteroids)
test_loader = DataLoader(dataset, batch_size=1, collate_fn=collate_fn)

best_model = AsteroidLSTM(input_size, hidden_size, num_layers, dropout)
best_model.load_state_dict(torch.load(save_path))

for ind, batch in enumerate(test_loader):
    inputs, targets = batch

    predictions = best_model(inputs)
    print(f"{test_asteroids[ind].name} predicted period: {predictions.item():.4f} - true period: {targets.item():.4f}")
    

In [None]:
raise

In [None]:
def predict_model(model, data_loader):
    model.eval()  # Set the model to evaluation mode
    predictions = []
    with torch.no_grad():  # No need to track gradients for predictions
        for sessions, _ in data_loader:
            outputs = model(sessions)  # Assuming sessions are formatted correctly
            predictions.extend(outputs.detach().cpu().numpy())  # Store predictions
    

    return sum(predictions) / len(predictions)

In [None]:
data_loader = DataLoader(dataset, batch_size=1, collate_fn=collate_fn, shuffle=False)


In [None]:
predicted_period, = predict_model(model, data_loader)
true_period = dataset.periods[0]

print(f"True period: {true_period:.2f} hours")
print(f"Predicted period: {predicted_period:.2f} hours")

