In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
import torchvision.transforms as transformers
from torchvision.datasets import ImageFolder
from tqdm.notebook import tqdm
import sklearn
from sklearn.preprocessing import normalize
from tqdm import tqdm

import glob
import time

import math
import os
from sklearn.model_selection import train_test_split

In [2]:
initial_states = []

path = "input_data/"

for dir, sub_dir, files in os.walk(path):
    for file in sorted(files):
        #print(file)
        temp = pd.read_csv((path+file),index_col=None, header=0)
        initial_states.append(temp)

initial_states_df = pd.concat(initial_states,axis=0,ignore_index=True)

hold = initial_states_df
x = initial_states_df.iloc[:,2:].values
x = normalize(x,norm='l2')
hold = pd.concat([hold['File ID'],pd.DataFrame(x)],axis=1)
initial_states_normalized = hold
#'2000-08-02 04:50:33'
timestamps = initial_states_df['Timestamp']
#initial_states_df

In [3]:
class FullDatasetPT(Dataset):
    def __init__(self, initial_states_df, pt_dir='data/processed/pt_files'):
        self.data = initial_states_df.reset_index(drop=True)
        self.pt_dir = pt_dir

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        file_id = str(int(row['File ID'])).zfill(5)
        pt_path = os.path.join(self.pt_dir, f"{file_id}.pt")

        if not os.path.exists(pt_path):
            raise FileNotFoundError(f".pt file not found for File ID: {file_id}")

        static_input = torch.tensor(row.drop("File ID").fillna(0.0).values, dtype=torch.float32)
        pt_data = torch.load(pt_path)

        return (
            static_input,
            pt_data["density"],
            pt_data["density_mask"],
            pt_data["goes"],
            pt_data["goes_mask"],
            pt_data["omni2"],
            pt_data["omni2_mask"]
        )



In [11]:
# -----------------------------------
# Positional Encoding for Sequences
# -----------------------------------
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=10000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(-torch.arange(0, d_model, 2) * math.log(10000.0) / d_model)
        pe[:, 0::2] = torch.sin(pos * div_term)
        pe[:, 1::2] = torch.cos(pos * div_term)
        self.pe = pe.unsqueeze(0)

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :].to(x.device)

# -----------------------------------
# STORMTransformer with Mask Handling
# -----------------------------------
class STORMTransformer(nn.Module):
    def __init__(self,
                 static_dim=9,
                 omni2_dim=57,
                 goes_dim=42,
                 d_model=256,
                 output_len=432,
                 nhead=8,
                 num_layers=4,
                 dropout=0.5):
        super().__init__()

        self.static_encoder = nn.Sequential(
            nn.Linear(static_dim, d_model),
            nn.ReLU(),
            nn.LayerNorm(d_model)
        )

        self.omni2_proj = nn.Linear(omni2_dim, d_model)
        self.goes_proj = nn.Linear(goes_dim, d_model)

        self.omni2_pos = PositionalEncoding(d_model)
        self.goes_pos = PositionalEncoding(d_model)

        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dropout=dropout, batch_first=True)
        self.omni2_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.goes_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.fusion = nn.Sequential(
            nn.Linear(d_model * 3, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256,360),
            nn.ReLU(),
            nn.Linear(360, output_len)
        )

    def forward(self, static_input, omni2_seq, goes_seq, omni2_mask=None, goes_mask=None):
        B = static_input.size(0)

        static_embed = self.static_encoder(static_input)

        omni2_embed = self.omni2_proj(omni2_seq)
        omni2_embed = self.omni2_pos(omni2_embed)
        omni2_key_mask = (~omni2_mask.bool()).any(dim=-1) if omni2_mask is not None else None
        omni2_out = self.omni2_encoder(omni2_embed, src_key_padding_mask=omni2_key_mask)
        omni2_summary = omni2_out.mean(dim=1)

        if goes_seq.shape[1] > 1024:
            step = goes_seq.shape[1] // 1024
            goes_seq = goes_seq[:, ::step, :]
            goes_mask = goes_mask[:, ::step, :] if goes_mask is not None else None
        global tester_mask
        global tester_seq
        tester_seq = goes_seq
        tester_mask = goes_mask
        #print(goes_seq,"\n",goes_mask)
        #print(goes_mask.sum())
        goes_embed = self.goes_proj(goes_seq)
        goes_embed = self.goes_pos(goes_embed)
        goes_key_mask = (~goes_mask.bool()).any(dim=-1) if goes_mask is not None else None
        global tester_key_mask
        global tester_embed
        tester_embed = goes_embed
        tester_key_mask = goes_key_mask
        goes_out = self.goes_encoder(goes_embed, src_key_padding_mask=goes_key_mask)
        goes_summary = goes_out.mean(dim=1)

        combined = torch.cat((static_embed, omni2_summary, goes_summary), dim=-1)
        return self.fusion(combined)

