In [None]:
import math
import numpy as np
import zstandard as zstd
import datetime
import io
import satkit
from satkit import satstate, time
from typing import List, Tuple, Dict

satkit.utils.update_datafiles()

def read_zst(filepath:str):
    lines = []
    with open(filepath, "rb") as file:
        decompressor = zstd.ZstdDecompressor()
        reader = decompressor.stream_reader(file)
        text_stream = io.TextIOWrapper(reader, encoding="utf-8")

        for line in text_stream:
            lines.append(line.rstrip("\n"))
    return lines


class State:
    def __init__(self, line:str):
        orbital_data = line.split(",")

        dt_time = datetime.datetime.fromtimestamp(float(orbital_data[0].rstrip())) #to be optimized
        self.time = time.from_datetime(dt_time)
        self.pos_x = float(orbital_data[1])
        self.pos_y = float(orbital_data[2])
        self.pos_z = float(orbital_data[3])
        self.vel_x = float(orbital_data[4])
        self.vel_y = float(orbital_data[5])
        self.vel_z = float(orbital_data[6])
    
    def get_position_vector(self):
        """
        Returns the position vector as a numpy array
        """
        return np.array([self.pos_x, self.pos_y, self.pos_z])
    
    def get_velocity_vector(self):
        """
        Returns the velocity vector as a numpy array
        """
        return np.array([self.vel_x, self.vel_y, self.vel_z])
    
    def get_velocity_magnitude(self):
        """
        Returns magnitude of the velocity vector, in km/s
        """
        return math.sqrt(self.vel_x**2 + self.vel_y**2 + self.vel_z**2)/1000

    def create_SatState(self):
        """
        Returns a SatState object from satkit with the current object's data
        """
        return satstate(self.time, self.get_position_vector(), self.get_velocity_vector())

In [None]:
import dsgp4
from customTLE import CustomTLE as TLE
#we use our own, leaner version of TLE

def read_blocks(file_lines):
    num_tles = int(len(file_lines)/5003) # 2 lines for TLE, 5001 lines for satstates
    i = 0
    tle_arr = []
    state_arr = []
    for j in range(num_tles):
        upper_end = i + 5003
        tle = TLE([file_lines[i].rstrip(), file_lines[i+1].rstrip()])
        tle_arr.append(tle)
        i+=2
        while i < upper_end:
            state_arr.append(State(file_lines[i]))
            i+=1
    return (tle_arr, state_arr)

def compute_tsinces(epoch:float, states: List[State]):
    """
    Epoch: Unix timestamp of TLE
    States: List of State objects representing satellite states in time
    """
    tsinces = []

    for state in states:
        tsince = (state.time - epoch) / 60
        tsinces.append(tsince)

    return tsinces

def batch_list(input_list: List[TLE], batch_size: int = 32):
    return [input_list[i:i+batch_size] for i in range(0, len(input_list), batch_size)]

In [None]:
FILEPATH = "/mnt/IronWolfPro8TB/SWARM/data/output/raw/val/integration_12.txt.zst"

lines = read_zst(FILEPATH)
states = []

num_lines = len(lines)
(tle_arr, state_arr) = read_blocks(lines)
print(len(tle_arr))
print(len(state_arr))

In [None]:
import torch
import torch.nn as nn
import numpy as np
import customMLDSGP4

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
init_tle = tle_arr[0]
print("init_tle:")
print(init_tle)

model = customMLDSGP4.mldsgp4()

customMLDSGP4.initialize_tle(init_tle)

In [None]:
from torch.cuda.amp import autocast
from torch.amp import GradScaler
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.optim as optim

#we will use a SmoothL1 Criterion, which combines MSE and MAE in order to be robust to outliers and get smooth gradients
class SmoothL1Loss():
    def __init__(self):
        return
    
    def forward(self, predicted, target):
        return

def train_mldsgp4(model: customMLDSGP4, tles_batch: List[TLE], tsinces: List[float], targets: List[State], density = 1, epochs = 100, batch_size = 32):
    model.train()
    total_loss = 0.0
    for i, tle in enumerate(tles_batch):
        tle = tle.to(device)
        optimizer.zero_grad()

        tle_expanded = [tle] * density
        time_steps = torch.linspace(0, tsinces[i], density, device=device)

        with autocast():
            ouput_segment_states = model(tle_expanded, time_steps)
            loss = criterion(output_segment_states, targets[i].to(device))
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()
    avg_loss = (total_loss / len(tles_batch))
    return (model, avg_loss)

In [None]:
criterion = nn.SmoothL1Loss()
optimizer = optim.AdamW(model.parameters(), lr = 0.001, weight_decay = 0.05)
scheduler = ReduceLROnPlateau(optimizer)
scaler = GradScaler()

In [None]:
from custom_dataset.dataset import CustomDataset

TRAIN_PATH = "/mnt/IronWolfPro8TB/SWARM/data/output/raw/train"
TEST_PATH = "/mnt/IronWolfPro8TB/SWARM/data/output/raw/test"
VAL_PATH = "/mnt/IronWolfPro8TB/SWARM/data/output/raw/val"

train_satellites = CustomDataset(folder = TRAIN_PATH)
test_satellites = CustomDataset(folder = TEST_PATH)
val_satellites = CustomDataset(folder = VAL_PATH)

print(train_satellites[0])