In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
import torchvision.transforms as transformers
from torchvision.datasets import ImageFolder
from tqdm.notebook import tqdm
import sklearn
from sklearn.preprocessing import normalize
from tqdm import tqdm

import glob
import time

import math
import os
from sklearn.model_selection import train_test_split

#import tensorflow as tf

In [2]:
initial_states = []

for dir, sub_dir, files in os.walk("input_data/"):
    for file in sorted(files):
        print(file)
        temp = pd.read_csv(("input_data/"+file),index_col=None, header=0)
        initial_states.append(temp)
#(initial_states)

00000_to_02284-initial_states.csv
02285_to_02357-initial_states.csv
02358_to_04264-initial_states.csv
04265_to_05570-initial_states.csv
05571_to_05614-initial_states.csv
05615_to_06671-initial_states.csv
06672_to_08118-initial_states.csv


In [3]:
initial_states_df = pd.concat(initial_states,axis=0,ignore_index=True)
#initial_states_df

In [4]:
hold = initial_states_df
x = initial_states_df.iloc[:,2:].values
x = normalize(x,norm='l2')
hold = pd.concat([hold['File ID'],pd.DataFrame(x)],axis=1)
initial_states_normalized = hold
#initial_states_normalized

In [5]:
initial_states_normalized['File ID'].dtype

dtype('int64')

In [None]:
#initial_states_normalized = ((initial_states_df.iloc[:,2:] - initial_states_df.iloc[:,2:].mean())/initial_states_df.iloc[:,2:].std())
#initial_states_normalized

In [69]:
dataloader = DataLoader(initial_states_normalized, batch_size=32, shuffle=True)

In [None]:

class StaticToFixedLengthTargetDataset(Dataset):
    def __init__(self, initial_states_df, targets_dir="data/dataset/test/sat_density", desired_len=432):
        self.data = initial_states_df.reset_index(drop=True)
        self.targets_dir = targets_dir
        self.desired_len = desired_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        static_input = torch.tensor(row.drop('File ID').values, dtype=torch.float32)
        file_id = str(row['File ID']).zfill(5)  # Ensure it's 5 characters with leading zeros if needed

        # Correct filename pattern: 6-char satellite code + '-' + File ID + '-'
        pattern = os.path.join(self.targets_dir, f"??????-{file_id}-*.csv")
        matched_files = glob.glob(pattern)

        if len(matched_files) == 0:
            raise FileNotFoundError(f"No target file found for File ID {file_id}")
        elif len(matched_files) > 1:
            raise RuntimeError(f"Multiple files matched for File ID {file_id}: {matched_files}")

        target_path = matched_files[0]

        # Load and process the target file
        target_df = pd.read_csv(target_path)
        density = target_df.iloc[:, 1].values
        actual_len = len(density)

        padded = torch.zeros(self.desired_len, dtype=torch.float32)
        mask = torch.zeros(self.desired_len, dtype=torch.bool)

        if actual_len >= self.desired_len:
            padded[:] = torch.tensor(density[:self.desired_len], dtype=torch.float32)
            mask[:] = 1
        else:
            padded[:actual_len] = torch.tensor(density, dtype=torch.float32)
            mask[:actual_len] = 1

        return static_input, padded, mask


In [36]:
class InitialStateEmbedder(nn.Module):
    def __init__(self, input_dim=9, embedding_dim=128):
        super(InitialStateEmbedder, self).__init__()
        self.embed = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Dropout(0.1),
            nn.Linear(64, embedding_dim),
            nn.ReLU()
        )

    def forward(self, x):
        # x shape: (batch_size, 9)
        return self.embed(x)  # output: (batch_size, embedding_dim)

In [70]:
model = InitialStateEmbedder()

In [71]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [None]:
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epoch = 5 #iterations over all data
train_losses, val_losses = [],[]

model = testModel()
model.to(device)

for epoch in range(num_epoch):
    model.train()
    running_loss = 0.0
    for images, labels in tqdm(train_loader, desc='Train loop'):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
    train_loss = running_loss / len(train_loader.dataset)
    train_losses.append(train_loss)


device(type='cuda', index=0)

In [None]:
# Assumes initial_states_normalized has 8119 rows
num_samples = initial_states_normalized.shape[0]  # should be 8119
output_dim = 3  # Predicting 3-day densities

np.random.seed(1)
# Simulate some dummy values (can be random or zeros)
target_values = pd.DataFrame(np.random.randn(num_samples, output_dim), columns=['day_1', 'day_2', 'day_3'])

# --- Step 2: Convert to Tensors ---
X = torch.tensor(initial_states_normalized.values, dtype=torch.float32)  # shape: (8119, 9)
y = torch.tensor(target_values.values, dtype=torch.float32)              # shape: (8119, 3)

# --- Step 3: Create Dataset and Dataloader ---
dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

# --- Step 4: Define Model ---
class InitialStateEmbedder(nn.Module):
    def __init__(self, input_dim=9, embedding_dim=128):
        super().__init__()
        self.embed = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Dropout(0.1),
            nn.Linear(64, embedding_dim),
            nn.ReLU()
        )

    def forward(self, x):
        return self.embed(x)

input_dim = X.shape[1]
embedding_dim = 128

embedder = InitialStateEmbedder(input_dim=input_dim, embedding_dim=embedding_dim).to(device)
regressor = nn.Linear(embedding_dim, output_dim).to(device)
model = nn.Sequential(embedder, regressor).to(device)


# --- Step 5: Training Setup ---
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# --- Step 6: Training Loop ---
num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0
    for xb, yb in (dataloader):
        xb = xb.to(device)
        yb = yb.to(device)
        optimizer.zero_grad()
        preds = model(xb)
        loss = criterion(preds, yb)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    print(f"Epoch {epoch+1}: Loss = {epoch_loss / len(dataloader):.4f}")

In [39]:


# ============================
# Dataset
# ============================
class StaticToFixedLengthTargetDataset(Dataset):
    def __init__(self, initial_states_df, targets_dir="data/dataset/test/sat_density", desired_len=432):
        self.data = initial_states_df.reset_index(drop=True)
        self.targets_dir = targets_dir
        self.desired_len = desired_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        static_input = torch.tensor(row.drop('File ID').values, dtype=torch.float32)
        file_id = str(int(row['File ID'])).zfill(5)  # Ensure it's 5 digits like "08081"

        # Pattern: <6-char satellite code>-<File ID>-*.csv
        pattern = os.path.join(self.targets_dir, f"??????-{file_id}-*.csv")
        matched_files = glob.glob(pattern)
        #print (matched_files)

        if len(matched_files) == 0:
            raise FileNotFoundError(f"No target file found for File ID {file_id}")
        elif len(matched_files) > 1:
            raise RuntimeError(f"Multiple files matched for File ID {file_id}: {matched_files}")

        target_path = matched_files[0]

        # Load density data from 2nd column
        target_df = pd.read_csv(target_path)
        density = pd.to_numeric(target_df.iloc[:, 1], errors='coerce')
        density = density.clip(lower=0, upper=1e3).fillna(0.0).values
        actual_len = len(density)

        # Pad or truncate to fixed length
        padded = torch.zeros(self.desired_len, dtype=torch.float32)
        mask = torch.zeros(self.desired_len, dtype=torch.bool)
        if actual_len >= self.desired_len:
            padded[:] = torch.tensor(density[:self.desired_len], dtype=torch.float32)
            mask[:] = 1
        else:
            padded[:actual_len] = torch.tensor(density, dtype=torch.float32)
            mask[:actual_len] = 1

        if torch.isnan(padded).any():
            print (f"NaN detected in padded target for File ID {file_id}")

        return static_input, padded, mask

