In [2]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
import torchvision.transforms as transformers
from torchvision.datasets import ImageFolder
from tqdm.notebook import tqdm
import sklearn
from sklearn.preprocessing import normalize
from tqdm import tqdm

import glob
import time

import math
import os
from sklearn.model_selection import train_test_split

In [2]:
initial_states = []

for dir, sub_dir, files in os.walk("input_data/"):
    for file in sorted(files):
        print(file)
        temp = pd.read_csv(("input_data/"+file),index_col=None, header=0)
        initial_states.append(temp)
#(initial_states)
initial_states_df = pd.concat(initial_states,axis=0,ignore_index=True)
#initial_states_df

hold = initial_states_df
x = initial_states_df.iloc[:,2:].values
x = normalize(x,norm='l2')
hold = pd.concat([hold['File ID'],pd.DataFrame(x)],axis=1)
initial_states_normalized = hold
#initial_states_normalized

00000_to_02284-initial_states.csv
02285_to_02357-initial_states.csv
02358_to_04264-initial_states.csv
04265_to_05570-initial_states.csv
05571_to_05614-initial_states.csv
05615_to_06671-initial_states.csv
06672_to_08118-initial_states.csv


In [3]:
initial_states = []

path = "input_data/"

for dir, sub_dir, files in os.walk(path):
    for file in sorted(files):
        #print(file)
        temp = pd.read_csv((path+file),index_col=None, header=0)
        initial_states.append(temp)

initial_states_df = pd.concat(initial_states,axis=0,ignore_index=True)

hold = initial_states_df
x = initial_states_df.iloc[:,2:].values
x = normalize(x,norm='l2')
hold = pd.concat([hold['File ID'],pd.DataFrame(x)],axis=1)
initial_states_normalized = hold
#'2000-08-02 04:50:33'
timestamps = initial_states_df['Timestamp']
initial_states_df

Unnamed: 0,File ID,Timestamp,Semi-major Axis (km),Eccentricity,Inclination (deg),RAAN (deg),Argument of Perigee (deg),True Anomaly (deg),Latitude (deg),Longitude (deg),Altitude (km)
0,0,2000-08-02 04:50:33,6826.387247,0.003882,87.275306,144.135111,257.314389,102.383270,43.637815,-62.543128,466.448890
1,1,2000-08-03 19:51:01,6826.327748,0.003879,87.275694,143.529694,250.438806,109.273118,43.444458,70.139709,463.435053
2,2,2000-08-05 05:40:05,6819.634802,0.004114,87.268611,142.972111,244.549389,115.138737,19.764250,104.521278,471.625453
3,3,2000-08-06 05:02:20,6819.606603,0.004134,87.268194,142.608389,241.172000,118.545161,12.450738,112.239558,470.385914
4,4,2000-08-08 20:54:57,6819.425918,0.004178,87.264611,141.605111,228.779611,130.982981,-8.776992,-130.559634,468.911226
...,...,...,...,...,...,...,...,...,...,...,...
8114,8114,2019-12-25 00:00:00,6765.013678,0.005730,87.863978,102.587920,240.608198,187.758342,69.535173,108.291937,443.930606
8115,8115,2019-12-27 00:00:00,6801.130577,0.002172,90.690901,99.760357,152.602156,226.350702,20.225336,99.507926,435.492910
8116,8116,2019-12-28 00:00:00,6805.864837,0.001925,91.053632,96.918243,43.442569,131.160787,4.265831,-83.003782,436.652863
8117,8117,2019-12-30 00:00:00,6774.300973,0.004785,88.598951,92.814340,320.652681,164.621884,53.716936,-89.084995,441.438958