# -----------------------------------
# Masked MSE Loss
# -----------------------------------
def masked_mse_loss(preds, targets, mask, eps=1e-8):
    # preds = torch.nan_to_num(preds, nan=0.0, posinf=1e3, neginf=0.0)
    # targets = torch.nan_to_num(targets, nan=0.0, posinf=1e3, neginf=0.0)
    # loss = (preds - targets) ** 2 * mask
    # return loss.sum() / (mask.sum() + eps)
    diff = (targets - preds) * mask
    sq = torch.square(diff)
    sum = torch.sum(sq)
    N = torch.sum(mask)
    # print(sum)
    # print(N)
    loss = torch.sqrt((sum/N))
    return loss

# -----------------------------------
# Full Training Loop with FullDataset
# -----------------------------------

        

In [9]:
def train_storm_transformer(initial_states_df, num_epochs=10, batch_size=8, lr=1e-3, device=None):
    # from full_dataset import FullDataset
    # from storm_transformer import STORMTransformer, masked_mse_loss
    torch.manual_seed(42)
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    os.makedirs("checkpoints", exist_ok=True)

    torch.manual_seed(42)

    # 🔀 Train/validation split
    train_df, val_df = train_test_split(initial_states_df, test_size=0.05, random_state=42)

    train_dataset = FullDatasetPT(train_df)
    val_dataset = FullDatasetPT(val_df)

    # train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    # val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=8, pin_memory=True, )
    val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=8, pin_memory=True)


    model = STORMTransformer().to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    checkpoint_path = "checkpoints/storm_last.pt"
    best_model_path = "checkpoints/storm_best.pt"
    start_epoch = 0
    best_val_loss = float("inf")

    # 🔁 Resume support
    if os.path.exists(checkpoint_path):
        print(f"🔁 Resuming from checkpoint: {checkpoint_path}")
        model.load_state_dict(torch.load(best_model_path, map_location=device))
        # model.load_state_dict(checkpoint['model_state_dict'])
        # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        # start_epoch = checkpoint.get('epoch', 0)
        # best_val_loss = checkpoint.get('val_loss', float("inf"))

    # 🚀 Training loop
    
    for epoch in range(start_epoch, num_epochs):
        model.train()
        total_train_loss = 0.0

        print(f"\n🚀 Epoch {epoch + 1}/{num_epochs}")
        #start_load = time.time()
        # for batch in tqdm(train_loader):
        #     static_input, density, density_mask, goes, goes_mask, omni2, omni2_mask = batch
        #     #nd_load = time.time()

        #     static_input = static_input.to(device)
        #     density = density.to(device)
        #     density_mask = density_mask.to(device)
        #     goes = goes.to(device)
        #     goes_mask = goes_mask.to(device)
        #     omni2 = omni2.to(device)
        #     omni2_mask = omni2_mask.to(device)

        #     # if (omni2_mask.any(dim=-1).sum(dim=1) == 0).any() or (goes_mask.any(dim=-1).sum(dim=1) == 0).any():
        #     #     print("⚠️ Skipping batch with fully masked inputs")
        #     #     continue

        #     optimizer.zero_grad()
        #     preds = model(static_input, omni2, goes, omni2_mask, goes_mask)
        #     #print (preds)
        #     loss = masked_mse_loss(preds, density, density_mask)
        #     loss.backward()
        #     torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        #     optimizer.step()

        #     total_train_loss += loss.item()

        #     #end_batch = time.time()

        #     # print ("Load time:", end_load - start_load )
        #     # print ("Calc time:", end_batch - end_load)

        # avg_train_loss = total_train_loss / len(train_loader)

        # print ("Preds:", preds)
        # print ("Targets",density)
        # 🧪 Validation
        
        total_val_loss = 0.0
        with torch.no_grad():
            for batch_idx, batch in enumerate(val_loader):
                static_input, density, density_mask, goes, goes_mask, omni2, omni2_mask = batch

                static_input = static_input.to(device)
                density = density.to(device)
                density_mask = density_mask.to(device)
                goes = goes.to(device)
                goes_mask = goes_mask.to(device)
                omni2 = omni2.to(device)
                omni2_mask = omni2_mask.to(device)

                if (omni2_mask.any(dim=-1).sum(dim=1) == 0).any() or (goes_mask.any(dim=-1).sum(dim=1) == 0).any():
                    continue

                preds = model(static_input, omni2, goes, omni2_mask, goes_mask)
                loss = masked_mse_loss(preds, density, density_mask)
                total_val_loss += loss.item()

                # 🧠 Mask diagnostics
                goes_mask_sum = goes_mask.sum().item()
                omni2_mask_sum = omni2_mask.sum().item()
                density_mask_sum = density_mask.sum().item()

                print(f"🧪 Eval Batch {batch_idx+1}/{len(val_loader)} — "
                      f"OMNI2 Mask Sum: {omni2_mask_sum} | "
                      f"GOES Mask Sum: {goes_mask_sum} | "
                      f"Density Mask Sum: {density_mask_sum}")

                # ⚠️ Alert if any mask has < 10% coverage
                if omni2_mask_sum < 0.1 * omni2_mask.numel():
                    print("⚠️ Low OMNI2 coverage in this batch!")
                if goes_mask_sum < 0.1 * goes_mask.numel():
                    print("⚠️ Low GOES coverage in this batch!")
                if density_mask_sum < 0.1 * density_mask.numel():
                    print("⚠️ Low density mask coverage in this batch!")

        avg_val_loss = total_val_loss / len(val_loader)
        print(f"\n📊 Epoch {epoch+1}/{num_epochs} — "
              f"Val Loss: {avg_val_loss}")
        
        print ("Preds:", preds)
        print ("Targets",density)

        # 💾 Save full checkpoint
        # torch.save({
        #     'epoch': epoch + 1,
        #     'model_state_dict': model.state_dict(),
        #     'optimizer_state_dict': optimizer.state_dict(),
        #     'val_loss': avg_val_loss
        # }, checkpoint_path)

        # # 💎 Save best model
        # if avg_val_loss < best_val_loss:
        #     best_val_loss = avg_val_loss
        #     torch.save(model.state_dict(), best_model_path)
        #     print("✅ Best model updated.")


