In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
import torchvision.transforms as transformers
from torchvision.datasets import ImageFolder
from tqdm.notebook import tqdm
import sklearn
from sklearn.preprocessing import normalize
from tqdm import tqdm

import joblib
from sklearn.preprocessing import MinMaxScaler
import glob
import time

import math
import os
from sklearn.model_selection import train_test_split

In [None]:
omni2_scaler = joblib.load("omni2_scaler.gz") 
goes_scaler = joblib.load("goes_scaler.gz") 
initial_states_scaler = joblib.load("initial_states_scaler.gz") 

In [31]:
initial_states = []

path = "input_data/"

for dir, sub_dir, files in os.walk(path):
    for file in sorted(files):
        #print(file)
        temp = pd.read_csv((path+file),index_col=None, header=0)
        initial_states.append(temp)

initial_states_df = pd.concat(initial_states,axis=0,ignore_index=True)

initial_states_norm_df = np.where(initial_states_df.iloc[:,2:] > 1e+10,0.0, initial_states_df.iloc[:,2:])

initial_states_scaler = MinMaxScaler()
initial_states_scaler_values = initial_states_scaler.fit(initial_states_norm_df)

initial_states_normalized = initial_states_scaler_values.transform(initial_states_df.iloc[:,2:].values)

initial_states_normalized = np.where(initial_states_normalized >=1, 0.99,initial_states_normalized)

initial_states_normalized = pd.concat([initial_states_df['Timestamp'],initial_states_df['File ID'],pd.DataFrame(initial_states_normalized)],axis=1)

initial_states_normalized


Unnamed: 0,Timestamp,File ID,0,1,2,3,4,5,6,7,8
0,2000-08-02 04:50:33,0,0.800022,0.410931,0.015596,0.400407,0.714733,0.284494,0.745053,0.326270,0.895052
1,2000-08-03 19:51:01,1,0.799765,0.410666,0.015691,0.398725,0.695626,0.303640,0.743967,0.695028,0.889269
2,2000-08-05 05:40:05,2,0.770871,0.435522,0.013965,0.397176,0.679259,0.319940,0.610945,0.790582,0.904985
3,2000-08-06 05:02:20,3,0.770750,0.437655,0.013863,0.396165,0.669874,0.329406,0.569862,0.812033,0.902606
4,2000-08-08 20:54:57,4,0.769970,0.442368,0.012990,0.393378,0.635436,0.363969,0.450618,0.137236,0.899777
...,...,...,...,...,...,...,...,...,...,...,...
8114,2019-12-25 00:00:00,8114,0.535067,0.607103,0.159052,0.284968,0.668307,0.521740,0.890529,0.801062,0.851842
8115,2019-12-27 00:00:00,8115,0.690987,0.229441,0.847954,0.277111,0.423743,0.628984,0.613536,0.776649,0.835652
8116,2019-12-28 00:00:00,8116,0.711425,0.203270,0.936350,0.269214,0.120393,0.364463,0.523885,0.269405,0.837877
8117,2019-12-30 00:00:00,8117,0.575161,0.506746,0.338160,0.257812,0.890747,0.457447,0.801671,0.252504,0.847061