# ============================
# Model
# ============================
class StaticToSequenceModel(nn.Module):
    def __init__(self, input_dim=9, embedding_dim=128, output_len=432):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Dropout(0.1),
            nn.Linear(64, embedding_dim),
            nn.ReLU(),
            nn.Linear(embedding_dim, output_len)
        )

    def forward(self, x):
        return self.model(x)  # (batch_size, 432)

# ============================
# Masked MSE Loss
# ============================
def masked_mse_loss(preds, targets, mask, eps=1e-8):
    preds = torch.nan_to_num(preds, nan=0.0, posinf=1e3, neginf=0.0)
    targets = torch.nan_to_num(targets, nan=0.0, posinf=1e3, neginf=0.0)

    loss = (preds - targets) ** 2
    loss = loss * mask.float()

    denom = mask.sum().clamp(min=1.0)  # avoid division by zero
    return loss.sum() / denom

# ============================
# Training Loop
# ============================
def train_model(initial_states_df, targets_dir="data/dataset/test/sat_density", num_epochs=5, batch_size=64, lr=1e-3):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Dataset and Dataloader
    dataset = StaticToFixedLengthTargetDataset(initial_states_df, targets_dir)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Model, optimizer, loss
    model = StaticToSequenceModel().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    # Training loop
    for epoch in range(num_epochs):
        start = time.time()
        model.train()
        epoch_loss = 0.0
        for static_input, target_seq, mask in dataloader:
            static_input = static_input.to(device)
            target_seq = target_seq.to(device)
            mask = mask.to(device)
            if torch.isnan(static_input).any():
                print("NaN detected in static_input")
            
            

            preds = model(static_input)
            preds = torch.nan_to_num(preds, nan=0.0, posinf=1e3, neginf=0.0)
            loss = masked_mse_loss(preds, target_seq, mask)
            print (loss.item())

            if torch.isnan(preds).any() or torch.isnan(target_seq).any():
                print("⚠️ NaNs found in predictions or targets")
            if torch.isinf(preds).any() or torch.isinf(target_seq).any():
                print("⚠️ Infs found in predictions or targets")

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        avg_loss = epoch_loss / len(dataloader)
        end = time.time()
        print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f} - Time per epoch: {(end - start):.4f}")

train_model(initial_states_normalized, lr=0.01)


28021.69140625
14363.4453125
14090.5
36620.46484375
5777.41455078125
19455.025390625
21918.533203125
17455.666015625
9195.009765625
11026.693359375
10261.505859375
13063.8623046875
24919.052734375
21758.896484375
26406.052734375
28172.142578125
20175.396484375
12272.025390625
26948.126953125
17732.1484375
9222.9208984375
21868.943359375
14417.9111328125
20033.720703125
9644.40234375
21626.513671875
25434.49609375
12749.986328125
17761.0234375
14656.7451171875
12758.6923828125
10766.94921875
27642.2890625
13528.1611328125
19931.908203125
22102.392578125
11143.6962890625
11000.3994140625
24376.427734375
28503.533203125
28824.07421875
11837.4580078125
14904.0009765625
15849.8310546875
24148.212890625
21173.578125
12422.90234375
24034.279296875
31939.107421875
18433.171875
27821.67578125
14767.837890625
25433.240234375
24150.580078125
13648.037109375
23455.94140625
6841.58642578125
13320.1806640625
15932.8212890625
6432.7685546875
12198.787109375
14764.3271484375
16360.794921875
18337.2011

In [6]:
class FullDataset(Dataset):
    def __init__(self, initial_states_df, density_length=432, goes_length=86400, omni2_length=1440, density_dir='data/dataset/test/sat_density', goes_dir="data/dataset/test/goes",
                 omni2_dir="data/dataset/test/omni2"):
        self.data = initial_states_df.reset_index(drop=True)
        self.density_dir = density_dir
        self.goes_dir = goes_dir
        self.omni2_dir = omni2_dir
        self.density_length = density_length
        self.goes_length = goes_length
        self.omni2_length = omni2_length
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        static_input = torch.tensor(row.drop('File ID').fillna(0.0).values, dtype=torch.float32)
        
        file_id = str(int(row['File ID'])).zfill(5)
        
        density_file = glob.glob(os.path.join(self.density_dir, f"*-{file_id}-*.csv"))
        goes_file = glob.glob(os.path.join(self.goes_dir, f"*-{file_id}-*.csv"))
        omni2_file = glob.glob(os.path.join(self.omni2_dir, f"*-{file_id}-*.csv"))

        pos = len(self.density_dir)+1
        density_sat = density_file[0][pos:pos+6]

        density_df = ((pd.read_csv(density_file[0])))
        density_df['Orbit Mean Density (kg/m^3)'] = np.where(density_df['Orbit Mean Density (kg/m^3)']>=1,np.nan,density_df['Orbit Mean Density (kg/m^3)'])
        if density_df.shape[0] > self.density_length:
            density_df = density_df[:self.density_length]
        elif density_df.shape[0] < self.density_length:
            padding = pd.DataFrame(np.empty((self.density_length-density_df.shape[0],2)),columns=density_df.columns)
            padding[:] = np.nan
            density_df = pd.concat((density_df,padding),ignore_index=True)
        density_df_mask = (pd.notnull(density_df)).astype(int)
        density_tensor = torch.tensor(density_df['Orbit Mean Density (kg/m^3)'].fillna(0.0).values, dtype=torch.float32)
        density_df_mask_tensor = torch.tensor(density_df_mask.iloc[:,1].values, dtype=torch.float32)
        density_stacked = torch.stack((density_tensor,density_df_mask_tensor))

        goes_df = pd.read_csv(goes_file[0])
        if goes_df.shape[0] > self.goes_length:
            goes_df = goes_df[goes_df.shape[0]-self.goes_length:goes_df.shape[0]]
        elif goes_df.shape[0] < self.goes_length:
            padding = pd.DataFrame(np.empty((self.goes_length-goes_df.shape[0],43)),columns=goes_df.columns)
            padding[:] = np.nan
            goes_df = pd.concat((padding,goes_df),ignore_index=True)
        goes_mask = (~pd.isnull(goes_df)).astype(int)
        goes_valid_mask = ((goes_df['xrsa_flag'] == 0.0) & (goes_df['xrsb_flag'] == 0.0)).astype(int)
        goes_mask = goes_mask.mul(goes_valid_mask.values,axis=0)
        goes_tensor = torch.tensor(normalize(goes_df.iloc[:, 1:].fillna(0.0).values, norm='l2'), dtype=torch.float32)
        goes_mask_tensor = torch.tensor(goes_mask.iloc[:, 1:].values, dtype=torch.float32)
        #goes_stacked = torch.stack((goes_tensor,goes_mask_tensor))
        
        omni2_df = pd.read_csv(omni2_file[0])
        if omni2_df.shape[0] > self.omni2_length:
            omni2_df = omni2_df[omni2_df.shape[0]-self.omni2_length:omni2_df.shape[0]]
        elif goes_df.shape[0] < self.omni2_length:
            padding = pd.DataFrame(np.empty((self.omni2_length-omni2_df.shape[0],58)),columns=omni2_df.columns)
            padding[:] = np.nan
            omni2_df = pd.concat((padding,omni2_df),ignore_index=True)
        omni2_tensor = torch.tensor(normalize(omni2_df.iloc[:, :57].fillna(0.0).values.astype(float), norm='l2'), dtype=torch.float32)
        omni2_mask = (~pd.isnull(omni2_df)).astype(int)
        omni2_mask_tensor = torch.tensor(omni2_mask.iloc[:, :57].values, dtype=torch.float32) 
        omni2_stacked = torch.stack((omni2_tensor,omni2_mask_tensor))

        return static_input, density_tensor, density_df_mask_tensor, goes_tensor, goes_mask_tensor, omni2_tensor, omni2_mask_tensor#, density_sat



