In [1]:
import satkit

# satkit.utils.update_datafiles()

In [2]:
from dotenv import load_dotenv
import os
import logging

logging.basicConfig(level=logging.INFO, filename='training.log', filemode='a', format='%(asctime)s %(levelname)s:%(message)s')
logging.info("Starting training script")

load_dotenv()

True

In [3]:
import torch
import torch.nn as nn
import numpy as np

if torch.cuda.is_available():
    torch.set_default_device('cuda')
    device_str = 'cuda'
else:
    device_str = 'cpu'

device = torch.device(device_str)

In [4]:
from torch.cuda.amp import autocast, GradScaler
from torch import optim
from dsgp4 import mldsgp4
from lazy_dataset.dataset import LazyDataset, TrainingStep, State
from dsgp4.tle import TLE
from util.transform import teme_to_gcrf, gcrf_to_teme
import satkit as sk

def extract_batch_data(batch: list[TrainingStep]) -> tuple[list[TLE], list[State], torch.Tensor]:
    """
    Extracts TLEs and states from a batch of data.
    """
    tles = [step.tle for step in batch]
    all_state_lists = [step.states for step in batch]
    all_tsince_lists = [step.tsinces for step in batch]
    states = [state for state_list in all_state_lists for state in state_list]
    tsinces = [tsince for tsince_list in all_tsince_lists for tsince in tsince_list]

    flattened_tles = []
    batched_steps = []
    for tle, tsince_list in zip(tles, all_tsince_lists):
        flattened_tles.extend([tle] * len(tsince_list))
        if len(tsince_list) > 0:
            batched_steps.extend(torch.linspace(0, tsince_list[-1], len(tsince_list), device=device).tolist())

    batched_steps = torch.tensor(batched_steps, device=device)

    return flattened_tles, states, batched_steps

def normalize_ground_truth(states, normalization_R=6958.137, normalization_V=7.947155867983262):
    """
    Normalize ground truth states to match model's output space.
    
    Args:
        states: Tensor [N, 6] with positions (km) and velocities (km/s)
        
    Returns:
        normalized_states: Tensor [N, 6] in normalized units
    """
    normalized = torch.zeros_like(states)
    normalized[:, :3] = states[:, :3] / normalization_R  # Normalize positions
    normalized[:, 3:] = states[:, 3:] / normalization_V   # Normalize velocities
    
    return normalized

def train(
        model: mldsgp4,
        dataset: LazyDataset,
        optimizer: optim.Optimizer,
        criterion: nn.Module = nn.SmoothL1Loss(),
        scaler: GradScaler = GradScaler(),
        chunk_size=512,
        ):
    model.train()
    total_loss = 0.0

    for i, batch in enumerate(dataset):
        logging.info(f"Processing batch {i+1}/{len(dataset)} with {len(batch)} steps")
        tles, states, steps = extract_batch_data(batch)

        optimizer.zero_grad()

        # Convert ground truth states to a tensor
        state_tensors = []
        for state in states:
            state_tensor_teme = gcrf_to_teme(torch.tensor(
                np.concatenate([state.get_position_vector(), state.get_velocity_vector()]),
                dtype=torch.float32,
                device=device
            ).unsqueeze(0), sk.time.from_datetime(state.dt_time))
        target_states = torch.stack(state_tensors).to(device)

        # Process in smaller chunks to save GPU memory
        propagated_states = []
        accumulated_loss = 0.0
        logging.info(f"Propagating {len(tles)} TLEs with {len(steps)} steps in chunks of size {chunk_size}")
        for i in range(0, len(steps), chunk_size):
            chunk_steps = steps[i:i+chunk_size]
            with autocast():
                chunk_states = model(tles[i:i+chunk_size], chunk_steps).to(device)
                propagated_states.append(chunk_states)

                # Compute the loss for the chunk
                loss = criterion(chunk_states, target_states[i:i+chunk_size].to(device))
                # Backpropagate the loss for this chunk
                scaler.scale(loss).backward()
            accumulated_loss += loss.item()

            # Free up memory
            torch.cuda.empty_cache()
            del chunk_states, loss

        logging.info(f"Accumulated loss for batch {i+1}: {accumulated_loss} | Average loss so far: {accumulated_loss / (i+1)}")
        # Convert propagated states to a tensor
        propagated_states = torch.cat(propagated_states, dim=0).to(device)
        print(f"Propagated states shape: {propagated_states.shape}")

        # Backpropagation and optimization
        scaler.step(optimizer)
        scaler.update()

        total_loss += accumulated_loss
    
    avg_loss = total_loss / len(dataset)
    return avg_loss

  scaler: GradScaler = GradScaler(),


In [5]:
from torch.optim.lr_scheduler import ReduceLROnPlateau

model = mldsgp4().to(device)
criterion = nn.SmoothL1Loss()
optimizer = optim.AdamW(model.parameters(), lr = 0.001, weight_decay = 0.05)
scheduler = ReduceLROnPlateau(optimizer)
criterion = nn.SmoothL1Loss() #we will use a SmoothL1 Criterion, which combines MSE and MAE in order to be robust to outliers and get smooth gradients
scaler = GradScaler()

  scaler = GradScaler()


In [6]:
from lazy_dataset.dataset import LazyDataset

TRAIN_PATH = os.getenv("TRAIN_PATH")
TEST_PATH = os.getenv("TEST_PATH")
VAL_PATH = os.getenv("VAL_PATH")

if TRAIN_PATH is None or TEST_PATH is None or VAL_PATH is None:
    raise ValueError("One or more dataset paths are not set in the environment variables.")

# LazyDataset is a custom dataset class that lazily loads data from the specified folder
# It is designed to handle large datasets efficiently by loading only the necessary data when needed
# It is assumed that the folder contains TLE files and corresponding state files in a specific format
mp = True  # Enable multiprocessing for LazyDataset
bs = 4
ns = 512
train_satellites = LazyDataset(folder = TRAIN_PATH, multiprocess=mp, batch_size=bs, num_states_per_tle=ns)
test_satellites = LazyDataset(folder = TEST_PATH, multiprocess=mp, batch_size=bs, num_states_per_tle=ns)
val_satellites = LazyDataset(folder = VAL_PATH, multiprocess=mp, batch_size=bs, num_states_per_tle=ns)

print(train_satellites[0])

['integration_30122.txt.zst', 'integration_30124.txt.zst', 'integration_30127.txt.zst', 'integration_30131.txt.zst']


In [None]:
EPOCHS = 100

LRS = [0.01, 0.001, 0.0001, 0.00001]
OPTIMIZERS = [
]

training_loss_over_time = []    
testing_loss_over_time = []
validation_loss_over_time = []

for epoch in range(EPOCHS):
    logging.info(f"Epoch {epoch+1}/{EPOCHS}")
    
    # Training
    model.train()
    train_loss = train(model, train_satellites, optimizer, criterion, chunk_size=1024)
    logging.info(f"Training loss: {train_loss}")
    training_loss_over_time.append(train_loss)

    # Evaluation
    # model.eval()
    # test_loss = eval(model, test_satellites, optimizer)
    # testing_loss_over_time.append(test_loss)

    # # Validation
    # val_loss = eval(model, val_satellites, optimizer)
    # validation_loss_over_time.append(val_loss)