In [6]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
import torchvision.transforms as transformers
from torchvision.datasets import ImageFolder
from tqdm.notebook import tqdm
import sklearn
from sklearn.preprocessing import normalize
from tqdm import tqdm

import joblib
from sklearn.preprocessing import MinMaxScaler
import glob
import time

import math
import os
from sklearn.model_selection import train_test_split

In [7]:
omni2_scaler = joblib.load("omni2_scaler.gz") 
goes_scaler = joblib.load("goes_scaler.gz") 
initial_states_scaler = joblib.load("initial_states_scaler.gz") 

In [8]:
omni2_scaler.data_min_

array([ 2.000e+03,  1.550e+02,  0.000e+00,  2.277e+03,  5.100e+01,
        4.500e+01,  9.000e+00,  1.000e+00,  1.600e+00,  6.000e-01,
       -8.510e+01,  4.000e-01, -1.660e+01, -1.740e+01, -4.520e+01,
       -1.750e+01, -4.530e+01,  0.000e+00,  3.000e-01,  0.000e+00,
        1.000e-01,  1.000e-01,  6.772e+03,  5.000e-01,  3.240e+02,
       -1.300e+01, -1.050e+01,  8.000e-03,  0.000e+00,  0.000e+00,
        0.000e+00,  0.000e+00,  0.000e+00,  0.000e+00,  2.700e-01,
       -2.296e+01,  0.000e+00,  1.200e+00,  1.200e+00,  1.200e-03,
        3.000e+00,  1.130e+02, -3.000e+02,  2.000e+00,  1.524e+02,
        3.000e+01, -1.164e+03, -1.500e+01, -7.700e+00,  8.363e-03,
        3.600e-01,  3.400e-01,  3.400e-01,  3.200e-01,  1.800e-01,
        1.000e-01, -1.000e+00])

In [9]:
initial_states = []

path = "input_data/"

for dir, sub_dir, files in os.walk(path):
    for file in sorted(files):
        #print(file)
        temp = pd.read_csv((path+file),index_col=None, header=0)
        initial_states.append(temp)

initial_states_df = pd.concat(initial_states,axis=0,ignore_index=True)

initial_states_norm_df = np.where(initial_states_df.iloc[:,2:] > 1e+10,0.0, initial_states_df.iloc[:,2:])

initial_states_scaler = MinMaxScaler()
initial_states_scaler_values = initial_states_scaler.fit(initial_states_norm_df)

initial_states_normalized = initial_states_scaler_values.transform(initial_states_df.iloc[:,2:].values)

initial_states_normalized = np.where(initial_states_normalized >=1, 0.99,initial_states_normalized)

initial_states_normalized = pd.concat([initial_states_df['File ID'],pd.DataFrame(initial_states_normalized)],axis=1)

# hold = initial_states_df
# x = initial_states_df.iloc[:,2:].values
# x = normalize(x,norm='l2')
# hold = pd.concat([hold['File ID'],pd.DataFrame(x)],axis=1)
# initial_states_normalized = hold
# #'2000-08-02 04:50:33'
# timestamps = initial_states_df['Timestamp']
# #initial_states_df

In [10]:
class FullDatasetPT(Dataset):
    def __init__(self, initial_states_df, pt_dir='data/new_processed/pt_files'):
        self.data = initial_states_df.reset_index(drop=True)
        self.pt_dir = pt_dir

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        file_id = str(int(row['File ID'])).zfill(5)
        pt_path = os.path.join(self.pt_dir, f"{file_id}.pt")

        if not os.path.exists(pt_path):
            raise FileNotFoundError(f".pt file not found for File ID: {file_id}")

        static_input = torch.tensor(row.drop("File ID").fillna(0.0).values, dtype=torch.float32)
        pt_data = torch.load(pt_path)
        
        #print (pt_data)
        return (
            static_input,
            pt_data["density"],
            pt_data["density_mask"],
            pt_data["goes"],
            pt_data["goes_mask"],
            pt_data["omni2"],
            pt_data["omni2_mask"]
        )



In [15]:
# -----------------------------------
# Positional Encoding for Sequences
# -----------------------------------
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=4320):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(-torch.arange(0, d_model, 2) * math.log(10000.0) / d_model)
        pe[:, 0::2] = torch.sin(pos * div_term)
        pe[:, 1::2] = torch.cos(pos * div_term)
        self.pe = pe.unsqueeze(0)

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :].to(x.device)

