In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
import torchvision.transforms as transformers
from torchvision.datasets import ImageFolder
from tqdm.notebook import tqdm
import sklearn
from sklearn.preprocessing import normalize
from tqdm import tqdm
from sklearn.preprocessing import MinMaxScaler
import glob
import time

import math
import os
from sklearn.model_selection import train_test_split
import joblib

In [2]:
initial_states_scaler = joblib.load("initial_states_scaler.gz") 

In [None]:
initial_states = []

path = "input_data/"

for dir, sub_dir, files in os.walk(path):
    for file in sorted(files):
        #print(file)
        temp = pd.read_csv((path+file),index_col=None, header=0)
        initial_states.append(temp)

initial_states_df = pd.concat(initial_states,axis=0,ignore_index=True)

initial_states_norm_df = np.where(initial_states_df.iloc[:,2:] > 1e+10,0.0, initial_states_df.iloc[:,2:])

initial_states_scaler = MinMaxScaler()
initial_states_scaler_values = initial_states_scaler.fit(initial_states_norm_df)

initial_states_normalized = initial_states_scaler_values.transform(initial_states_df.iloc[:,2:].values)

initial_states_normalized = np.where(initial_states_normalized >=1, 0.99,initial_states_normalized)

initial_states_normalized = pd.concat([initial_states_df['File ID'],pd.DataFrame(initial_states_normalized)],axis=1)



initial_states_normalized

[6.82638725e+03 4.58659977e-03 8.72826111e+01 3.59665972e+02
 3.59640694e+02 3.59845867e+02 4.50029728e+01 1.79871944e+02
 4.99987154e+02]


Unnamed: 0,File ID,0,1,2,3,4,5,6,7,8
0,0,0.800022,0.410931,0.015596,0.400407,0.714733,0.284494,0.745053,0.326270,0.895052
1,1,0.799765,0.410666,0.015691,0.398725,0.695626,0.303640,0.743967,0.695028,0.889269
2,2,0.770871,0.435522,0.013965,0.397176,0.679259,0.319940,0.610945,0.790582,0.904985
3,3,0.770750,0.437655,0.013863,0.396165,0.669874,0.329406,0.569862,0.812033,0.902606
4,4,0.769970,0.442368,0.012990,0.393378,0.635436,0.363969,0.450618,0.137236,0.899777
...,...,...,...,...,...,...,...,...,...,...
8114,8114,0.535067,0.607103,0.159052,0.284968,0.668307,0.521740,0.890529,0.801062,0.851842
8115,8115,0.690987,0.229441,0.847954,0.277111,0.423743,0.628984,0.613536,0.776649,0.835652
8116,8116,0.711425,0.203270,0.936350,0.269214,0.120393,0.364463,0.523885,0.269405,0.837877
8117,8117,0.575161,0.506746,0.338160,0.257812,0.890747,0.457447,0.801671,0.252504,0.847061


In [4]:
# initial_states = []

# path = "input_data/"

# for dir, sub_dir, files in os.walk(path):
#     for file in sorted(files):
#         #print(file)
#         temp = pd.read_csv((path+file),index_col=None, header=0)
#         initial_states.append(temp)

# initial_states_df = pd.concat(initial_states,axis=0,ignore_index=True)

# hold = initial_states_df
# x = initial_states_df.iloc[:,2:].values
# x = normalize(x,norm='l2')
# hold = pd.concat([hold['File ID'],pd.DataFrame(x)],axis=1)
# initial_states_normalized = hold
# #'2000-08-02 04:50:33'
# timestamps = initial_states_df['Timestamp']
# initial_states_normalized

In [5]:
class FullDatasetPT(Dataset):
    def __init__(self, initial_states_df, pt_dir='data/new_processed/pt_files'):
        self.data = initial_states_df.reset_index(drop=True)
        self.pt_dir = pt_dir

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        file_id = str(int(row['File ID'])).zfill(5)
        pt_path = os.path.join(self.pt_dir, f"{file_id}.pt")

        if not os.path.exists(pt_path):
            raise FileNotFoundError(f".pt file not found for File ID: {file_id}")

        static_input = torch.tensor(row.drop("File ID").fillna(0.0).values, dtype=torch.float32)
        pt_data = torch.load(pt_path)

        return (
            static_input,
            pt_data["density"],
            pt_data["density_mask"],
            pt_data["goes"],
            pt_data["goes_mask"],
            pt_data["omni2"],
            pt_data["omni2_mask"]
        )