In [35]:
class StormForecastModel(nn.Module):
    def __init__(self, 
                 static_dim=9,
                 goes_dim=42,
                 omni2_dim=57,
                 goes_hidden=64,
                 omni2_hidden=64,
                 static_embed_dim=64,
                 output_len=432):
        super().__init__()

        # Static MLP
        self.static_net = nn.Sequential(
            nn.Linear(static_dim, 64),
            nn.ReLU(),
            nn.Linear(64, static_embed_dim),
            nn.ReLU()
        )

        # Feature encoders
        self.goes_encoder = nn.Linear(goes_dim, goes_hidden)
        self.omni2_encoder = nn.Linear(omni2_dim, omni2_hidden)

        # Final regression
        self.fusion = nn.Sequential(
            nn.Linear(static_embed_dim + goes_hidden + omni2_hidden, 128),
            nn.ReLU(),
            nn.Linear(128, output_len)
        )

    def masked_mean(self, x, mask):
        # x: [B, T, F], mask: [B, T]
        mask = mask.unsqueeze(-1)  # [B, T, 1]
        x = x * mask  # mask out invalid steps
        summed = x.sum(dim=1)  # sum over time
        lengths = mask.sum(dim=1).clamp(min=1e-6)  # avoid divide-by-zero
        return summed / lengths  # [B, F]

    def forward(self, static_input, goes_seq, omni2_seq, goes_mask, omni2_mask):
        # Encode static
        static_embed = self.static_net(static_input)  # [B, static_embed_dim]

        # Encode GOES
        goes_encoded = self.goes_encoder(goes_seq)  # [B, T, H]
        goes_pooled = self.masked_mean(goes_encoded, goes_mask)  # [B, H]

        # Encode OMNI2
        omni2_encoded = self.omni2_encoder(omni2_seq)  # [B, T, H]
        omni2_pooled = self.masked_mean(omni2_encoded, omni2_mask)  # [B, H]

        # Combine and predict
        fused = torch.cat([static_embed, goes_pooled, omni2_pooled], dim=-1)  # [B, total]
        output = self.fusion(fused)  # [B, 432]

        return output


In [7]:
def masked_mse_loss(preds, targets, mask, eps=1e-8):
    loss = (preds - targets) ** 2
    masked_loss = loss * mask.float()
    return masked_loss.sum() / (mask.sum() + eps)

def train(model, dataloader, device, num_epochs=5, lr=1e-4):
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0

        for batch in dataloader:
            static_input, target, target_mask, goes_seq, goes_mask, omni2_seq, omni2_mask, _ = batch

            # Move to GPU
            static_input = static_input.to(device)
            target = target.to(device)
            target_mask = target_mask.to(device)
            goes_seq = goes_seq.to(device)
            goes_mask = goes_mask.to(device)
            omni2_seq = omni2_seq.to(device)
            omni2_mask = omni2_mask.to(device)

            # Forward pass
            preds = model(static_input, goes_seq, omni2_seq, goes_mask, omni2_mask)

            # Loss
            loss = masked_mse_loss(preds, target, target_mask)

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}")

In [None]:
model = StormForecastModel()
dataset = FullDataset(initial_states_normalized)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, drop_last=True)

train(model, dataloader, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))

In [7]:
class FullModel(nn.Module):
    def __init__(self, input_dim=9, embedding_dim=128, goes_len=86400, goes_dim=43, omni2_length=1440, omni2_dim=58, output_len=432):
        super().__init__()
        self.embed = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Dropout(0.1),
            nn.Linear(64, embedding_dim),
            nn.ReLU()
        )

        

    def forward(self,x):
        return (self.embed(x))


In [26]:
num_epochs = 100
batch_size = 64
num_batches = len(initial_states_normalized)/batch_size
dataset = FullDataset(initial_states_normalized)
dataloader = DataLoader(dataset,batch_size=batch_size,shuffle=True)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#dataset[1]
embedder = FullModel().to(device)
regressor = nn.Linear(128, 432).to(device)
model = nn.Sequential(embedder, regressor).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

losses = []



def MaskedRMSE(preds, targets, mask):
    # preds = torch.Tensor.numpy(preds)
    # targets = torch.Tensor.numpy(targets)
    # mask = torch.Tensor.numpy(mask)
    # print (preds)
    # print (targets)
    # print (mask)
    diff = (targets - preds) * mask
    sq = torch.square(diff)
    sum = torch.sum(sq)
    N = torch.sum(mask)
    # print(sum)
    # print(N)
    loss = torch.sqrt((sum/N))
    return loss

for epoch in (range(num_epochs)):
    start = time.time()
    model.eval()
    epoch_loss = 0.0
    for static_input, density_tensor, density_mask_tensor, goes_tensor, goes_mask_tensor, omni2_tensor, omni2_mask_tensor, density_sat in tqdm(dataloader):
        static_input = static_input.to(device)
        density_tensor = density_tensor.to(device)
        density_mask_tensor = density_mask_tensor.to(device)
        preds = model(static_input)
        # print(preds)
        # print (preds.shape)
        # print(f"shapes\nPreds: {preds.shape}, {preds.dtype}\nTargets: {density_tensor.shape}, {density_tensor.dtype}\nMask: {density_mask_tensor.shape}, {density_mask_tensor.dtype}")
        loss = MaskedRMSE(preds,density_tensor, density_mask_tensor)
        #print(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    avg_loss = epoch_loss/ len(dataloader)
    losses.append(avg_loss)
    end = time.time()
    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.10f} - Time per epoch: {(end - start):.4f}")

    
    

  0%|          | 0/127 [00:00<?, ?it/s]

  2%|▏         | 2/127 [00:39<41:11, 19.77s/it]


