In [1]:
import math
import numpy as np
import zstandard as zstd
import datetime
import io
import satkit
from satkit import satstate, time

satkit.utils.update_datafiles()

def read_zst(filepath:str):
    lines = []
    with open(filepath, "rb") as file:
        decompressor = zstd.ZstdDecompressor()
        reader = decompressor.stream_reader(file)
        text_stream = io.TextIOWrapper(reader, encoding="utf-8")

        for line in text_stream:
            lines.append(line.rstrip("\n"))
    return lines


class State:
    def __init__(self, line:str):
        orbital_data = line.split(",")

        dt_time = datetime.datetime.fromtimestamp(float(orbital_data[0].rstrip())) #to be optimized
        self.time = time.from_datetime(dt_time)
        self.pos_x = float(orbital_data[1])
        self.pos_y = float(orbital_data[2])
        self.pos_z = float(orbital_data[3])
        self.vel_x = float(orbital_data[4])
        self.vel_y = float(orbital_data[5])
        self.vel_z = float(orbital_data[6])
    
    def get_position_vector(self):
        """
        Returns the position vector as a numpy array
        """
        return np.array([self.pos_x, self.pos_y, self.pos_z])
    
    def get_velocity_vector(self):
        """
        Returns the velocity vector as a numpy array
        """
        return np.array([self.vel_x, self.vel_y, self.vel_z])
    
    def get_velocity_magnitude(self):
        """
        Returns magnitude of the velocity vector, in km/s
        """
        return math.sqrt(self.vel_x**2 + self.vel_y**2 + self.vel_z**2)/1000

    def create_SatState(self):
        """
        Returns a SatState object from satkit with the current object's data
        """
        return satstate(self.time, self.get_position_vector(), self.get_velocity_vector())

Downloading data files to /home/venom_snake/miniconda3/envs/SWARM/lib/python3.13/site-packages/satkit/satkit-data
File EGM96.gfc exists; skipping download
File ITU_GRACE16.gfc exists; skipping download
File leap-seconds.list exists; skipping download
File linux_p1550p2650.440 exists; skipping download
File JGM2.gfc exists; skipping download
File EOP-All.csv exists; skipping download
File tab5.2a.txt exists; skipping download
File JGM3.gfc exists; skipping download
File tab5.2b.txt exists; skipping download
File tab5.2d.txt exists; skipping download
File SW-All.csv exists; skipping download
Now downloading files that are regularly updated:
  Space Weather & Earth Orientation Parameters
Downloading EOP-All.csv
Downloading SW-All.csv


In [None]:
import dsgp4
from customTLE import CustomTLE as TLE
#we use our own, leaner version of TLE

def read_blocks(file_lines):
    num_tles = int(len(file_lines)/5003) # 2 lines for TLE, 5001 lines for satstates
    i = 0
    tle_arr = []
    state_arr = []
    for j in range(num_tles):
        upper_end = i + 5003
        tle = TLE([file_lines[i].rstrip(), file_lines[i+1].rstrip()])
        tle_arr.append(tle)
        i+=2
        while i < upper_end:
            state_arr.append(State(file_lines[i]))
            i+=1
    return (tle_arr, state_arr)


In [5]:
FILEPATH = "/mnt/IronWolfPro8TB/SWARM/data/output/raw/val/integration_12.txt.zst"

lines = read_zst(FILEPATH)
states = []

num_lines = len(lines)
(tle_arr, state_arr) = read_blocks(lines)
print(len(tle_arr))
print(len(state_arr))

679
3395679


In [6]:
import torch
import torch.nn as nn
import numpy as np
import customMLDSGP4

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [7]:
init_tle = tle_arr[0]
print("init_tle:")
print(init_tle)

model = customMLDSGP4.mldsgp4()

customMLDSGP4.initialize_tle(init_tle)

init_tle:
satellite_catalog_number: 12
classification: U
international_designator: 59001B
epoch_year: 2023
epoch_days: 365.78901621
date_string: 2023-12-31 18:56:11.000544
date_mjd: 60309.78902250016
mean_motion_first_derivative: 5.252148212019973e-15
mean_motion_second_derivative: 0.0
b_star: 0.0001888
ephemeris_type: 0
element_number: 1
line1: 1 00012U 59001B   23365.78901621 +.00000312 +00000-0  18880-3 0 00019
_epochdays: 365.78901621
_bstar: 0.0001888
_ndot: 9.453866781635952e-12
_nddot: 0.0
inclination: 0.5741417654068026
raan: 2.5427248086699867
eccentricity: 0.165895
argument_of_perigee: 0.6271247255095945
mean_anomaly: 5.833223718052181
mean_motion: 0.0008333316416281253
revolution_number_at_epoch: 42810
line2: 2 00012  32.8959 145.6874 1658950  35.9316 334.2191 11.45913264428101
_inclo: 0.5741417654068026
_nodeo: 2.5427248086699867
_ecco: 0.165895
_argpo: 0.6271247255095945
_mo: 5.833223718052181
_no_kozai: 0.04999989849768752
_epochyr: 2023
_jdsatepoch: 2460309.5
_jdsatepoch

tensor([1.8880e-04, 9.4539e-12, 0.0000e+00, 1.6589e-01, 6.2712e-01, 5.7414e-01,
        5.8332e+00, 5.0000e-02, 2.5427e+00])

In [None]:
from torch.cuda.amp import GradScaler, autocast
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.optim as optim

#we will use a SmoothL1 Criterion, which combines MSE and MAE in order to be robust to outliers and get smooth gradients
class SmoothL1Loss():
    def __init__(self):
        return
    
    def forward(self, predicted, target):
        return
    
criterion = SmoothL1Loss()
optimizer = optim.AdamW(model.parameters(), lr = 0.001, weight_decay = 0.05)
scheduler = ReduceLROnPlateau()
scaler = GradScaler()

def train_mldsgp4(model:customMLDSGP4, tles, states, epochs=100, batch_size=32):
    time_steps = torch.linspace(0, time, density)
    tle_expanded = [record] * density

    with autocast():
        ouput_segment_states = model(tle_expanded, time_steps)
        loss = criterion(output_segment_states,)
    segment_states = segment_states.detach().clone().numpy()
    
    return