In [6]:
# -----------------------------------
# Positional Encoding for Sequences
# -----------------------------------
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=10000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(-torch.arange(0, d_model, 2) * math.log(10000.0) / d_model)
        pe[:, 0::2] = torch.sin(pos * div_term)
        pe[:, 1::2] = torch.cos(pos * div_term)
        self.pe = pe.unsqueeze(0)

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :].to(x.device)

# -----------------------------------
# STORMTransformer with Mask Handling
# -----------------------------------
class STORMTransformer(nn.Module):
    def __init__(self,
                 static_dim=9,
                 omni2_dim=57,
                 goes_dim=42,
                 d_model=256,
                 output_len=432,
                 nhead=8,
                 num_layers=4,
                 dropout=0.1):
        super().__init__()

        self.static_encoder = nn.Sequential(
            nn.Linear(static_dim, d_model),
            nn.ReLU(),
            nn.LayerNorm(d_model)
        )

        self.omni2_proj = nn.Linear(omni2_dim, d_model)
        self.goes_proj = nn.Linear(goes_dim, d_model)

        self.omni2_pos = PositionalEncoding(d_model)
        self.goes_pos = PositionalEncoding(d_model)

        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dropout=dropout, batch_first=True)
        self.omni2_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.goes_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.fusion = nn.Sequential(
            nn.Linear(d_model * 3, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256,360),
            nn.ReLU(),
            nn.Linear(360, output_len)
        )

    def forward(self, static_input, omni2_seq, goes_seq, omni2_mask=None, goes_mask=None):
        B = static_input.size(0)

        #print("static input",static_input.shape)
        static_embed = self.static_encoder(static_input)
        #print("static embed", static_embed)
        #static_embed = self.static_encoder(static_input)

        omni2_embed = self.omni2_proj(omni2_seq)
        omni2_embed = self.omni2_pos(omni2_embed)
        omni2_key_mask = (~omni2_mask.bool()).any(dim=-1) if omni2_mask is not None else None
        omni2_out = self.omni2_encoder(omni2_embed, src_key_padding_mask=omni2_key_mask)
        omni2_summary = omni2_out.mean(dim=1)

        if goes_seq.shape[1] > 1024:
            step = goes_seq.shape[1] // 1024
            goes_seq = goes_seq[:, ::step, :]
            goes_mask = goes_mask[:, ::step, :] if goes_mask is not None else None
        # global tester_mask
        # global tester_seq
        # tester_seq = goes_seq
        # tester_mask = goes_mask
        #print(goes_seq,"\n",goes_mask)
        #print(goes_mask.sum())
        goes_embed = self.goes_proj(goes_seq)
        goes_embed = self.goes_pos(goes_embed)
        goes_key_mask = (~goes_mask.bool()).any(dim=-1) if goes_mask is not None else None
        # global tester_key_mask
        # global tester_embed
        # tester_embed = goes_embed
        # tester_key_mask = goes_key_mask
        goes_out = self.goes_encoder(goes_embed, src_key_padding_mask=goes_key_mask)
        goes_summary = goes_out.mean(dim=1)

        combined = torch.cat((static_embed, omni2_summary, goes_summary), dim=-1)
        return self.fusion(combined)

# -----------------------------------
# Masked MSE Loss
# -----------------------------------
def masked_mse_loss(preds, targets, mask, eps=1e-8):
    # preds = torch.nan_to_num(preds, nan=0.0, posinf=1e3, neginf=0.0)
    # targets = torch.nan_to_num(targets, nan=0.0, posinf=1e3, neginf=0.0)
    # loss = (preds - targets) ** 2 * mask
    # return loss.sum() / (mask.sum() + eps)
    diff = (targets - preds) * mask
    sq = torch.square(diff)
    sum = torch.sum(sq)
    N = torch.sum(mask)
    # print(sum)
    # print(N)
    loss = torch.sqrt((sum/N))
    return loss

# -----------------------------------
# Full Training Loop with FullDataset
# -----------------------------------

        