KeyboardInterrupt: 

In [24]:
start = time.time()
dataloader = DataLoader(dataset,batch_size=1,shuffle=True)
model.eval()
epoch_loss = 0.0
for static_input, density_tensor, density_mask_tensor in tqdm(dataloader):
    static_input = static_input.to(device)
    density_tensor = density_tensor.to(device)
    density_mask_tensor = density_mask_tensor.to(device)
    preds = model(static_input)
    # print(preds)
    # print (preds.shape)
    # print(f"shapes\nPreds: {preds.shape}, {preds.dtype}\nTargets: {density_tensor.shape}, {density_tensor.dtype}\nMask: {density_mask_tensor.shape}, {density_mask_tensor.dtype}")
    loss = MaskedRMSE(preds,density_tensor, density_mask_tensor)
    break

preds.shape
preds

  0%|          | 0/8119 [00:00<?, ?it/s]


tensor([[-3.6412e-04, -3.5223e-04,  3.4084e-04,  4.5743e-04, -3.0997e-04,
          3.8624e-05, -4.9660e-04,  3.6685e-04,  6.6805e-04,  2.3664e-04,
          4.1881e-04, -1.3260e-04,  3.5301e-04, -3.9159e-04,  1.7426e-04,
         -1.5569e-05,  6.6600e-04,  1.7110e-04,  5.4594e-04,  4.4028e-04,
          6.7592e-04, -5.5230e-04, -7.7586e-04, -1.8942e-04, -1.5887e-05,
         -2.2934e-04, -3.3015e-04,  3.7521e-04,  4.8421e-04, -1.0995e-04,
         -7.3139e-04, -1.8615e-04,  6.9783e-04, -6.8153e-04, -5.4119e-04,
         -2.3852e-04,  7.0119e-04,  3.6764e-04,  4.4872e-04,  4.7903e-04,
         -2.2988e-04,  4.9530e-04, -2.3529e-04,  3.1703e-04,  3.2064e-04,
         -2.6860e-04,  3.3155e-05,  1.4689e-04, -3.1956e-04, -1.1548e-04,
         -4.7844e-04,  1.3072e-04,  4.7526e-04, -6.4395e-04, -5.2673e-04,
          2.4633e-04,  3.5200e-04,  5.0917e-04,  4.3526e-04,  2.7104e-04,
          7.5075e-04,  5.0053e-04, -4.9528e-04,  3.5879e-04, -6.7240e-04,
          6.0761e-04, -5.8948e-04,  4.

In [58]:


# -----------------------------------
# Positional Encoding for Sequences
# -----------------------------------
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=10000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(-torch.arange(0, d_model, 2) * math.log(10000.0) / d_model)
        pe[:, 0::2] = torch.sin(pos * div_term)
        pe[:, 1::2] = torch.cos(pos * div_term)
        self.pe = pe.unsqueeze(0)

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :].to(x.device)

# -----------------------------------
# STORMTransformer with Mask Handling
# -----------------------------------
class STORMTransformer(nn.Module):
    def __init__(self,
                 static_dim=9,
                 omni2_dim=57,
                 goes_dim=42,
                 d_model=128,
                 output_len=432,
                 nhead=4,
                 num_layers=2,
                 dropout=0.1):
        super().__init__()

        self.static_encoder = nn.Sequential(
            nn.Linear(static_dim, d_model),
            nn.ReLU(),
            nn.LayerNorm(d_model)
        )

        self.omni2_proj = nn.Linear(omni2_dim, d_model)
        self.goes_proj = nn.Linear(goes_dim, d_model)

        self.omni2_pos = PositionalEncoding(d_model)
        self.goes_pos = PositionalEncoding(d_model)

        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dropout=dropout, batch_first=True)
        self.omni2_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.goes_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.fusion = nn.Sequential(
            nn.Linear(d_model * 3, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, output_len)
        )

    def forward(self, static_input, omni2_seq, goes_seq, omni2_mask=None, goes_mask=None):
        B = static_input.size(0)

        static_embed = self.static_encoder(static_input)

        omni2_embed = self.omni2_proj(omni2_seq)
        omni2_embed = self.omni2_pos(omni2_embed)
        omni2_key_mask = (~omni2_mask.bool()).any(dim=-1) if omni2_mask is not None else None
        omni2_out = self.omni2_encoder(omni2_embed, src_key_padding_mask=omni2_key_mask)
        omni2_summary = omni2_out.mean(dim=1)

        if goes_seq.shape[1] > 1024:
            step = goes_seq.shape[1] // 1024
            goes_seq = goes_seq[:, ::step, :]
            goes_mask = goes_mask[:, ::step, :] if goes_mask is not None else None

        goes_embed = self.goes_proj(goes_seq)
        goes_embed = self.goes_pos(goes_embed)
        goes_key_mask = (~goes_mask.bool()).any(dim=-1) if goes_mask is not None else None
        goes_out = self.goes_encoder(goes_embed, src_key_padding_mask=goes_key_mask)
        goes_summary = goes_out.mean(dim=1)

        combined = torch.cat((static_embed, omni2_summary, goes_summary), dim=-1)
        return self.fusion(combined)

# -----------------------------------
# Masked MSE Loss
# -----------------------------------
def masked_mse_loss(preds, targets, mask, eps=1e-8):
    preds = torch.nan_to_num(preds, nan=0.0, posinf=1e3, neginf=0.0)
    targets = torch.nan_to_num(targets, nan=0.0, posinf=1e3, neginf=0.0)
    loss = (preds - targets) ** 2 * mask
    return loss.sum() / (mask.sum() + eps)

# -----------------------------------
# Full Training Loop with FullDataset
# -----------------------------------
def train_storm_transformer(initial_states_df, num_epochs=3, batch_size=4, lr=1e-4, device=None):
    #from full_dataset import FullDataset  # Ensure your FullDataset class is in a separate file

    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    dataset = FullDataset(initial_states_df)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    model = STORMTransformer().to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    losses = []
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0

        for batch in tqdm(dataloader):
            (static_input, 
             density_tensor, 
             density_mask_tensor, 
             goes_tensor, 
             goes_mask_tensor, 
             omni2_tensor, 
             omni2_mask_tensor, 
             _) = batch

            static_input = static_input.to(device)
            density_tensor = density_tensor.to(device)
            density_mask_tensor = density_mask_tensor.to(device)
            goes_tensor = goes_tensor.to(device)
            goes_mask_tensor = goes_mask_tensor.to(device)
            omni2_tensor = omni2_tensor.to(device)
            omni2_mask_tensor = omni2_mask_tensor.to(device)

            optimizer.zero_grad()
            preds = model(static_input, omni2_tensor, goes_tensor, omni2_mask_tensor, goes_mask_tensor)
            loss = masked_mse_loss(preds, density_tensor, density_mask_tensor)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            total_loss += loss.item()
        batch_average_loss = total_loss/len(dataloader)
        losses.append(batch_average_loss)
        print(f"Epoch {epoch+1}/{num_epochs} - Loss: {batch_average_loss:.15f}")


In [9]:
# -----------------------------------
# Positional Encoding for Sequences
# -----------------------------------
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=10000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(-torch.arange(0, d_model, 2) * math.log(10000.0) / d_model)
        pe[:, 0::2] = torch.sin(pos * div_term)
        pe[:, 1::2] = torch.cos(pos * div_term)
        self.pe = pe.unsqueeze(0)

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :].to(x.device)

