In [None]:
import math
import numpy as np
import zstandard as zstd
import datetime
import io
import satkit
from satkit import satstate, time

satkit.utils.update_datafiles()

def read_zst(filepath:str):
    lines = []
    with open(filepath, "rb") as file:
        decompressor = zstd.ZstdDecompressor()
        reader = decompressor.stream_reader(file)
        text_stream = io.TextIOWrapper(reader, encoding="utf-8")

        for line in text_stream:
            lines.append(line.rstrip("\n"))
    return lines


class State:
    def __init__(self, line:str):
        orbital_data = line.split(",")

        dt_time = datetime.datetime.fromtimestamp(float(orbital_data[0].rstrip())) #to be optimized
        self.time = time.from_datetime(dt_time)
        self.pos_x = float(orbital_data[1])
        self.pos_y = float(orbital_data[2])
        self.pos_z = float(orbital_data[3])
        self.vel_x = float(orbital_data[4])
        self.vel_y = float(orbital_data[5])
        self.vel_z = float(orbital_data[6])
    
    def get_position_vector(self):
        """
        Returns the position vector as a numpy array
        """
        return np.array([self.pos_x, self.pos_y, self.pos_z])
    
    def get_velocity_vector(self):
        """
        Returns the velocity vector as a numpy array
        """
        return np.array([self.vel_x, self.vel_y, self.vel_z])
    
    def get_velocity_magnitude(self):
        """
        Returns magnitude of the velocity vector, in km/s
        """
        return math.sqrt(self.vel_x**2 + self.vel_y**2 + self.vel_z**2)/1000

    def create_SatState(self):
        """
        Returns a SatState object from satkit with the current object's data
        """
        return satstate(self.time, self.get_position_vector(), self.get_velocity_vector())

In [17]:
import dsgp4

def read_blocks(file_lines):
    num_tles = int(len(file_lines)/5003) # 2 lines for TLE, 5001 lines for satstates
    i = 0
    tle_arr = []
    state_arr = []
    for j in range(num_tles):
        upper_end = i + 5003
        tle = dsgp4.tle.TLE([file_lines[i].rstrip(), file_lines[i+1].rstrip()])
        tle_arr.append(tle)
        i+=2
        while i < upper_end:
            state_arr.append(State(file_lines[i]))
            i+=1
    return (tle_arr, state_arr)


In [18]:
FILEPATH = "/mnt/IronWolfPro8TB/SWARM/data/output/raw/integration_12.txt.zst"

lines = read_zst(FILEPATH)
states = []

num_lines = len(lines)
(tle_arr, state_arr) = read_blocks(lines)
print(len(tle_arr))
print(len(state_arr))

679
3395679


In [19]:
import torch
import torch.nn as nn
import numpy as np
import customMLDSGP4

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [20]:
init_tle = tle_arr[0]
print("init_tle:")
print(init_tle)

model = customMLDSGP4.mldsgp4()

customMLDSGP4.initialize_tle(init_tle)

init_tle:
TLE(
1 00012U 59001B   23365.78901621 +.00000312 +00000-0  18880-3 0 00019
2 00012  32.8959 145.6874 1658950  35.9316 334.2191 11.45913264428101
)


tensor([1.8880e-04, 9.4539e-12, 0.0000e+00, 1.6589e-01, 6.2712e-01, 5.7414e-01,
        5.8332e+00, 5.0000e-02, 2.5427e+00])

In [None]:
def train_mldsgp4(model, tles, states, epochs=100, batch_size=32):
    return