In [7]:
def train_storm_transformer(initial_states_df, num_epochs=20, batch_size=2, lr=1e-3, device=None):
    # from full_dataset import FullDataset
    # from storm_transformer import STORMTransformer, masked_mse_loss
    torch.manual_seed(42)
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    os.makedirs("checkpoints", exist_ok=True)

    torch.manual_seed(42)

    # 🔀 Train/validation split
    train_df, val_df = train_test_split(initial_states_df[0:100], test_size=0.05, random_state=42)

    train_dataset = FullDatasetPT(train_df)
    val_dataset = FullDatasetPT(val_df)

    # train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    # val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=8, pin_memory=True, )
    val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, num_workers=8, pin_memory=True)


    model = STORMTransformer().to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    checkpoint_path = "checkpoints/storm_last.pt"
    best_model_path = "checkpoints/storm_best.pt"
    start_epoch = 0
    best_val_loss = float("inf")

    # 🔁 Resume support
    if os.path.exists(checkpoint_path):
        print(f"🔁 Resuming from checkpoint: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint.get('epoch', 0)
        best_val_loss = checkpoint.get('val_loss', float("inf"))

    # 🚀 Training loop
    train_loss_history = []
    val_loss_history = []
    for epoch in range(start_epoch, num_epochs):
        model.train()
        total_train_loss = 0.0

        print(f"\n🚀 Epoch {epoch + 1}/{num_epochs}")
        #start_load = time.time()
        for batch in tqdm(train_loader):
            static_input, density, density_mask, goes, goes_mask, omni2, omni2_mask = batch
            #nd_load = time.time()

            static_input = static_input.to(device)
            density = density.to(device)
            density_mask = density_mask.to(device)
            goes = goes.to(device)
            goes_mask = goes_mask.to(device)
            omni2 = omni2.to(device)
            omni2_mask = omni2_mask.to(device)

            # if (omni2_mask.any(dim=-1).sum(dim=1) == 0).any() or (goes_mask.any(dim=-1).sum(dim=1) == 0).any():
            #     print("⚠️ Skipping batch with fully masked inputs")
            #     continue

            optimizer.zero_grad()
            preds = model(static_input, omni2, goes, omni2_mask, goes_mask)
            #print (preds)
            loss = masked_mse_loss(preds, density, density_mask)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            total_train_loss += loss.item()

            #end_batch = time.time()

            # print ("Load time:", end_load - start_load )
            # print ("Calc time:", end_batch - end_load)

        avg_train_loss = total_train_loss / len(train_loader)

        print ("Preds:", preds)
        print ("Targets",density)
        # 🧪 Validation
        
        total_val_loss = 0.0
        with torch.no_grad():
            for batch_idx, batch in enumerate(val_loader):
                static_input, density, density_mask, goes, goes_mask, omni2, omni2_mask = batch

                static_input = static_input.to(device)
                density = density.to(device)
                density_mask = density_mask.to(device)
                goes = goes.to(device)
                goes_mask = goes_mask.to(device)
                omni2 = omni2.to(device)
                omni2_mask = omni2_mask.to(device)

                if (omni2_mask.any(dim=-1).sum(dim=1) == 0).any() or (goes_mask.any(dim=-1).sum(dim=1) == 0).any():
                    continue

                preds = model(static_input, omni2, goes, omni2_mask, goes_mask)
                loss = masked_mse_loss(preds, density, density_mask)
                total_val_loss += loss.item()

                # 🧠 Mask diagnostics
                goes_mask_sum = goes_mask.sum().item()
                omni2_mask_sum = omni2_mask.sum().item()
                density_mask_sum = density_mask.sum().item()

                print(f"🧪 Eval Batch {batch_idx+1}/{len(val_loader)} — "
                      f"OMNI2 Mask Sum: {omni2_mask_sum} | "
                      f"GOES Mask Sum: {goes_mask_sum} | "
                      f"Density Mask Sum: {density_mask_sum}")

                # ⚠️ Alert if any mask has < 10% coverage
                if omni2_mask_sum < 0.1 * omni2_mask.numel():
                    print("⚠️ Low OMNI2 coverage in this batch!")
                if goes_mask_sum < 0.1 * goes_mask.numel():
                    print("⚠️ Low GOES coverage in this batch!")
                if density_mask_sum < 0.1 * density_mask.numel():
                    print("⚠️ Low density mask coverage in this batch!")

        avg_val_loss = total_val_loss / len(val_loader)
        print(f"\n📊 Epoch {epoch+1}/{num_epochs} — "
              f"Train Loss: {avg_train_loss} | Val Loss: {avg_val_loss}")
        
        train_loss_history.append(avg_train_loss)
        val_loss_history.append(avg_val_loss)
        
        print ("Preds:", preds)
        print ("Targets",density)

        # 💾 Save full checkpoint
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': avg_val_loss
        }, checkpoint_path)

        # 💎 Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), f"checkpoints/epoch{epoch}.pt")
            print("✅ Best model updated.")