# -----------------------------------
# STORMTransformer with Mask Handling
# -----------------------------------
class STORMTransformer(nn.Module):
    def __init__(self,
                 static_dim=9,
                 omni2_dim=57,
                 goes_dim=42,
                 d_model=128,
                 output_len=432,
                 nhead=4,
                 num_layers=2,
                 dropout=0.1):
        super().__init__()

        self.static_encoder = nn.Sequential(
            nn.Linear(static_dim, d_model),
            nn.ReLU(),
            nn.LayerNorm(d_model)
        )

        self.omni2_proj = nn.Linear(omni2_dim, d_model)
        self.goes_proj = nn.Linear(goes_dim, d_model)

        self.omni2_pos = PositionalEncoding(d_model)
        self.goes_pos = PositionalEncoding(d_model)

        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dropout=dropout, batch_first=True)
        self.omni2_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.goes_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.fusion = nn.Sequential(
            nn.Linear(d_model * 3, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, output_len)
        )

    def forward(self, static_input, omni2_seq, goes_seq, omni2_mask=None, goes_mask=None):
        B = static_input.size(0)

        static_embed = self.static_encoder(static_input)

        omni2_embed = self.omni2_proj(omni2_seq)
        omni2_embed = self.omni2_pos(omni2_embed)
        omni2_key_mask = (~omni2_mask.bool()).any(dim=-1) if omni2_mask is not None else None
        omni2_out = self.omni2_encoder(omni2_embed, src_key_padding_mask=omni2_key_mask)
        omni2_summary = omni2_out.mean(dim=1)

        if goes_seq.shape[1] > 1024:
            step = goes_seq.shape[1] // 1024
            goes_seq = goes_seq[:, ::step, :]
            goes_mask = goes_mask[:, ::step, :] if goes_mask is not None else None

        goes_embed = self.goes_proj(goes_seq)
        goes_embed = self.goes_pos(goes_embed)
        goes_key_mask = (~goes_mask.bool()).any(dim=-1) if goes_mask is not None else None
        goes_out = self.goes_encoder(goes_embed, src_key_padding_mask=goes_key_mask)
        goes_summary = goes_out.mean(dim=1)

        combined = torch.cat((static_embed, omni2_summary, goes_summary), dim=-1)
        return self.fusion(combined)

# -----------------------------------
# Masked MSE Loss
# -----------------------------------
def masked_mse_loss(preds, targets, mask):
    diff = (targets - preds) * mask
    sq = torch.square(diff)
    sum = torch.sum(sq)
    N = torch.sum(mask)
    # print(sum)
    # print(N)
    loss = torch.sqrt((sum/N))
    return loss

# -----------------------------------
# Full Training Loop with FullDataset
# -----------------------------------
train_losses =[]
val_losses= []
def train_storm_transformer(initial_states_df, num_epochs=10, batch_size=16, lr=1e-4, device=None):
    #from full_dataset import FullDataset

    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    os.makedirs("checkpoints", exist_ok=True)

    # 🔀 Split into train and val
    train_df, val_df = train_test_split(initial_states_df, test_size=0.05)

    train_dataset = FullDataset(train_df)
    val_dataset = FullDataset(val_df)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    model = STORMTransformer().to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    best_val_loss = float("inf")

    for epoch in range(num_epochs):
        model.train()
        total_train_loss = 0.0

        for batch in tqdm(train_loader):
            static_input, density, density_mask, goes, goes_mask, omni2, omni2_mask = [b.to(device) for b in batch]

            optimizer.zero_grad()
            preds = model(static_input, omni2, goes, omni2_mask, goes_mask)
            loss = masked_mse_loss(preds, density, density_mask)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            total_train_loss += loss.item()

        avg_train_loss = total_train_loss / len(train_loader)

        # 🔍 Validation
        total_val_loss = 0.0
        try:
            with torch.no_grad():
                for batch in tqdm(val_loader):
                    static_input, density, density_mask, goes, goes_mask, omni2, omni2_mask = [b.to(device) for b in batch]
                    preds = model(static_input, omni2, goes, omni2_mask, goes_mask)
                    loss = masked_mse_loss(preds, density, density_mask)
                    total_val_loss += loss.item()
        except:
            print("Didnt work")

        avg_val_loss = total_val_loss / len(val_loader)

        train_losses.append(avg_train_loss)
        val_losses.append(avg_val_loss)
        print(f"Epoch {epoch+1}/{num_epochs} — Train Loss: {avg_train_loss} | Val Loss: {avg_val_loss}")

        # 💾 Save if improved
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            save_path = f"checkpoints/storm_epoch_{epoch+1}.pt"
            torch.save(model.state_dict(), save_path)
            print(f"✅ Model improved. Saved to: {save_path}")


In [10]:
train_storm_transformer(initial_states_normalized,num_epochs=20)


  0%|          | 0/483 [00:00<?, ?it/s]

100%|██████████| 483/483 [26:27<00:00,  3.29s/it]
100%|██████████| 26/26 [01:19<00:00,  3.05s/it]


Epoch 1/20 — Train Loss: 0.011633931810526065 | Val Loss: 0.005352307702056491
✅ Model improved. Saved to: checkpoints/storm_epoch_1.pt


100%|██████████| 483/483 [26:23<00:00,  3.28s/it]
100%|██████████| 26/26 [01:20<00:00,  3.11s/it]


Epoch 2/20 — Train Loss: 0.00382988435008002 | Val Loss: 0.0029437169880391313
✅ Model improved. Saved to: checkpoints/storm_epoch_2.pt


100%|██████████| 483/483 [26:25<00:00,  3.28s/it]
100%|██████████| 26/26 [01:20<00:00,  3.08s/it]


Epoch 3/20 — Train Loss: 0.002602643213103724 | Val Loss: 0.0020691500349829975
✅ Model improved. Saved to: checkpoints/storm_epoch_3.pt


100%|██████████| 483/483 [26:24<00:00,  3.28s/it]
100%|██████████| 26/26 [01:19<00:00,  3.08s/it]


Epoch 4/20 — Train Loss: 0.001390635460076707 | Val Loss: 0.000812962277380463
✅ Model improved. Saved to: checkpoints/storm_epoch_4.pt


100%|██████████| 483/483 [26:27<00:00,  3.29s/it]
100%|██████████| 26/26 [01:19<00:00,  3.05s/it]