In [37]:
class FullDataset(Dataset):
    def __init__(self, initial_states_df, density_length=432, goes_length=86400, omni2_length=1440, density_dir='data/dataset/test/sat_density', goes_dir="data/dataset/test/goes",
                 omni2_dir="data/dataset/test/omni2"):
        self.data = initial_states_df.reset_index(drop=True)
        self.density_dir = density_dir
        self.goes_dir = goes_dir
        self.omni2_dir = omni2_dir
        self.density_length = density_length
        self.goes_length = goes_length
        self.omni2_length = omni2_length
        #self.timestamps = initial_states_df['Timestamps']
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        ts = row['Timestamp']
        #print (ts)
        static_input = row.drop('Timestamp')
        static_input = torch.tensor(row.drop('File ID').fillna(0.0).values, dtype=torch.float32)
        print (static_input)
        
        file_id = str(int(row['File ID'])).zfill(5)
        
        density_file = glob.glob(os.path.join(self.density_dir, f"*-{file_id}-*.csv"))
        goes_file = glob.glob(os.path.join(self.goes_dir, f"*-{file_id}-*.csv"))
        omni2_file = glob.glob(os.path.join(self.omni2_dir, f"*-{file_id}-*.csv"))

        pos = len(self.density_dir)+1
        density_sat = density_file[0][pos:pos+6]

        density_df = ((pd.read_csv(density_file[0])))
        density_df['Orbit Mean Density (kg/m^3)'] = np.where(density_df['Orbit Mean Density (kg/m^3)']>=1,np.nan,density_df['Orbit Mean Density (kg/m^3)'])
        if density_df.shape[0] > self.density_length:
            density_df = density_df[:self.density_length]
        elif density_df.shape[0] < self.density_length:
            padding = pd.DataFrame(np.empty((self.density_length-density_df.shape[0],2)),columns=density_df.columns)
            padding[:] = np.nan
            density_df = pd.concat((density_df,padding),ignore_index=True)
        density_df_mask = (pd.notnull(density_df)).astype(int)
        density_tensor = torch.tensor(density_df['Orbit Mean Density (kg/m^3)'].fillna(0.0).values, dtype=torch.float32)
        density_df_mask_tensor = torch.tensor(density_df_mask.iloc[:,1].values, dtype=torch.float32)
        density_stacked = torch.stack((density_tensor,density_df_mask_tensor))

        goes_df = pd.read_csv(goes_file[0])
        if goes_df.shape[0] > self.goes_length:
            goes_df = goes_df[goes_df.shape[0]-self.goes_length:goes_df.shape[0]]
        elif goes_df.shape[0] < self.goes_length:
            padding = pd.DataFrame(np.empty((self.goes_length-goes_df.shape[0],43)),columns=goes_df.columns)
            padding[:] = np.nan
            goes_df = pd.concat((padding,goes_df),ignore_index=True)
        goes_mask = (~pd.isnull(goes_df)).astype(int)
        goes_valid_mask = ((goes_df['xrsa_flag'] == 0.0) & (goes_df['xrsb_flag'] == 0.0)).astype(int)
        goes_mask = goes_mask.mul(goes_valid_mask.values,axis=0)
        goes_tensor = torch.tensor(normalize(goes_df.iloc[:, 1:].fillna(0.0).values, norm='l2'), dtype=torch.float32)
        goes_mask_tensor = torch.tensor(goes_mask.iloc[:, 1:].values, dtype=torch.float32)
        #goes_stacked = torch.stack((goes_tensor,goes_mask_tensor))
        
        omni2_df = pd.read_csv(omni2_file[0])
        if omni2_df.shape[0] > self.omni2_length:
            omni2_df = omni2_df[omni2_df.shape[0]-self.omni2_length:omni2_df.shape[0]]
        elif goes_df.shape[0] < self.omni2_length:
            padding = pd.DataFrame(np.empty((self.omni2_length-omni2_df.shape[0],58)),columns=omni2_df.columns)
            padding[:] = np.nan
            omni2_df = pd.concat((padding,omni2_df),ignore_index=True)
        omni2_tensor = torch.tensor(normalize(omni2_df.iloc[:, :57].fillna(0.0).values.astype(float), norm='l2'), dtype=torch.float32)
        omni2_mask = (~pd.isnull(omni2_df)).astype(int)
        omni2_mask_tensor = torch.tensor(omni2_mask.iloc[:, :57].values, dtype=torch.float32) 
        omni2_stacked = torch.stack((omni2_tensor,omni2_mask_tensor))

        return static_input, density_tensor, density_df_mask_tensor, goes_tensor, goes_mask_tensor, omni2_tensor, omni2_mask_tensor, ts



In [9]:
# -----------------------------------
# Positional Encoding for Sequences
# -----------------------------------
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=4320):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(-torch.arange(0, d_model, 2) * math.log(10000.0) / d_model)
        pe[:, 0::2] = torch.sin(pos * div_term)
        pe[:, 1::2] = torch.cos(pos * div_term)
        self.pe = pe.unsqueeze(0)

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :].to(x.device)