In [8]:
train_storm_transformer(initial_states_normalized)


🚀 Epoch 1/20


100%|██████████| 48/48 [00:17<00:00,  2.81it/s]


Preds: tensor([[ 4.5051e-04,  4.2748e-05, -8.1519e-04, -1.8458e-04,  1.9350e-03,
         -1.9516e-03, -1.3435e-03,  6.2724e-04,  5.0115e-04, -3.5229e-04,
         -8.5708e-04,  5.9400e-04,  5.5478e-04, -1.0511e-03, -4.1045e-04,
          9.4252e-05,  1.4596e-03,  6.4479e-04,  9.2144e-04,  4.1785e-04,
          6.4290e-04, -3.7951e-04, -8.3137e-04,  3.2952e-04,  8.2809e-05,
         -4.2203e-04,  2.2619e-04, -5.6146e-04, -1.1807e-03,  5.8202e-04,
          3.7260e-04,  1.1405e-03,  2.6640e-04, -5.4953e-05,  3.8037e-04,
         -1.5548e-03, -8.8468e-05, -1.1471e-03, -1.7978e-03, -9.3518e-04,
         -1.3375e-03,  2.9879e-04, -8.4687e-05,  5.0087e-05,  3.3275e-04,
         -3.0634e-04,  7.0128e-04, -7.0680e-04,  3.9315e-04,  4.9062e-04,
          1.9522e-04, -1.5027e-03,  1.0462e-03, -2.5434e-04, -7.1629e-04,
         -7.4010e-05, -2.2833e-04,  8.1785e-04, -6.0689e-05,  1.0185e-03,
         -4.7459e-04, -4.1417e-04,  1.0787e-03,  8.6594e-04, -6.0217e-04,
         -1.1614e-03,  5.4628e-

100%|██████████| 48/48 [00:11<00:00,  4.19it/s]

Preds: tensor([[ 3.4162e-04, -4.3797e-04,  2.2092e-04, -8.5333e-04, -3.7112e-04,
          2.1292e-04,  8.8401e-05, -9.6011e-04,  8.9230e-04, -4.6972e-05,
         -4.3465e-04,  1.6028e-04,  1.5003e-04, -1.9299e-04, -1.7041e-04,
          1.4603e-06,  6.2020e-05, -2.9534e-04,  3.2356e-04,  4.4915e-04,
          3.0412e-04, -6.8359e-06,  4.2234e-04, -2.5702e-04,  4.2464e-05,
          6.1751e-05,  3.5126e-05, -2.0496e-04, -6.9020e-04,  7.1943e-05,
         -7.6093e-05, -1.0907e-04,  6.0609e-04, -2.9913e-04, -1.7012e-05,
          4.3314e-04, -1.2012e-04, -3.0136e-04,  1.4850e-04, -6.7721e-04,
          4.7234e-05, -4.0124e-04,  8.5559e-05, -4.6335e-04, -3.8563e-04,
         -3.0380e-04,  2.5057e-04, -5.3898e-05, -1.5236e-04,  8.1574e-04,
          5.7117e-04,  2.5840e-04,  5.1477e-04,  4.6602e-05,  1.9152e-05,
          6.3508e-04,  1.1346e-04,  2.2251e-04, -3.1764e-04,  3.6599e-05,
          2.6591e-04,  4.6803e-04, -2.2759e-04,  9.3216e-05,  1.5692e-04,
         -4.6473e-06, -1.5144e-




🧪 Eval Batch 1/3 — OMNI2 Mask Sum: 164160.0 | GOES Mask Sum: 683824.0 | Density Mask Sum: 784.0
⚠️ Low GOES coverage in this batch!
🧪 Eval Batch 2/3 — OMNI2 Mask Sum: 164160.0 | GOES Mask Sum: 679068.0 | Density Mask Sum: 718.0
⚠️ Low GOES coverage in this batch!
🧪 Eval Batch 3/3 — OMNI2 Mask Sum: 82080.0 | GOES Mask Sum: 333536.0 | Density Mask Sum: 254.0
⚠️ Low GOES coverage in this batch!

📊 Epoch 2/20 — Train Loss: 0.0005848915637519289 | Val Loss: 0.0003590346605051309
Preds: tensor([[ 4.3470e-04, -6.3501e-04,  3.4735e-04, -7.6108e-05, -5.4824e-04,
          4.1937e-04, -1.2435e-04, -3.8436e-04,  1.0954e-03, -2.6990e-04,
          1.2399e-04,  6.2231e-05,  1.7390e-04, -4.1629e-04, -2.8375e-04,
          8.6901e-05, -1.7512e-04, -8.8902e-04,  6.9370e-04,  4.2580e-04,
          4.2803e-04,  2.9121e-05,  8.6379e-04, -7.0258e-04, -3.5615e-04,
          1.4111e-04,  1.1342e-04, -6.1767e-04, -9.6753e-04,  1.2809e-04,
         -1.5374e-04, -2.5682e-04,  6.2377e-04, -7.5516e-04, -5.3991e-

100%|██████████| 48/48 [00:11<00:00,  4.19it/s]

Preds: tensor([[-6.3445e-05, -1.1360e-04, -9.6907e-05, -9.4324e-06, -3.5752e-05,
          1.2611e-04, -9.1456e-07,  3.7081e-05, -3.5630e-04, -5.2042e-05,
          1.4446e-04, -2.5649e-04, -1.6347e-04,  3.2205e-04,  2.5554e-04,
          9.4501e-05,  6.7333e-05,  3.0968e-04,  2.8174e-04,  4.6025e-04,
         -2.7101e-06,  1.8232e-04, -2.1942e-04, -3.8178e-04,  1.5560e-04,
         -1.5236e-04, -1.1110e-04, -1.3961e-04,  2.1156e-05, -2.1819e-05,
          8.7284e-05, -2.6456e-04,  1.6244e-04, -1.2677e-04, -8.2012e-05,
          8.2713e-05,  2.7551e-04,  4.2708e-04,  2.4854e-04,  4.9887e-05,
         -4.0680e-05, -1.1949e-05, -2.3898e-04, -2.0158e-04, -3.2585e-05,
         -5.4469e-05, -1.2624e-04, -2.3253e-04,  2.5329e-04, -5.3007e-04,
          1.5652e-04, -1.1583e-04, -7.9488e-07,  6.4874e-05,  5.6277e-04,
          7.6315e-04,  2.8513e-04,  1.2822e-04,  2.6803e-05,  5.2029e-05,
          1.5033e-04, -4.9260e-06, -4.5978e-04, -9.1717e-05,  2.2740e-04,
          5.5002e-05,  2.2659e-




🧪 Eval Batch 1/3 — OMNI2 Mask Sum: 164160.0 | GOES Mask Sum: 683824.0 | Density Mask Sum: 784.0
⚠️ Low GOES coverage in this batch!
🧪 Eval Batch 2/3 — OMNI2 Mask Sum: 164160.0 | GOES Mask Sum: 679068.0 | Density Mask Sum: 718.0
⚠️ Low GOES coverage in this batch!
🧪 Eval Batch 3/3 — OMNI2 Mask Sum: 82080.0 | GOES Mask Sum: 333536.0 | Density Mask Sum: 254.0
⚠️ Low GOES coverage in this batch!

📊 Epoch 3/20 — Train Loss: 0.0003173446387639463 | Val Loss: 0.0004077698298109074
Preds: tensor([[-7.4923e-05,  3.8303e-05, -2.5476e-04, -4.5741e-05,  2.6779e-04,
         -1.1458e-04, -2.4416e-05,  6.8052e-05,  2.7750e-05, -1.9160e-04,
         -3.7843e-04,  2.0465e-04,  1.5336e-04, -6.6070e-04, -2.5076e-04,
         -1.4329e-04,  2.8004e-05,  7.2408e-05, -3.6089e-04, -2.0764e-04,
          2.9523e-05,  1.0661e-04, -2.4312e-04,  2.3801e-04, -1.8357e-04,
          1.3560e-04, -2.8005e-04, -1.0522e-05, -3.9084e-04, -1.8222e-04,
          1.7397e-04,  3.0462e-04, -1.6634e-04,  8.3978e-05, -2.8610e-

 94%|█████████▍| 45/48 [00:10<00:00,  4.15it/s]


KeyboardInterrupt: 