Epoch 5/20 — Train Loss: 0.00010495477069448601 | Val Loss: 4.936498586315653e-06
✅ Model improved. Saved to: checkpoints/storm_epoch_5.pt


100%|██████████| 483/483 [26:25<00:00,  3.28s/it]
100%|██████████| 26/26 [01:20<00:00,  3.08s/it]


Epoch 6/20 — Train Loss: 4.595014123037376e-06 | Val Loss: 3.857655288951579e-06
✅ Model improved. Saved to: checkpoints/storm_epoch_6.pt


100%|██████████| 483/483 [26:26<00:00,  3.28s/it]
100%|██████████| 26/26 [01:19<00:00,  3.05s/it]


Epoch 7/20 — Train Loss: 4.83784419065504e-06 | Val Loss: 3.2902886372399876e-06
✅ Model improved. Saved to: checkpoints/storm_epoch_7.pt


100%|██████████| 483/483 [26:26<00:00,  3.28s/it]
100%|██████████| 26/26 [01:19<00:00,  3.07s/it]


Epoch 8/20 — Train Loss: 7.11319532077915e-06 | Val Loss: 6.818886609275628e-06


100%|██████████| 483/483 [26:26<00:00,  3.29s/it]
100%|██████████| 26/26 [01:19<00:00,  3.04s/it]


Epoch 9/20 — Train Loss: 6.3424163616079376e-06 | Val Loss: 4.264409653842449e-06


100%|██████████| 483/483 [26:10<00:00,  3.25s/it]
100%|██████████| 26/26 [01:19<00:00,  3.05s/it]


Epoch 10/20 — Train Loss: 4.224208920827145e-06 | Val Loss: 3.6963406448474134e-06


100%|██████████| 483/483 [26:28<00:00,  3.29s/it]
100%|██████████| 26/26 [01:19<00:00,  3.07s/it]


Epoch 11/20 — Train Loss: 4.336992924059742e-06 | Val Loss: 3.195294250592549e-06
✅ Model improved. Saved to: checkpoints/storm_epoch_11.pt


100%|██████████| 483/483 [26:26<00:00,  3.29s/it]
100%|██████████| 26/26 [01:19<00:00,  3.08s/it]


Epoch 12/20 — Train Loss: 4.374549327121054e-06 | Val Loss: 3.5555588930527356e-06


100%|██████████| 483/483 [26:28<00:00,  3.29s/it]
100%|██████████| 26/26 [01:20<00:00,  3.08s/it]


Epoch 13/20 — Train Loss: 3.3975865500141865e-05 | Val Loss: 3.941598046926979e-06


100%|██████████| 483/483 [26:27<00:00,  3.29s/it]
100%|██████████| 26/26 [01:19<00:00,  3.06s/it]


Epoch 14/20 — Train Loss: 4.338248628370461e-06 | Val Loss: 5.75980201434644e-06


100%|██████████| 483/483 [26:26<00:00,  3.28s/it]
100%|██████████| 26/26 [01:19<00:00,  3.06s/it]


Epoch 15/20 — Train Loss: 4.3511188521180255e-06 | Val Loss: 4.536636459726231e-06


100%|██████████| 483/483 [26:26<00:00,  3.28s/it]
100%|██████████| 26/26 [01:20<00:00,  3.09s/it]


Epoch 16/20 — Train Loss: 4.097751099729119e-06 | Val Loss: 3.5075801966181293e-06


100%|██████████| 483/483 [26:26<00:00,  3.29s/it]
100%|██████████| 26/26 [01:20<00:00,  3.08s/it]


Epoch 17/20 — Train Loss: 4.202842442367153e-06 | Val Loss: 4.47616730525624e-06


 40%|███▉      | 191/483 [10:29<16:02,  3.30s/it]


KeyboardInterrupt: 

In [30]:
train_storm_transformer(initial_states_normalized,num_epochs=20)

100%|██████████| 965/965 [28:15<00:00,  1.76s/it]
 10%|▉         | 5/51 [00:09<01:26,  1.87s/it]


Didnt work
Epoch 1/20 — Train Loss: 0.00026589945274455847 | Val Loss: 0.00000013971743319595
✅ Model improved. Saved to: checkpoints/storm_epoch_1.pt


100%|██████████| 965/965 [28:07<00:00,  1.75s/it]
 10%|▉         | 5/51 [00:09<01:27,  1.90s/it]


Didnt work
Epoch 2/20 — Train Loss: 0.00001941913264528798 | Val Loss: 0.00000025277192354920


 28%|██▊       | 267/965 [07:55<20:41,  1.78s/it]


KeyboardInterrupt: 

In [26]:
train_df, val_df = train_test_split(initial_states_normalized, test_size=0.1, random_state=42)
val_dataset = FullDataset(val_df)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
for epoch in range(10):
    for batch in tqdm(val_loader):
        static_input, density, density_mask, goes, goes_mask, omni2, omni2_mask = [b.to(device) for b in batch]
        preds = model(static_input, omni2, goes, omni2_mask, goes_mask)
        loss = masked_mse_loss(preds, density, density_mask)
        print(loss)
        total_val_loss += loss.item()

  0%|          | 0/26 [00:07<?, ?it/s]


NameError: name 'model' is not defined

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm

