In [None]:
import satkit

satkit.utils.update_datafiles()

In [None]:
from util.read import read_zst, read_blocks

FILEPATH = "/mnt/IronWolfPro8TB/SWARM/data/output/raw/val/integration_12.txt.zst"

lines = read_zst(FILEPATH)
states = []

num_lines = len(lines)
(tle_arr, state_arr) = read_blocks(lines)
print(len(tle_arr))
print(len(state_arr))

In [None]:
import torch
import torch.nn as nn
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
from torch.cuda.amp import autocast
from torch.amp import GradScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch import optim
from custom_satkit.customMLDSGP4 import mldsgp4
from custom_satkit.CustomTLE import CustomTLE as TLE
from custom_dataset.dataset import State, LazyDataset

def train(
        model: mldsgp4,
        dataset: LazyDataset,
        optimizer: optim.Optimizer,
        criterion: nn.Module = nn.SmoothL1Loss(),
        scaler: GradScaler = GradScaler(),
        density = 5001
        ):
    
    model.train()
    total_loss = 0.0

    #TODO: Set up proper training algorithm with tles, batches, and targets
    # for batch in dataset:
    #     batch = batch.to(device)
    #     tle = batch.tle.to(device)
    #     optimizer.zero_grad()

    #     tle_expanded = [tle] * density
    #     time_steps = torch.linspace(0, tsinces[i], density, device=device)

    #     with autocast(device):
    #         output_segment_states = model(tle_expanded, time_steps)
    #         loss = criterion(output_segment_states, targets[i].to(device))
        
    #     scaler.scale(loss).backward()
    #     scaler.step(optimizer)
    #     scaler.update()

    #     total_loss += loss.item()
    # avg_loss = (total_loss / len(tles_batch))
    # return avg_loss

def evaluate(
        model: mldsgp4,
        tles_batch: list[TLE],
        tsinces: list[float],
        targets: list[State],
        criterion: nn.Module = nn.SmoothL1Loss(),
        density = 1
        ):
    
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for i, tle in enumerate(tles_batch):
            tle = tle.to(device)

            tle_expanded = [tle] * density
            time_steps = torch.linspace(0, tsinces[i], density, device=device)

            output_segment_states = model(tle_expanded, time_steps)
            loss = criterion(output_segment_states, targets[i].to(device))

            total_loss += loss.item()
    avg_loss = (total_loss / len(tles_batch))
    return avg_loss

In [None]:
model = mldsgp4()
criterion = nn.SmoothL1Loss()
optimizer = optim.AdamW(model.parameters(), lr = 0.001, weight_decay = 0.05)
scheduler = ReduceLROnPlateau(optimizer)
criterion = nn.SmoothL1Loss() #we will use a SmoothL1 Criterion, which combines MSE and MAE in order to be robust to outliers and get smooth gradients
scaler = GradScaler()

In [None]:
from custom_dataset.dataset import LazyDataset

#replace with your own paths
# TRAIN_PATH = "/mnt/IronWolfPro8TB/SWARM/data/output/raw/train"
# TEST_PATH = "/mnt/IronWolfPro8TB/SWARM/data/output/raw/test"
# VAL_PATH = "/mnt/IronWolfPro8TB/SWARM/data/output/raw/val"

TRAIN_PATH = "../data/train"
TEST_PATH = "../data/test"
VAL_PATH = "../data/val"

# LazyDataset is a custom dataset class that lazily loads data from the specified folder
# It is designed to handle large datasets efficiently by loading only the necessary data when needed
# It is assumed that the folder contains TLE files and corresponding state files in a specific format
train_satellites = LazyDataset(folder = TRAIN_PATH)
test_satellites = LazyDataset(folder = TEST_PATH)
val_satellites = LazyDataset(folder = VAL_PATH)

print(train_satellites[0])

In [None]:
EPOCHS = 100

LRS = [0.01, 0.001, 0.0001, 0.00001]
OPTIMIZERS = [
]
for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}/{EPOCHS}")
    
    # Training
    model.train()
    training_batch = train_satellites.get_next_batch()
    train(model, training_batch, optimizer, criterion)
    
    # Evaluation
    model.eval()
    val_loss = 0.0
    for i in range(0, len(val_satellites), 32):
        tles_batch = val_satellites[i:i+32]
        tsinces = compute_tsinces(tles_batch[0].epoch, [state for tle in tles_batch for state in tle.states])
        targets = [state for tle in tles_batch for state in tle.states]
        
        avg_loss = evaluate(model, tles_batch, tsinces, targets, criterion)
        val_loss += avg_loss
    
    print(f"Validation Loss: {val_loss / (len(val_satellites) // 32)}")