In [4]:
class FullDatasetPT(Dataset):
    def __init__(self, initial_states_df, pt_dir='data/processed/pt_files'):
        self.data = initial_states_df.reset_index(drop=True)
        self.pt_dir = pt_dir

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        file_id = str(int(row['File ID'])).zfill(5)
        pt_path = os.path.join(self.pt_dir, f"{file_id}.pt")

        if not os.path.exists(pt_path):
            raise FileNotFoundError(f".pt file not found for File ID: {file_id}")

        static_input = torch.tensor(row.drop("File ID").fillna(0.0).values, dtype=torch.float32)
        pt_data = torch.load(pt_path)

        return (
            static_input,
            pt_data["density"],
            pt_data["density_mask"],
            pt_data["goes"],
            pt_data["goes_mask"],
            pt_data["omni2"],
            pt_data["omni2_mask"],
            file_id
        )



In [5]:
# -----------------------------------
# Positional Encoding for Sequences
# -----------------------------------
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=10000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(-torch.arange(0, d_model, 2) * math.log(10000.0) / d_model)
        pe[:, 0::2] = torch.sin(pos * div_term)
        pe[:, 1::2] = torch.cos(pos * div_term)
        self.pe = pe.unsqueeze(0)

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :].to(x.device)

# -----------------------------------
# STORMTransformer with Mask Handling
# -----------------------------------
class STORMTransformer(nn.Module):
    def __init__(self,
                 static_dim=9,
                 omni2_dim=57,
                 goes_dim=42,
                 d_model=128,
                 output_len=432,
                 nhead=4,
                 num_layers=2,
                 dropout=0.1):
        super().__init__()

        self.static_encoder = nn.Sequential(
            nn.Linear(static_dim, d_model),
            nn.ReLU(),
            nn.LayerNorm(d_model)
        )

        self.omni2_proj = nn.Linear(omni2_dim, d_model)
        self.goes_proj = nn.Linear(goes_dim, d_model)

        self.omni2_pos = PositionalEncoding(d_model)
        self.goes_pos = PositionalEncoding(d_model)

        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dropout=dropout, batch_first=True)
        self.omni2_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.goes_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.fusion = nn.Sequential(
            nn.Linear(d_model * 3, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, output_len)
        )

    def forward(self, static_input, omni2_seq, goes_seq, omni2_mask=None, goes_mask=None):
        B = static_input.size(0)

        static_embed = self.static_encoder(static_input)

        omni2_embed = self.omni2_proj(omni2_seq)
        omni2_embed = self.omni2_pos(omni2_embed)
        omni2_key_mask = (~omni2_mask.bool()).any(dim=-1) if omni2_mask is not None else None
        omni2_out = self.omni2_encoder(omni2_embed, src_key_padding_mask=omni2_key_mask)
        omni2_summary = omni2_out.mean(dim=1)

        if goes_seq.shape[1] > 1024:
            step = goes_seq.shape[1] // 1024
            goes_seq = goes_seq[:, ::step, :]
            goes_mask = goes_mask[:, ::step, :] if goes_mask is not None else None
        global tester_mask
        global tester_seq
        tester_seq = goes_seq
        tester_mask = goes_mask
        #print(goes_seq,"\n",goes_mask)
        #print(goes_mask.sum())
        goes_embed = self.goes_proj(goes_seq)
        goes_embed = self.goes_pos(goes_embed)
        goes_key_mask = (~goes_mask.bool()).any(dim=-1) if goes_mask is not None else None
        global tester_key_mask
        global tester_embed
        tester_embed = goes_embed
        tester_key_mask = goes_key_mask
        goes_out = self.goes_encoder(goes_embed, src_key_padding_mask=goes_key_mask)
        goes_summary = goes_out.mean(dim=1)

        combined = torch.cat((static_embed, omni2_summary, goes_summary), dim=-1)
        return self.fusion(combined)

# -----------------------------------
# Masked MSE Loss
# -----------------------------------
def masked_mse_loss(preds, targets, mask, eps=1e-8):
    preds = torch.nan_to_num(preds, nan=0.0, posinf=1e3, neginf=0.0)
    targets = torch.nan_to_num(targets, nan=0.0, posinf=1e3, neginf=0.0)
    loss = (preds - targets) ** 2 * mask
    return loss.sum() / (mask.sum() + eps)