In [12]:
train_storm_transformer(initial_states_normalized)

🔁 Resuming from checkpoint: checkpoints/storm_last.pt

🚀 Epoch 1/10
🧪 Eval Batch 22/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5002624.0 | Density Mask Sum: 3364.0
🧪 Eval Batch 28/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5302084.0 | Density Mask Sum: 3319.0
🧪 Eval Batch 30/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5539086.0 | Density Mask Sum: 3358.0
🧪 Eval Batch 47/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 6195294.0 | Density Mask Sum: 3456.0

📊 Epoch 1/10 — Val Loss: 8.811491830320115e-06
Preds: tensor([[ 3.1056e-05,  3.4915e-05,  1.7463e-04,  ...,  2.0950e-04,
          8.5663e-05,  2.0908e-04],
        [-3.5213e-05,  1.6214e-06,  1.5371e-04,  ..., -2.9411e-05,
          1.1948e-04, -8.8779e-05],
        [-3.5213e-05,  1.6214e-06,  1.5371e-04,  ..., -2.9411e-05,
          1.1948e-04, -8.8779e-05],
        ...,
        [-3.5213e-05,  1.6214e-06,  1.5371e-04,  ..., -2.9411e-05,
          1.1948e-04, -8.8779e-05],
        [-3.4222e-05,  2.6450e-06,  1.5587e-04,  ...

Exception in thread Thread-17 (_pin_memory_loop):
Traceback (most recent call last):
  File "/root/miniconda3/envs/runpod_conda/lib/python3.12/threading.py", line 1075, in _bootstrap_inner
    self.run()
  File "/root/miniconda3/envs/runpod_conda/lib/python3.12/site-packages/ipykernel/ipkernel.py", line 766, in run_closure
    _threading_Thread_run(self)
  File "/root/miniconda3/envs/runpod_conda/lib/python3.12/threading.py", line 1012, in run
    self._target(*self._args, **self._kwargs)
  File "/root/miniconda3/envs/runpod_conda/lib/python3.12/site-packages/torch/utils/data/_utils/pin_memory.py", line 59, in _pin_memory_loop
    do_one_step()
  File "/root/miniconda3/envs/runpod_conda/lib/python3.12/site-packages/torch/utils/data/_utils/pin_memory.py", line 35, in do_one_step
    r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/runpod_conda/lib/python3.12/multiprocessing/queues.py", line 122, in ge

KeyboardInterrupt: 