In [1]:
import satkit

# satkit.utils.update_datafiles()

In [2]:
from util.read import read_zst, read_blocks
from dotenv import load_dotenv
import os

load_dotenv()

FILEPATH = os.getenv("TEST_FILEPATH")
lines = read_zst(FILEPATH)
states = []

num_lines = len(lines)
steps = read_blocks(lines)
print(len(steps))

107


In [3]:
import torch
import torch.nn as nn
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
from torch.cuda.amp import GradScaler, autocast
from torch import optim
from custom_satkit.customMLDSGP4 import mldsgp4
from lazy_dataset.dataset import LazyDataset


def train(
        model: mldsgp4,
        dataset: LazyDataset,
        optimizer: optim.Optimizer,
        criterion: nn.Module = nn.SmoothL1Loss(),
        scaler: GradScaler = GradScaler(device),
        density=5001
        ):
    model.train()
    total_loss = 0.0

    for batch in dataset:
        # Extract TLEs and states from the batch
        tles = [step.tle for step in batch]
        all_state_lists = [step.states for step in batch]
        all_tsince_lists = [step.tsinces for step in batch]
        states = [state for state_list in all_state_lists for state in state_list]
        tsinces = [tsince for tsince_list in all_tsince_lists for tsince in tsince_list]

        optimizer.zero_grad()

        # Propagate TLEs using the model
        propagated_states = []
        for tle, tsince in zip(tles, tsinces):
            tle_expanded = [tle] * density
            time_steps = torch.linspace(0, tsince, density, device=device)

            with autocast(device):
                output_segment_states = model(tle_expanded, time_steps)
                propagated_states.append(output_segment_states)

        # Convert propagated states to a tensor
        propagated_states = torch.cat(propagated_states, dim=0).to(device)
        print(f"Propagated states shape: {propagated_states.shape}")

        # Convert ground truth states to a tensor
        target_states = torch.tensor(
            [state.get_position_vector() for state in states],
            dtype=torch.float32,
            device=device
        )

        # Compute the loss
        with autocast(device):
            loss = criterion(propagated_states, target_states)

        # Backpropagation and optimization
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()
    
    avg_loss = total_loss / len(dataset)
    return avg_loss

def eval(
    model: mldsgp4,
    dataset: LazyDataset,
    optimizer: optim.Optimizer,
    criterion: nn.Module = nn.SmoothL1Loss(),
    density=5001
):
    total_loss = 0.0

    for batch in dataset:
        # Extract TLEs and states from the batch
        tles = [step.tle for step in batch]
        all_state_lists = [step.states for step in batch]
        all_tsince_lists = [step.tsinces for step in batch]
        states = [state for state_list in all_state_lists for state in state_list]
        tsinces = [tsince for tsince_list in all_tsince_lists for tsince in tsince_list]

        optimizer.zero_grad()

        # Propagate TLEs using the model
        propagated_states = []
        for tle, tsince in zip(tles, tsinces):
            tle_expanded = [tle] * density
            time_steps = torch.linspace(0, tsince, density, device=device)

            with autocast(device):
                output_segment_states = model(tle_expanded, time_steps)
                propagated_states.append(output_segment_states)

        # Convert propagated states to a tensor
        propagated_states = torch.cat(propagated_states, dim=0).to(device)

        # Convert ground truth states to a tensor
        target_states = torch.tensor(
            [state.get_position_vector() for state in states],
            dtype=torch.float32,
            device=device
        )

        # Compute the loss
        with autocast(device):
            loss = criterion(propagated_states, target_states)

        # Evaluation cycle, we skip backprop

        total_loss += loss.item()

    avg_loss = total_loss / len(dataset)
    return avg_loss

  scaler: GradScaler = GradScaler(device),


In [5]:
from torch.optim.lr_scheduler import ReduceLROnPlateau

model = mldsgp4()
criterion = nn.SmoothL1Loss()
optimizer = optim.AdamW(model.parameters(), lr = 0.001, weight_decay = 0.05)
scheduler = ReduceLROnPlateau(optimizer)
criterion = nn.SmoothL1Loss() #we will use a SmoothL1 Criterion, which combines MSE and MAE in order to be robust to outliers and get smooth gradients
scaler = GradScaler(device)

  scaler = GradScaler(device)


In [6]:
from lazy_dataset.dataset import LazyDataset

TRAIN_PATH = os.getenv("TRAIN_PATH")
TEST_PATH = os.getenv("TEST_PATH")
VAL_PATH = os.getenv("VAL_PATH")

# LazyDataset is a custom dataset class that lazily loads data from the specified folder
# It is designed to handle large datasets efficiently by loading only the necessary data when needed
# It is assumed that the folder contains TLE files and corresponding state files in a specific format
mp = False
bs = 1
train_satellites = LazyDataset(folder = TRAIN_PATH, multiprocess=mp, batch_size=bs)
test_satellites = LazyDataset(folder = TEST_PATH, multiprocess=mp, batch_size=bs)
val_satellites = LazyDataset(folder = VAL_PATH, multiprocess=mp, batch_size=bs)

print(train_satellites[0])

['integration_85210.txt.zst']


In [None]:
EPOCHS = 100

LRS = [0.01, 0.001, 0.0001, 0.00001]
OPTIMIZERS = [
]

training_loss_over_time = []    
testing_loss_over_time = []
validation_loss_over_time = []

for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}/{EPOCHS}")
    
    # Training
    model.train()
    train_loss = train(model, train_satellites, optimizer, criterion)
    training_loss_over_time.append(train_loss)

    # Evaluation
    # model.eval()
    # test_loss = eval(model, test_satellites, optimizer)
    # testing_loss_over_time.append(test_loss)

    # # Validation
    # val_loss = eval(model, val_satellites, optimizer)
    # validation_loss_over_time.append(val_loss)

Epoch 1/100


  with autocast(device):