# -----------------------------------
# STORMTransformer with Feature Mask Concatenation (No Downsampling)
# -----------------------------------
class STORMTransformer(nn.Module):
    def __init__(self,
                 static_dim=9,
                 omni2_dim=57,
                 goes_dim=42,
                 d_model=128,
                 output_len=432,
                 nhead=8,
                 num_layers=4,
                 dropout=0.1):
        super().__init__()

        self.static_encoder = nn.Sequential(
            nn.Linear(static_dim, d_model),
            nn.ReLU(),
            nn.LayerNorm(d_model)
        )

        # Inputs are doubled due to feature-mask concatenation
        self.omni2_proj = nn.Linear(omni2_dim * 2, d_model)
        self.goes_proj = nn.Linear(goes_dim * 2, d_model)

        self.omni2_pos = PositionalEncoding(d_model)
        self.goes_pos = PositionalEncoding(d_model)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dropout=dropout, batch_first=True
        )
        self.omni2_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.goes_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # self.fusion = nn.Sequential(
        #     nn.Linear(d_model * 3, 256),
        #     nn.ReLU(),
        #     nn.Dropout(dropout),
        #     nn.Linear(256, output_len)
        # )
        self.fusion = nn.Sequential(
            nn.Linear(d_model * 3, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256,360),
            nn.BatchNorm1d(360),
            nn.ReLU(),
            nn.Linear(360, output_len)
        )

    def forward(self, static_input, omni2_seq, goes_seq, omni2_mask=None, goes_mask=None):
        B = static_input.size(0)

        # ----- Static Embedding -----
        #print("static input",static_input)
        static_embed = self.static_encoder(static_input)
        #print("static embed", static_embed)

        # ----- OMNI2 -----
        if omni2_mask is not None:
            omni2_cat = torch.cat([omni2_seq, omni2_mask], dim=-1)  # [B, T, 2D]
        else:
            omni2_cat = omni2_seq
        omni2_embed = self.omni2_proj(omni2_cat)
        omni2_embed = self.omni2_pos(omni2_embed)
        omni2_out = self.omni2_encoder(omni2_embed)  # ⬅️ No key mask
        omni2_summary = omni2_out.mean(dim=1)

        # ----- GOES Downsampling to 8640 -----
        if goes_seq.shape[1] > 4320:
            step = goes_seq.shape[1] // 4320
            goes_seq = goes_seq[:, ::step, :]
            goes_mask = goes_mask[:, ::step, :] if goes_mask is not None else None

        if goes_mask is not None:
            goes_cat = torch.cat([goes_seq, goes_mask], dim=-1)  # [B, T, 2D]
        else:
            goes_cat = goes_seq
        goes_embed = self.goes_proj(goes_cat)
        goes_embed = self.goes_pos(goes_embed)
        goes_out = self.goes_encoder(goes_embed)  # ⬅️ No key mask
        goes_summary = goes_out.mean(dim=1)

        # print("static",static_embed)
        # print("omni2",omni2_summary)
        # print("goes",goes_summary)

        # ----- Fusion -----
        combined = torch.cat((static_embed, omni2_summary, goes_summary), dim=-1)
        return self.fusion(combined)

# -----------------------------------
# Masked MSE Loss
# -----------------------------------
def masked_mse_loss(preds, targets, mask, eps=1e-8):
    # preds = torch.nan_to_num(preds, nan=0.0, posinf=1e3, neginf=0.0)
    # targets = torch.nan_to_num(targets, nan=0.0, posinf=1e3, neginf=0.0)
    # loss = (preds - targets) ** 2 * mask
    # return loss.sum() / (mask.sum() + eps)
    diff = (targets - preds) * mask
    sq = torch.square(diff)
    sum = torch.sum(sq)
    N = torch.sum(mask)
    # print(sum)
    # print(N)
    loss = torch.sqrt((sum/N))
    return loss

# -----------------------------------
# Full Training Loop with FullDataset
# -----------------------------------

train_loss_history = []
val_loss_history = []