# -----------------------------------
# Full Training Loop with FullDataset
# -----------------------------------

        

In [8]:
def train_storm_transformer(initial_states_df, num_epochs=10, batch_size=8, lr=1e-4, device=None):
    # from full_dataset import FullDataset
    # from storm_transformer import STORMTransformer, masked_mse_loss

    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    os.makedirs("checkpoints", exist_ok=True)

    torch.manual_seed(42)

    # 🔀 Train/validation split
    train_df, val_df = train_test_split(initial_states_df, test_size=0.05, random_state=42)

    train_dataset = FullDatasetPT(train_df)
    val_dataset = FullDatasetPT(val_df)

    # train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    # val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=5, pin_memory=True, )
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=5, pin_memory=True)


    model = STORMTransformer().to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    checkpoint_path = "checkpoints/storm_last.pt"
    best_model_path = "checkpoints/storm_best.pt"
    start_epoch = 0
    best_val_loss = float("inf")

    # 🔁 Resume support
    if os.path.exists(checkpoint_path):
        print(f"🔁 Resuming from checkpoint: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint.get('epoch', 0)
        best_val_loss = checkpoint.get('val_loss', float("inf"))

    # 🚀 Training loop
    
    for epoch in range(start_epoch, num_epochs):
        model.train()
        total_train_loss = 0.0

        print(f"\n🚀 Epoch {epoch + 1}/{num_epochs}")
        #start_load = time.time()
        for batch in tqdm(train_loader):
            static_input, density, density_mask, goes, goes_mask, omni2, omni2_mask, _ = batch
            #nd_load = time.time()

            static_input = static_input.to(device)
            density = density.to(device)
            density_mask = density_mask.to(device)
            goes = goes.to(device)
            goes_mask = goes_mask.to(device)
            omni2 = omni2.to(device)
            omni2_mask = omni2_mask.to(device)

            # if (omni2_mask.any(dim=-1).sum(dim=1) == 0).any() or (goes_mask.any(dim=-1).sum(dim=1) == 0).any():
            #     print("⚠️ Skipping batch with fully masked inputs")
            #     continue

            optimizer.zero_grad()
            preds = model(static_input, omni2, goes, omni2_mask, goes_mask)
            #print (preds)
            loss = masked_mse_loss(preds, density, density_mask)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            total_train_loss += loss.item()

            #end_batch = time.time()

            # print ("Load time:", end_load - start_load )
            # print ("Calc time:", end_batch - end_load)

        avg_train_loss = total_train_loss / len(train_loader)

        # 🧪 Validation
        #model.eval()
        total_val_loss = 0.0
        with torch.no_grad():
            for batch_idx, batch in tqdm(val_loader):
                static_input, density, density_mask, goes, goes_mask, omni2, omni2_mask, _ = batch

                static_input = static_input.to(device)
                density = density.to(device)
                density_mask = density_mask.to(device)
                goes = goes.to(device)
                goes_mask = goes_mask.to(device)
                omni2 = omni2.to(device)
                omni2_mask = omni2_mask.to(device)

                if (omni2_mask.any(dim=-1).sum(dim=1) == 0).any() or (goes_mask.any(dim=-1).sum(dim=1) == 0).any():
                    continue

                preds = model(static_input, omni2, goes, omni2_mask, goes_mask)
                loss = masked_mse_loss(preds, density, density_mask)
                total_val_loss += loss.item()

                # 🧠 Mask diagnostics
                goes_mask_sum = goes_mask.sum().item()
                omni2_mask_sum = omni2_mask.sum().item()
                density_mask_sum = density_mask.sum().item()

                print(f"🧪 Eval Batch {batch_idx+1}/{len(val_loader)} — "
                      f"OMNI2 Mask Sum: {omni2_mask_sum} | "
                      f"GOES Mask Sum: {goes_mask_sum} | "
                      f"Density Mask Sum: {density_mask_sum}")

                # ⚠️ Alert if any mask has < 10% coverage
                if omni2_mask_sum < 0.1 * omni2_mask.numel():
                    print("⚠️ Low OMNI2 coverage in this batch!")
                if goes_mask_sum < 0.1 * goes_mask.numel():
                    print("⚠️ Low GOES coverage in this batch!")
                if density_mask_sum < 0.1 * density_mask.numel():
                    print("⚠️ Low density mask coverage in this batch!")

        avg_val_loss = total_val_loss / len(val_loader)
        print(f"\n📊 Epoch {epoch+1}/{num_epochs} — "
              f"Train Loss: {avg_train_loss:.20f} | Val Loss: {avg_val_loss:.20f}")

        # 💾 Save full checkpoint
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': avg_val_loss
        }, checkpoint_path)

        # 💎 Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), best_model_path)
            print("✅ Best model updated.")