# -----------------------------------
# STORMTransformer with Feature Mask Concatenation (No Downsampling)
# -----------------------------------
class STORMTransformer(nn.Module):
    def __init__(self,
                 static_dim=9,
                 omni2_dim=57,
                 goes_dim=42,
                 d_model=128,
                 output_len=432,
                 nhead=8,
                 num_layers=4,
                 dropout=0.1):
        super().__init__()

        self.static_encoder = nn.Sequential(
            nn.Linear(static_dim, d_model),
            nn.ReLU(),
            nn.LayerNorm(d_model)
        )

        # Inputs are doubled due to feature-mask concatenation
        self.omni2_proj = nn.Linear(omni2_dim * 2, d_model)
        self.goes_proj = nn.Linear(goes_dim * 2, d_model)

        self.omni2_pos = PositionalEncoding(d_model)
        self.goes_pos = PositionalEncoding(d_model)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dropout=dropout, batch_first=True
        )
        self.omni2_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.goes_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # self.fusion = nn.Sequential(
        #     nn.Linear(d_model * 3, 256),
        #     nn.ReLU(),
        #     nn.Dropout(dropout),
        #     nn.Linear(256, output_len)
        # )
        self.fusion = nn.Sequential(
            nn.Linear(d_model * 3, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256,360),
            nn.BatchNorm1d(360),
            nn.ReLU(),
            nn.Linear(360, output_len)
        )

    def forward(self, static_input, omni2_seq, goes_seq, omni2_mask=None, goes_mask=None):
        B = static_input.size(0)

        # ----- Static Embedding -----
        #print("static input",static_input)
        static_embed = self.static_encoder(static_input)
        #print("static embed", static_embed)

        # ----- OMNI2 -----
        if omni2_mask is not None:
            omni2_cat = torch.cat([omni2_seq, omni2_mask], dim=-1)  # [B, T, 2D]
        else:
            omni2_cat = omni2_seq
        omni2_embed = self.omni2_proj(omni2_cat)
        omni2_embed = self.omni2_pos(omni2_embed)
        omni2_out = self.omni2_encoder(omni2_embed)  # ⬅️ No key mask
        omni2_summary = omni2_out.mean(dim=1)

        # ----- GOES Downsampling to 8640 -----
        if goes_seq.shape[1] > 4320:
            step = goes_seq.shape[1] // 4320
            goes_seq = goes_seq[:, ::step, :]
            goes_mask = goes_mask[:, ::step, :] if goes_mask is not None else None

        if goes_mask is not None:
            goes_cat = torch.cat([goes_seq, goes_mask], dim=-1)  # [B, T, 2D]
        else:
            goes_cat = goes_seq
        goes_embed = self.goes_proj(goes_cat)
        goes_embed = self.goes_pos(goes_embed)
        goes_out = self.goes_encoder(goes_embed)  # ⬅️ No key mask
        goes_summary = goes_out.mean(dim=1)

        # print("static",static_embed)
        # print("omni2",omni2_summary)
        # print("goes",goes_summary)

        # ----- Fusion -----
        combined = torch.cat((static_embed, omni2_summary, goes_summary), dim=-1)
        return self.fusion(combined)

# -----------------------------------
# Masked MSE Loss
# -----------------------------------
def masked_mse_loss(preds, targets, mask, eps=1e-8):
    # preds = torch.nan_to_num(preds, nan=0.0, posinf=1e3, neginf=0.0)
    # targets = torch.nan_to_num(targets, nan=0.0, posinf=1e3, neginf=0.0)
    # loss = (preds - targets) ** 2 * mask
    # return loss.sum() / (mask.sum() + eps)
    diff = (targets - preds) * mask
    sq = torch.square(diff)
    sum = torch.sum(sq)
    N = torch.sum(mask)
    # print(sum)
    # print(N)
    loss = torch.sqrt((sum/N))
    return loss

# -----------------------------------
# Full Training Loop with FullDataset
# -----------------------------------

train_loss_history = []
val_loss_history = []

In [None]:
def predict(initial_states_df, batch_size=16, device=None):

    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    os.makedirs("checkpoints", exist_ok=True)

    torch.manual_seed(42)

    

    test_dataset = FullDataset(initial_states_df)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True, )
    
    model = STORMTransformer().to(device)
    model.load_state_dict(torch.load('epoch56.pt', weights_only=True))


    model.eval()
    with torch.no_grad():
        for batch in tqdm(test_loader):
            static_input, density, density_mask, goes, goes_mask, omni2, omni2_mask, ts = batch
            
            static_input = static_input.to(device)
            density = density.to(device)
            density_mask = density_mask.to(device)
            goes = goes.to(device)
            goes_mask = goes_mask.to(device)
            omni2 = omni2.to(device)
            omni2_mask = omni2_mask.to(device)
            preds = model(static_input, omni2, goes, omni2_mask, goes_mask)
            print (preds)


In [38]:
predict(initial_states_normalized)

  0%|          | 0/508 [00:00<?, ?it/s]

  0%|          | 0/508 [00:00<?, ?it/s]


TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/ardrit/app/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/worker.py", line 349, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/ardrit/app/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
            ~~~~~~~~~~~~^^^^^
  File "/tmp/ipykernel_7240/67470720.py", line 21, in __getitem__
    static_input = torch.tensor(row.drop('File ID').fillna(0.0).values, dtype=torch.float32)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint64, uint32, uint16, uint8, and bool.