In [19]:
def train_storm_transformer(initial_states_df, num_epochs=100, batch_size=16, lr=1e-3, device=None):
    # from full_dataset import FullDataset
    # from storm_transformer import STORMTransformer, masked_mse_loss

    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    os.makedirs("checkpoints", exist_ok=True)

    torch.manual_seed(42)

    # 🔀 Train/validation split
    train_df, val_df = train_test_split(initial_states_df[0:8112], test_size=0.05, random_state=42)

    train_dataset = FullDatasetPT(train_df)
    val_dataset = FullDatasetPT(val_df)

    # train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    # val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True, )
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)


    model = STORMTransformer().to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    checkpoint_path = "checkpoints/storm_last.pt"
    best_model_path = "checkpoints/storm_best.pt"
    start_epoch = 0
    best_val_loss = float("inf")

    # 🔁 Resume support
    if os.path.exists(checkpoint_path):
        print(f"🔁 Resuming from checkpoint: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint.get('epoch', 0)
        best_val_loss = checkpoint.get('val_loss', float("inf"))

    # 🚀 Training loop
    
    
    for epoch in range(start_epoch, num_epochs):
        model.train()
        total_train_loss = 0.0

        print(f"\n🚀 Epoch {epoch + 1}/{num_epochs}")
        #start_load = time.time()
        for batch in tqdm(train_loader):
            
            static_input, density, density_mask, goes, goes_mask, omni2, omni2_mask = batch
            #nd_load = time.time()
            global omni2_out
            omni2_out = omni2
            
            static_input = static_input.to(device)
            density = density.to(device)
            density_mask = density_mask.to(device)
            goes = goes.to(device)
            goes_mask = goes_mask.to(device)
            omni2 = omni2.to(device)
            omni2_mask = omni2_mask.to(device)
            # print ("static",static_input)
            # print ("dens",density)
            # print ("goes",goes)
            # print ("omni2",omni2)

            # if (omni2_mask.any(dim=-1).sum(dim=1) == 0).any() or (goes_mask.any(dim=-1).sum(dim=1) == 0).any():
            #     print("⚠️ Skipping batch with fully masked inputs")
            #     continue

            optimizer.zero_grad()
            
            preds = model(static_input, omni2, goes, omni2_mask, goes_mask)
            # print (preds)
            # print ("------------------Preds--------------------",len(preds))
            # print ("Preds:", preds)
            # print ("Targets",density)
            #print (preds)
            loss = masked_mse_loss(preds, density, density_mask)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            total_train_loss += loss.item()

            #end_batch = time.time()

            # print ("Load time:", end_load - start_load )
            # print ("Calc time:", end_batch - end_load)

        avg_train_loss = total_train_loss / len(train_loader)
        print ("Preds:", preds)
        print ("Targets",density)
        
        # 🧪 Validation
        model.eval()
        total_val_loss = 0.0
        with torch.no_grad():
            for batch_idx, batch in enumerate(val_loader):
                static_input, density, density_mask, goes, goes_mask, omni2, omni2_mask = batch
                

                static_input = static_input.to(device)
                density = density.to(device)
                density_mask = density_mask.to(device)
                goes = goes.to(device)
                goes_mask = goes_mask.to(device)
                omni2 = omni2.to(device)
                omni2_mask = omni2_mask.to(device)

                if (omni2_mask.any(dim=-1).sum(dim=1) == 0).any() or (goes_mask.any(dim=-1).sum(dim=1) == 0).any():
                    continue

                preds = model(static_input, omni2, goes, omni2_mask, goes_mask)
                
                loss = masked_mse_loss(preds, density, density_mask)
                #print ("loss",loss,"-------------------")
                total_val_loss += loss.item()

                # 🧠 Mask diagnostics
                goes_mask_sum = goes_mask.sum().item()
                omni2_mask_sum = omni2_mask.sum().item()
                density_mask_sum = density_mask.sum().item()

                print(f"🧪 Eval Batch {batch_idx+1}/{len(val_loader)} — "
                      f"OMNI2 Mask Sum: {omni2_mask_sum} | "
                      f"GOES Mask Sum: {goes_mask_sum} | "
                      f"Density Mask Sum: {density_mask_sum}")

                # ⚠️ Alert if any mask has < 10% coverage
                if omni2_mask_sum < 0.1 * omni2_mask.numel():
                    print("⚠️ Low OMNI2 coverage in this batch!")
                if goes_mask_sum < 0.1 * goes_mask.numel():
                    print("⚠️ Low GOES coverage in this batch!")
                if density_mask_sum < 0.1 * density_mask.numel():
                    print("⚠️ Low density mask coverage in this batch!")

        avg_val_loss = total_val_loss / len(val_loader)
        print(f"\n📊 Epoch {epoch+1}/{num_epochs} — "
              f"Train Loss: {avg_train_loss} | Val Loss: {avg_val_loss}")
        
        print ("Preds:", preds)
        print ("Targets",density)
        train_loss_history.append(avg_train_loss)
        val_loss_history.append(avg_val_loss)

        # 💾 Save full checkpoint
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': avg_val_loss
        }, checkpoint_path)

        # 💎 Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), f"checkpoints/epoch{epoch}.pt")
            print("✅ Best model updated.")