def train_storm_transformer(initial_states_df, num_epochs=10, batch_size=8, lr=1e-4, device=None):
    from full_dataset import FullDataset
    from storm_transformer import STORMTransformer, masked_mse_loss

    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    os.makedirs("checkpoints", exist_ok=True)

    # 🔀 Train/validation split
    train_df, val_df = train_test_split(initial_states_df, test_size=0.05, random_state=42)

    train_dataset = FullDataset(train_df)
    val_dataset = FullDataset(val_df)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    model = STORMTransformer().to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    checkpoint_path = "checkpoints/storm_last.pt"
    best_model_path = "checkpoints/storm_best.pt"
    start_epoch = 0
    best_val_loss = float("inf")

    # 🔁 Resume support
    if os.path.exists(checkpoint_path):
        print(f"🔁 Resuming from checkpoint: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint.get('epoch', 0)
        best_val_loss = checkpoint.get('val_loss', float("inf"))

    # 🚀 Training loop
    for epoch in range(start_epoch, num_epochs):
        model.train()
        total_train_loss = 0.0

        print(f"\n🚀 Epoch {epoch + 1}/{num_epochs}")
        for batch in tqdm(train_loader):
            static_input, density, density_mask, goes, goes_mask, omni2, omni2_mask, _ = batch

            static_input = static_input.to(device)
            density = density.to(device)
            density_mask = density_mask.to(device)
            goes = goes.to(device)
            goes_mask = goes_mask.to(device)
            omni2 = omni2.to(device)
            omni2_mask = omni2_mask.to(device)

            if (omni2_mask.any(dim=-1).sum(dim=1) == 0).any() or (goes_mask.any(dim=-1).sum(dim=1) == 0).any():
                print("⚠️ Skipping batch with fully masked inputs")
                continue

            optimizer.zero_grad()
            preds = model(static_input, omni2, goes, omni2_mask, goes_mask)
            loss = masked_mse_loss(preds, density, density_mask)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            total_train_loss += loss.item()

        avg_train_loss = total_train_loss / len(train_loader)

        # 🧪 Validation
        model.eval()
        total_val_loss = 0.0
        with torch.no_grad():
            for batch_idx, batch in enumerate(val_loader):
                static_input, density, density_mask, goes, goes_mask, omni2, omni2_mask, _ = batch

                static_input = static_input.to(device)
                density = density.to(device)
                density_mask = density_mask.to(device)
                goes = goes.to(device)
                goes_mask = goes_mask.to(device)
                omni2 = omni2.to(device)
                omni2_mask = omni2_mask.to(device)

                if (omni2_mask.any(dim=-1).sum(dim=1) == 0).any() or (goes_mask.any(dim=-1).sum(dim=1) == 0).any():
                    continue

                preds = model(static_input, omni2, goes, omni2_mask, goes_mask)
                loss = masked_mse_loss(preds, density, density_mask)
                total_val_loss += loss.item()

                # 🧠 Mask diagnostics
                goes_mask_sum = goes_mask.sum().item()
                omni2_mask_sum = omni2_mask.sum().item()
                density_mask_sum = density_mask.sum().item()

                print(f"🧪 Eval Batch {batch_idx+1}/{len(val_loader)} — "
                      f"OMNI2 Mask Sum: {omni2_mask_sum} | "
                      f"GOES Mask Sum: {goes_mask_sum} | "
                      f"Density Mask Sum: {density_mask_sum}")

                # ⚠️ Alert if any mask has < 10% coverage
                if omni2_mask_sum < 0.1 * omni2_mask.numel():
                    print("⚠️ Low OMNI2 coverage in this batch!")
                if goes_mask_sum < 0.1 * goes_mask.numel():
                    print("⚠️ Low GOES coverage in this batch!")
                if density_mask_sum < 0.1 * density_mask.numel():
                    print("⚠️ Low density mask coverage in this batch!")

        avg_val_loss = total_val_loss / len(val_loader)
        print(f"\n📊 Epoch {epoch+1}/{num_epochs} — "
              f"Train Loss: {avg_train_loss:.20f} | Val Loss: {avg_val_loss:.20f}")

        # 💾 Save full checkpoint
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': avg_val_loss
        }, checkpoint_path)

        # 💎 Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), best_model_path)
            print("✅ Best model updated.")


In [94]:
# -----------------------------------
# Positional Encoding for Sequences
# -----------------------------------
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=10000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(-torch.arange(0, d_model, 2) * math.log(10000.0) / d_model)
        pe[:, 0::2] = torch.sin(pos * div_term)
        pe[:, 1::2] = torch.cos(pos * div_term)
        self.pe = pe.unsqueeze(0)

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :].to(x.device)

# -----------------------------------
# STORMTransformer with Mask Handling
# -----------------------------------
class STORMTransformer(nn.Module):
    def __init__(self,
                 static_dim=9,
                 omni2_dim=57,
                 goes_dim=42,
                 d_model=128,
                 output_len=432,
                 nhead=4,
                 num_layers=2,
                 dropout=0.1):
        super().__init__()

        self.static_encoder = nn.Sequential(
            nn.Linear(static_dim, d_model),
            nn.ReLU(),
            nn.LayerNorm(d_model)
        )

        self.omni2_proj = nn.Linear(omni2_dim, d_model)
        self.goes_proj = nn.Linear(goes_dim, d_model)

        self.omni2_pos = PositionalEncoding(d_model)
        self.goes_pos = PositionalEncoding(d_model)

        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dropout=dropout, batch_first=True)
        self.omni2_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.goes_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.fusion = nn.Sequential(
            nn.Linear(d_model * 3, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, output_len)
        )

    def forward(self, static_input, omni2_seq, goes_seq, omni2_mask=None, goes_mask=None):
        B = static_input.size(0)

        static_embed = self.static_encoder(static_input)

        omni2_embed = self.omni2_proj(omni2_seq)
        omni2_embed = self.omni2_pos(omni2_embed)
        omni2_key_mask = (~omni2_mask.bool()).any(dim=-1) if omni2_mask is not None else None
        omni2_out = self.omni2_encoder(omni2_embed, src_key_padding_mask=omni2_key_mask)
        omni2_summary = omni2_out.mean(dim=1)

        if goes_seq.shape[1] > 1024:
            step = goes_seq.shape[1] // 1024
            goes_seq = goes_seq[:, ::step, :]
            goes_mask = goes_mask[:, ::step, :] if goes_mask is not None else None
        global tester_mask
        global tester_seq
        tester_seq = goes_seq
        tester_mask = goes_mask
        #print(goes_seq,"\n",goes_mask)
        #print(goes_mask.sum())
        goes_embed = self.goes_proj(goes_seq)
        goes_embed = self.goes_pos(goes_embed)
        goes_key_mask = (~goes_mask.bool()).any(dim=-1) if goes_mask is not None else None
        global tester_key_mask
        global tester_embed
        tester_embed = goes_embed
        tester_key_mask = goes_key_mask
        goes_out = self.goes_encoder(goes_embed, src_key_padding_mask=goes_key_mask)
        goes_summary = goes_out.mean(dim=1)

        combined = torch.cat((static_embed, omni2_summary, goes_summary), dim=-1)
        return self.fusion(combined)

# -----------------------------------
# Masked MSE Loss
# -----------------------------------
def masked_mse_loss(preds, targets, mask, eps=1e-8):
    preds = torch.nan_to_num(preds, nan=0.0, posinf=1e3, neginf=0.0)
    targets = torch.nan_to_num(targets, nan=0.0, posinf=1e3, neginf=0.0)
    loss = (preds - targets) ** 2 * mask
    return loss.sum() / (mask.sum() + eps)

# -----------------------------------
# Full Training Loop with FullDataset
# -----------------------------------
train_losses =[]
val_losses= []
def eval_storm_transformer(initial_states_df, num_epochs=10, batch_size=4, lr=1e-4, device=None):
    #from full_dataset import FullDataset

    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    os.makedirs("checkpoints", exist_ok=True)

    # 🔀 Split into train and val
    train_df, val_df = train_test_split(initial_states_df, test_size=0.05)

    train_dataset = FullDataset(train_df)
    val_dataset = FullDataset(val_df)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

    model = STORMTransformer().to(device)
    model.load_state_dict(torch.load("checkpoints/storm_epoch_10.pt", weights_only=True))
    optimizer = optim.Adam(model.parameters(), lr=lr)

    best_val_loss = float("inf")

    for epoch in range(num_epochs):
        

        # 🔍 Validation
        model.train()
        total_val_loss = 0.0
        #try:
        with torch.no_grad():
            for batch in tqdm(val_loader):
                static_input, density, density_mask, goes, goes_mask, omni2, omni2_mask = batch

                static_input = static_input.to(device)
                density = density.to(device)
                density_mask = density_mask.to(device)
                goes = goes.to(device)
                goes_mask = goes_mask.to(device)
                omni2 = omni2.to(device)
                omni2_mask = omni2_mask.to(device)
                valid_goes = goes_mask.any(dim=-1).sum(dim=1) > 0  # (B,)
                valid_omni2 = omni2_mask.any(dim=-1).sum(dim=1) > 0  # (B,)
                valid_mask = valid_goes & valid_omni2

                # Skip batch if any sample is fully masked
                # if not valid_mask.all():
                #     print("⚠️ Skipping eval batch due to fully masked sample(s).")
                #     continue
                preds = model(static_input, omni2, goes, omni2_mask, goes_mask)
                loss = masked_mse_loss(preds, density, density_mask)
                total_val_loss += loss.item()
        #except:
         #   print("Didnt work")

        avg_val_loss = total_val_loss / len(val_loader)

        val_losses.append(avg_val_loss)
        print(f"Epoch {epoch+1}/{num_epochs} | Val Loss: {avg_val_loss}")

        