In [7]:
train_storm_transformer(initial_states_normalized)


🚀 Epoch 1/10


100%|██████████| 483/483 [07:17<00:00,  1.10it/s]



📊 Epoch 1/10 — Train Loss: 0.00052474837222998339 | Val Loss: 0.00000000000000000000
✅ Best model updated.

🚀 Epoch 2/10


  2%|▏         | 8/483 [00:19<19:37,  2.48s/it]  


RuntimeError: Caught RuntimeError in pin memory thread for device 0.
Original Traceback (most recent call last):
  File "/home/ardrit/app/py_torch/lib/python3.13/site-packages/torch/utils/data/_utils/pin_memory.py", line 41, in do_one_step
    data = pin_memory(data, device)
  File "/home/ardrit/app/py_torch/lib/python3.13/site-packages/torch/utils/data/_utils/pin_memory.py", line 98, in pin_memory
    clone[i] = pin_memory(item, device)
               ~~~~~~~~~~^^^^^^^^^^^^^^
  File "/home/ardrit/app/py_torch/lib/python3.13/site-packages/torch/utils/data/_utils/pin_memory.py", line 64, in pin_memory
    return data.pin_memory(device)
           ~~~~~~~~~~~~~~~^^^^^^^^
RuntimeError: CUDA error: invalid argument
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.



Exception in thread Thread-8 (_pin_memory_loop):
Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7275aa055010>>
Traceback (most recent call last):
  File "/home/ardrit/app/py_torch/lib/python3.13/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 


Traceback (most recent call last):
  File [35m"/home/ardrit/anaconda3/envs/test/lib/python3.13/threading.py"[0m, line [35m1041[0m, in [35m_bootstrap_inner[0m
    [31mself.run[0m[1;31m()[0m
    [31m~~~~~~~~[0m[1;31m^^[0m
  File [35m"/home/ardrit/app/py_torch/lib/python3.13/site-packages/ipykernel/ipkernel.py"[0m, line [35m766[0m, in [35mrun_closure[0m
    [31m_threading_Thread_run[0m[1;31m(self)[0m
    [31m~~~~~~~~~~~~~~~~~~~~~[0m[1;31m^^^^^^[0m
  File [35m"/home/ardrit/anaconda3/envs/test/lib/python3.13/threading.py"[0m, line [35m992[0m, in [35mrun[0m
    [31mself._target[0m[1;31m(*self._args, **self._kwargs)[0m
    [31m~~~~~~~~~~~~[0m[1;31m^^^^^^^^^^^^^^^^^^^^^^^^^^^^^[0m
  File [35m"/home/ardrit/app/py_torch/lib/python3.13/site-packages/torch/utils/data/_utils/pin_memory.py"[0m, line [35m59[0m, in [35m_pin_memory_loop[0m
    [31mdo_one_step[0m[1;31m()[0m
    [31m~~~~~~~~~~~[0m[1;31m^^[0m
  File [35m"/home/ardrit/app/py_torch/li