In [18]:
train_storm_transformer(initial_states_normalized)


🚀 Epoch 1/100


100%|██████████| 964/964 [10:04<00:00,  1.59it/s]

Preds: tensor([[ 3.0248e-03,  3.0156e-03, -4.4864e-03, -7.0914e-03,  2.5942e-03,
         -1.4351e-03, -8.5881e-03, -7.0874e-03,  6.7298e-03,  8.8343e-03,
          2.2860e-03,  7.1481e-03,  4.5089e-03,  5.3674e-03, -4.9099e-03,
         -1.5461e-03, -3.3813e-03, -1.3004e-03, -3.8007e-03,  6.8478e-03,
         -3.7238e-03,  1.1579e-03,  3.5341e-04, -5.4347e-03,  4.7645e-03,
          3.6788e-03,  6.0750e-03, -2.3068e-04,  7.2794e-03,  5.4921e-04,
         -2.6330e-03, -7.9644e-04, -3.2662e-03, -5.7969e-03,  1.3516e-03,
         -1.4603e-03, -5.2927e-03,  9.1365e-03, -4.1576e-03, -1.3049e-03,
          6.4635e-04,  3.4087e-03, -2.7705e-03, -4.7039e-05,  3.9814e-03,
         -7.8685e-03,  6.4209e-03,  9.5020e-04, -1.7637e-03, -7.9494e-05,
         -8.1475e-03, -2.5583e-03,  5.2639e-03,  6.3449e-03, -8.3751e-04,
          9.5774e-04,  9.6082e-04,  1.6680e-03, -7.4036e-03,  4.2483e-03,
         -6.9514e-03,  7.2583e-03, -3.4201e-03,  2.0080e-03,  2.3140e-03,
         -4.2588e-03,  1.2134e-




🧪 Eval Batch 1/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 6515662.0 | Density Mask Sum: 3443.0
🧪 Eval Batch 16/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 4162416.0 | Density Mask Sum: 3343.0
🧪 Eval Batch 22/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5345290.0 | Density Mask Sum: 3416.0
🧪 Eval Batch 49/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 6166782.0 | Density Mask Sum: 3456.0
🧪 Eval Batch 51/51 — OMNI2 Mask Sum: 492480.0 | GOES Mask Sum: 2937816.0 | Density Mask Sum: 2592.0

📊 Epoch 1/100 — Train Loss: 0.06857383849207482 | Val Loss: 0.001162332117411436
Preds: tensor([[ 0.0200,  0.0170,  0.0061,  ...,  0.0204,  0.0022,  0.0330],
        [ 0.0086,  0.0257, -0.0267,  ...,  0.0172, -0.0041,  0.0142],
        [-0.0064,  0.0110, -0.0029,  ...,  0.0101, -0.0230, -0.0234],
        [ 0.0004,  0.0095, -0.0073,  ..., -0.0118, -0.0027,  0.0012],
        [ 0.0176,  0.0177, -0.0025,  ...,  0.0288, -0.0060,  0.0338],
        [-0.0102,  0.0165,  0.0005,  ...,  0.0052, -0.0245, -0

  3%|▎         | 30/964 [00:24<12:54,  1.21it/s] 
Exception in thread Thread-29 (_pin_memory_loop):
Traceback (most recent call last):
  File "/root/miniconda3/envs/runpod_conda/lib/python3.12/threading.py", line 1075, in _bootstrap_inner
    self.run()
  File "/root/miniconda3/envs/runpod_conda/lib/python3.12/site-packages/ipykernel/ipkernel.py", line 766, in run_closure
    _threading_Thread_run(self)
  File "/root/miniconda3/envs/runpod_conda/lib/python3.12/threading.py", line 1012, in run
    self._target(*self._args, **self._kwargs)
  File "/root/miniconda3/envs/runpod_conda/lib/python3.12/site-packages/torch/utils/data/_utils/pin_memory.py", line 59, in _pin_memory_loop
    do_one_step()
  File "/root/miniconda3/envs/runpod_conda/lib/python3.12/site-packages/torch/utils/data/_utils/pin_memory.py", line 35, in do_one_step
    r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/runpod_conda/lib/pyth

KeyboardInterrupt: 

In [None]:
plt.plot(val_loss_history)

In [None]:
omni2_out