In [90]:
#Working train loop testing eval

# -----------------------------------
# Positional Encoding for Sequences
# -----------------------------------
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=10000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(-torch.arange(0, d_model, 2) * math.log(10000.0) / d_model)
        pe[:, 0::2] = torch.sin(pos * div_term)
        pe[:, 1::2] = torch.cos(pos * div_term)
        self.pe = pe.unsqueeze(0)

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :].to(x.device)

# -----------------------------------
# STORMTransformer with Mask Handling
# -----------------------------------
class STORMTransformer(nn.Module):
    def __init__(self,
                 static_dim=9,
                 omni2_dim=57,
                 goes_dim=42,
                 d_model=128,
                 output_len=432,
                 nhead=4,
                 num_layers=2,
                 dropout=0.1):
        super().__init__()

        self.static_encoder = nn.Sequential(
            nn.Linear(static_dim, d_model),
            nn.ReLU(),
            nn.LayerNorm(d_model)
        )

        self.omni2_proj = nn.Linear(omni2_dim, d_model)
        self.goes_proj = nn.Linear(goes_dim, d_model)

        self.omni2_pos = PositionalEncoding(d_model)
        self.goes_pos = PositionalEncoding(d_model)

        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dropout=dropout, batch_first=True)
        self.omni2_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.goes_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.fusion = nn.Sequential(
            nn.Linear(d_model * 3, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, output_len)
        )

    def forward(self, static_input, omni2_seq, goes_seq, omni2_mask=None, goes_mask=None):
        B = static_input.size(0)

        static_embed = self.static_encoder(static_input)

        omni2_embed = self.omni2_proj(omni2_seq)
        omni2_embed = self.omni2_pos(omni2_embed)
        omni2_key_mask = (~omni2_mask.bool()).any(dim=-1) if omni2_mask is not None else None
        omni2_out = self.omni2_encoder(omni2_embed, src_key_padding_mask=omni2_key_mask)
        omni2_summary = omni2_out.mean(dim=1)

        if goes_seq.shape[1] > 1024:
            step = goes_seq.shape[1] // 1024
            goes_seq = goes_seq[:, ::step, :]
            goes_mask = goes_mask[:, ::step, :] if goes_mask is not None else None

        goes_embed = self.goes_proj(goes_seq)
        goes_embed = self.goes_pos(goes_embed)
        goes_key_mask = (~goes_mask.bool()).any(dim=-1) if goes_mask is not None else None
        goes_out = self.goes_encoder(goes_embed, src_key_padding_mask=goes_key_mask)
        goes_summary = goes_out.mean(dim=1)

        combined = torch.cat((static_embed, omni2_summary, goes_summary), dim=-1)
        return self.fusion(combined)

# -----------------------------------
# Masked MSE Loss
# -----------------------------------
def masked_mse_loss(preds, targets, mask, eps=1e-8):
    preds = torch.nan_to_num(preds, nan=0.0, posinf=1e3, neginf=0.0)
    targets = torch.nan_to_num(targets, nan=0.0, posinf=1e3, neginf=0.0)
    loss = (preds - targets) ** 2 * mask
    return loss.sum() / (mask.sum() + eps)

# -----------------------------------
# Full Training Loop with FullDataset
# -----------------------------------
def reEval_storm_transformer(initial_states_df, num_epochs=3, batch_size=8, lr=1e-4, device=None):
    #from full_dataset import FullDataset  # Ensure your FullDataset class is in a separate file

    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_df, val_df = train_test_split(initial_states_df, test_size=0.05)

    train_dataset = FullDataset(train_df)
    val_dataset = FullDataset(val_df)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

    model = STORMTransformer().to(device)
    model.load_state_dict(torch.load("checkpoints/storm_epoch_10.pt", weights_only=True))
    optimizer = optim.Adam(model.parameters(), lr=lr)
    losses = []
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        with torch.no_grad():
            for batch in tqdm(val_loader):
                (static_input, 
                density_tensor, 
                density_mask_tensor, 
                goes_tensor, 
                goes_mask_tensor, 
                omni2_tensor, 
                omni2_mask_tensor, 
                _) = batch

                static_input = static_input.to(device)
                density_tensor = density_tensor.to(device)
                density_mask_tensor = density_mask_tensor.to(device)
                goes_tensor = goes_tensor.to(device)
                goes_mask_tensor = goes_mask_tensor.to(device)
                omni2_tensor = omni2_tensor.to(device)
                omni2_mask_tensor = omni2_mask_tensor.to(device)

                preds = model(static_input, omni2_tensor, goes_tensor, omni2_mask_tensor, goes_mask_tensor)
                loss = masked_mse_loss(preds, density_tensor, density_mask_tensor)
                
                total_loss += loss.item()
        batch_average_loss = total_loss/len(val_loader)
        losses.append(batch_average_loss)
        print(f"Epoch {epoch+1}/{num_epochs} - Loss: {batch_average_loss}")


In [83]:
reEval_storm_transformer(initial_states_normalized)

  0%|          | 0/7 [00:17<?, ?it/s]


ValueError: not enough values to unpack (expected 8, got 7)

In [44]:
tester_key_mask_int = tester_key_mask.
if tester_key_mask.any() == False:
    print ("yes")
print ((tester_key_mask.size()[0]*tester_key_mask.size()[1]))
print (np.count_nonzero(tester_key_mask.to("cpu")))

tester_embed.size
# if tester_embed == 0.9454:
#     print (True)

8232
8232


<function Tensor.size>

In [95]:
eval_storm_transformer(initial_states_normalized)

  0%|          | 0/102 [00:00<?, ?it/s]

  0%|          | 0/102 [00:00<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 46.00 MiB. GPU 0 has a total capacity of 7.60 GiB of which 20.00 MiB is free. Process 5350 has 1.27 GiB memory in use. Including non-PyTorch memory, this process has 5.53 GiB memory in use. Of the allocated memory 5.24 GiB is allocated by PyTorch, and 161.55 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)