In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
import torchvision.transforms as transformers
from torchvision.datasets import ImageFolder
from tqdm.notebook import tqdm
import sklearn
from sklearn.preprocessing import normalize
from tqdm import tqdm

import glob
import time

import math
import os
from sklearn.model_selection import train_test_split

In [3]:
initial_states = []

path = "input_data/"

for dir, sub_dir, files in os.walk(path):
    for file in sorted(files):
        #print(file)
        temp = pd.read_csv((path+file),index_col=None, header=0)
        initial_states.append(temp)

initial_states_df = pd.concat(initial_states,axis=0,ignore_index=True)

hold = initial_states_df
x = initial_states_df.iloc[:,2:].values
x = normalize(x,norm='l2')
hold = pd.concat([hold['File ID'],pd.DataFrame(x)],axis=1)
initial_states_normalized = hold
#'2000-08-02 04:50:33'
timestamps = initial_states_df['Timestamp']
#initial_states_df

In [4]:
class FullDatasetPT(Dataset):
    def __init__(self, initial_states_df, pt_dir='data/processed/pt_files'):
        self.data = initial_states_df.reset_index(drop=True)
        self.pt_dir = pt_dir

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        file_id = str(int(row['File ID'])).zfill(5)
        pt_path = os.path.join(self.pt_dir, f"{file_id}.pt")

        if not os.path.exists(pt_path):
            raise FileNotFoundError(f".pt file not found for File ID: {file_id}")

        static_input = torch.tensor(row.drop("File ID").fillna(0.0).values, dtype=torch.float32)
        pt_data = torch.load(pt_path)

        return (
            static_input,
            pt_data["density"],
            pt_data["density_mask"],
            pt_data["goes"],
            pt_data["goes_mask"],
            pt_data["omni2"],
            pt_data["omni2_mask"]
        )



In [None]:
# -----------------------------------
# Positional Encoding for Sequences
# -----------------------------------
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=10000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(-torch.arange(0, d_model, 2) * math.log(10000.0) / d_model)
        pe[:, 0::2] = torch.sin(pos * div_term)
        pe[:, 1::2] = torch.cos(pos * div_term)
        self.pe = pe.unsqueeze(0)

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :].to(x.device)

# -----------------------------------
# STORMTransformer with Mask Handling
# -----------------------------------
class STORMTransformer(nn.Module):
    def __init__(self,
                 static_dim=9,
                 omni2_dim=57,
                 goes_dim=42,
                 d_model=256,
                 output_len=432,
                 nhead=8,
                 num_layers=4,
                 dropout=0.1):
        super().__init__()

        self.static_encoder = nn.Sequential(
            nn.Linear(static_dim, d_model),
            nn.ReLU(),
            nn.LayerNorm(d_model)
        )

        self.omni2_proj = nn.Linear(omni2_dim, d_model)
        self.goes_proj = nn.Linear(goes_dim, d_model)

        self.omni2_pos = PositionalEncoding(d_model)
        self.goes_pos = PositionalEncoding(d_model)

        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dropout=dropout, batch_first=True)
        self.omni2_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.goes_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.fusion = nn.Sequential(
            nn.Linear(d_model * 3, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256,360),
            nn.ReLU(),
            nn.Linear(360, output_len)
        )

    def forward(self, static_input, omni2_seq, goes_seq, omni2_mask=None, goes_mask=None):
        B = static_input.size(0)

        static_embed = self.static_encoder(static_input)

        omni2_embed = self.omni2_proj(omni2_seq)
        omni2_embed = self.omni2_pos(omni2_embed)
        omni2_key_mask = (~omni2_mask.bool()).any(dim=-1) if omni2_mask is not None else None
        omni2_out = self.omni2_encoder(omni2_embed, src_key_padding_mask=omni2_key_mask)
        omni2_summary = omni2_out.mean(dim=1)

        if goes_seq.shape[1] > 1024:
            step = goes_seq.shape[1] // 1024
            goes_seq = goes_seq[:, ::step, :]
            goes_mask = goes_mask[:, ::step, :] if goes_mask is not None else None
        # global tester_mask
        # global tester_seq
        # tester_seq = goes_seq
        # tester_mask = goes_mask
        #print(goes_seq,"\n",goes_mask)
        #print(goes_mask.sum())
        goes_embed = self.goes_proj(goes_seq)
        goes_embed = self.goes_pos(goes_embed)
        goes_key_mask = (~goes_mask.bool()).any(dim=-1) if goes_mask is not None else None
        # global tester_key_mask
        # global tester_embed
        # tester_embed = goes_embed
        # tester_key_mask = goes_key_mask
        goes_out = self.goes_encoder(goes_embed, src_key_padding_mask=goes_key_mask)
        goes_summary = goes_out.mean(dim=1)

        combined = torch.cat((static_embed, omni2_summary, goes_summary), dim=-1)
        return self.fusion(combined)

# -----------------------------------
# Masked MSE Loss
# -----------------------------------
def masked_mse_loss(preds, targets, mask, eps=1e-8):
    # preds = torch.nan_to_num(preds, nan=0.0, posinf=1e3, neginf=0.0)
    # targets = torch.nan_to_num(targets, nan=0.0, posinf=1e3, neginf=0.0)
    # loss = (preds - targets) ** 2 * mask
    # return loss.sum() / (mask.sum() + eps)
    diff = (targets - preds) * mask
    sq = torch.square(diff)
    sum = torch.sum(sq)
    N = torch.sum(mask)
    # print(sum)
    # print(N)
    loss = torch.sqrt((sum/N))
    return loss

# -----------------------------------
# Full Training Loop with FullDataset
# -----------------------------------

        

In [6]:
def train_storm_transformer(initial_states_df, num_epochs=20, batch_size=8, lr=1e-3, device=None):
    # from full_dataset import FullDataset
    # from storm_transformer import STORMTransformer, masked_mse_loss
    torch.manual_seed(42)
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    os.makedirs("checkpoints", exist_ok=True)

    torch.manual_seed(42)

    # 🔀 Train/validation split
    train_df, val_df = train_test_split(initial_states_df, test_size=0.05, random_state=42)

    train_dataset = FullDatasetPT(train_df)
    val_dataset = FullDatasetPT(val_df)

    # train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    # val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=8, pin_memory=True, )
    val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=8, pin_memory=True)


    model = STORMTransformer().to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    checkpoint_path = "checkpoints/storm_last.pt"
    best_model_path = "checkpoints/storm_best.pt"
    start_epoch = 0
    best_val_loss = float("inf")

    # 🔁 Resume support
    if os.path.exists(checkpoint_path):
        print(f"🔁 Resuming from checkpoint: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint.get('epoch', 0)
        best_val_loss = checkpoint.get('val_loss', float("inf"))

    # 🚀 Training loop
    train_loss_history = []
    val_loss_history = []
    for epoch in range(start_epoch, num_epochs):
        model.train()
        total_train_loss = 0.0

        print(f"\n🚀 Epoch {epoch + 1}/{num_epochs}")
        #start_load = time.time()
        for batch in tqdm(train_loader):
            static_input, density, density_mask, goes, goes_mask, omni2, omni2_mask = batch
            #nd_load = time.time()

            static_input = static_input.to(device)
            density = density.to(device)
            density_mask = density_mask.to(device)
            goes = goes.to(device)
            goes_mask = goes_mask.to(device)
            omni2 = omni2.to(device)
            omni2_mask = omni2_mask.to(device)

            # if (omni2_mask.any(dim=-1).sum(dim=1) == 0).any() or (goes_mask.any(dim=-1).sum(dim=1) == 0).any():
            #     print("⚠️ Skipping batch with fully masked inputs")
            #     continue

            optimizer.zero_grad()
            preds = model(static_input, omni2, goes, omni2_mask, goes_mask)
            #print (preds)
            loss = masked_mse_loss(preds, density, density_mask)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            total_train_loss += loss.item()

            #end_batch = time.time()

            # print ("Load time:", end_load - start_load )
            # print ("Calc time:", end_batch - end_load)

        avg_train_loss = total_train_loss / len(train_loader)

        print ("Preds:", preds)
        print ("Targets",density)
        # 🧪 Validation
        
        total_val_loss = 0.0
        with torch.no_grad():
            for batch_idx, batch in enumerate(val_loader):
                static_input, density, density_mask, goes, goes_mask, omni2, omni2_mask = batch

                static_input = static_input.to(device)
                density = density.to(device)
                density_mask = density_mask.to(device)
                goes = goes.to(device)
                goes_mask = goes_mask.to(device)
                omni2 = omni2.to(device)
                omni2_mask = omni2_mask.to(device)

                if (omni2_mask.any(dim=-1).sum(dim=1) == 0).any() or (goes_mask.any(dim=-1).sum(dim=1) == 0).any():
                    continue

                preds = model(static_input, omni2, goes, omni2_mask, goes_mask)
                loss = masked_mse_loss(preds, density, density_mask)
                total_val_loss += loss.item()

                # 🧠 Mask diagnostics
                goes_mask_sum = goes_mask.sum().item()
                omni2_mask_sum = omni2_mask.sum().item()
                density_mask_sum = density_mask.sum().item()

                print(f"🧪 Eval Batch {batch_idx+1}/{len(val_loader)} — "
                      f"OMNI2 Mask Sum: {omni2_mask_sum} | "
                      f"GOES Mask Sum: {goes_mask_sum} | "
                      f"Density Mask Sum: {density_mask_sum}")

                # ⚠️ Alert if any mask has < 10% coverage
                if omni2_mask_sum < 0.1 * omni2_mask.numel():
                    print("⚠️ Low OMNI2 coverage in this batch!")
                if goes_mask_sum < 0.1 * goes_mask.numel():
                    print("⚠️ Low GOES coverage in this batch!")
                if density_mask_sum < 0.1 * density_mask.numel():
                    print("⚠️ Low density mask coverage in this batch!")

        avg_val_loss = total_val_loss / len(val_loader)
        print(f"\n📊 Epoch {epoch+1}/{num_epochs} — "
              f"Train Loss: {avg_train_loss} | Val Loss: {avg_val_loss}")
        
        train_loss_history.append(avg_train_loss)
        val_loss_history.append(avg_val_loss)
        
        print ("Preds:", preds)
        print ("Targets",density)

        # 💾 Save full checkpoint
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': avg_val_loss
        }, checkpoint_path)

        # 💎 Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), f"checkpoints/epoch{epoch}.pt")
            print("✅ Best model updated.")


In [7]:
train_storm_transformer(initial_states_normalized)


🚀 Epoch 1/20


100%|██████████| 965/965 [10:10<00:00,  1.58it/s]


Preds: tensor([[ 3.9402e-05,  8.7712e-05, -2.0693e-05,  1.3167e-04, -1.2110e-04,
         -2.3879e-06, -7.8280e-05,  3.2947e-05,  5.5514e-05, -1.5350e-05,
         -5.2065e-05,  1.5615e-04, -1.2612e-04,  1.8484e-04,  7.7445e-05,
          1.0847e-05,  4.0329e-05,  3.6263e-05, -8.2167e-05, -2.6942e-05,
          1.8129e-05, -1.0129e-04, -3.2103e-06,  3.5432e-06,  9.1679e-06,
         -5.3673e-05,  8.2208e-05, -1.4570e-04, -1.2565e-04, -1.3030e-04,
          1.9823e-04,  1.8151e-04,  4.5472e-07, -1.1252e-04,  1.4195e-04,
         -3.0460e-05,  8.6253e-05,  1.6276e-05, -1.3999e-04, -2.8439e-05,
          8.6628e-05, -1.9561e-04, -2.5705e-04,  1.3993e-04, -1.0626e-04,
          6.7829e-05, -8.6127e-05, -2.6988e-05,  4.3591e-04, -1.1528e-04,
          7.4595e-05, -5.3792e-05,  8.1333e-05,  8.6876e-05, -8.9440e-05,
          4.0028e-06, -3.7494e-05,  5.3923e-05,  1.1966e-04, -1.1676e-04,
         -4.4431e-04, -8.6917e-05,  8.8550e-05,  1.1012e-04,  1.8124e-04,
          6.1877e-06, -6.8946e-

100%|██████████| 965/965 [10:05<00:00,  1.59it/s]

Preds: tensor([[-1.9803e-04, -5.4283e-05, -9.0166e-05, -9.1335e-05,  9.5299e-05,
          4.2382e-05,  9.7277e-05, -1.6482e-04, -9.5487e-05, -6.6299e-05,
          5.3996e-05, -4.3314e-05,  3.9145e-05, -8.3808e-05, -1.4286e-04,
          1.4068e-04,  2.4684e-05, -1.2591e-04,  6.2514e-05,  3.2296e-05,
         -1.4480e-04,  1.1528e-04,  3.4710e-06,  7.3914e-05,  2.5909e-06,
          1.3149e-04, -5.4801e-05,  2.1055e-04,  2.0456e-04, -1.9968e-04,
          6.9177e-05, -4.8129e-05,  1.4184e-04, -3.7926e-05, -2.3176e-04,
          1.1645e-04, -1.7784e-04, -7.5782e-06,  9.4719e-05,  1.4282e-04,
         -3.1719e-05,  1.7431e-04,  8.9809e-05, -1.1435e-04,  1.6168e-04,
         -1.2016e-04,  1.3033e-04,  1.3908e-04,  3.7368e-05,  2.5312e-04,
          1.5328e-04, -1.0392e-04,  2.3158e-05, -1.3578e-04,  4.2924e-04,
          9.8174e-05,  1.3752e-04, -9.4959e-05, -2.1793e-05,  6.3173e-05,
          4.5002e-06,  1.3743e-04, -1.5243e-04, -1.4052e-04,  4.7786e-05,
          1.0997e-04,  7.5115e-




🧪 Eval Batch 22/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5002624.0 | Density Mask Sum: 3364.0
🧪 Eval Batch 28/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5302084.0 | Density Mask Sum: 3319.0
🧪 Eval Batch 30/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5539086.0 | Density Mask Sum: 3358.0
🧪 Eval Batch 47/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 6195294.0 | Density Mask Sum: 3456.0

📊 Epoch 2/20 — Train Loss: 0.00012828812872179074 | Val Loss: 8.368396568888178e-06
Preds: tensor([[ 6.9104e-05,  6.6350e-05,  7.8139e-05,  ..., -8.6531e-05,
         -1.5586e-04,  8.5451e-05],
        [ 6.9104e-05,  6.6350e-05,  7.8139e-05,  ..., -8.6531e-05,
         -1.5586e-04,  8.5451e-05],
        [ 6.9104e-05,  6.6350e-05,  7.8139e-05,  ..., -8.6531e-05,
         -1.5586e-04,  8.5451e-05],
        ...,
        [ 6.9104e-05,  6.6350e-05,  7.8139e-05,  ..., -8.6531e-05,
         -1.5586e-04,  8.5451e-05],
        [ 6.9104e-05,  6.6350e-05,  7.8139e-05,  ..., -8.6531e-05,
         -1.5586

100%|██████████| 965/965 [09:55<00:00,  1.62it/s]

Preds: tensor([[ 1.4379e-04,  8.4663e-05,  8.3341e-05,  1.0823e-04,  1.1055e-05,
          9.2587e-05, -1.9763e-04,  1.8954e-04,  1.9370e-04,  1.1960e-04,
         -7.4925e-05,  2.9549e-04, -1.4198e-04,  1.1839e-04,  2.0299e-04,
         -1.4051e-04,  1.0439e-04,  1.1841e-04, -8.5477e-05,  4.3716e-06,
          5.5665e-05, -1.5455e-04,  7.7692e-05,  7.7064e-05,  2.8956e-05,
         -7.7635e-05,  1.4523e-04,  2.0641e-04, -9.1981e-05,  6.3203e-05,
          1.3457e-04, -5.7407e-05, -2.4009e-06,  4.8260e-05,  1.9459e-04,
         -1.4347e-04, -5.8160e-05,  6.1582e-05, -1.3010e-04, -8.5412e-05,
          2.8198e-05, -1.2271e-04, -7.6370e-05,  8.2036e-05, -7.6968e-05,
          3.8232e-05, -1.2632e-04, -1.6890e-04,  9.2544e-05, -1.1410e-04,
          8.5507e-05,  3.9457e-05, -2.7045e-05,  1.2544e-04, -8.3001e-05,
         -9.6443e-05, -9.3371e-05,  1.0960e-04,  8.5164e-05, -5.0759e-05,
         -1.5580e-04, -1.4605e-04,  1.4361e-04,  8.8897e-05, -1.0307e-04,
         -9.6638e-05, -1.0706e-




🧪 Eval Batch 22/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5002624.0 | Density Mask Sum: 3364.0
🧪 Eval Batch 28/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5302084.0 | Density Mask Sum: 3319.0
🧪 Eval Batch 30/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5539086.0 | Density Mask Sum: 3358.0
🧪 Eval Batch 47/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 6195294.0 | Density Mask Sum: 3456.0

📊 Epoch 3/20 — Train Loss: 0.00011102588035388236 | Val Loss: 7.746736753010647e-06
Preds: tensor([[-5.8353e-05, -1.1126e-04, -3.2766e-05,  ...,  7.7842e-05,
          7.3204e-05, -1.2883e-04],
        [-5.8353e-05, -1.1126e-04, -3.2766e-05,  ...,  7.7842e-05,
          7.3204e-05, -1.2883e-04],
        [-5.8353e-05, -1.1126e-04, -3.2766e-05,  ...,  7.7842e-05,
          7.3204e-05, -1.2883e-04],
        ...,
        [-5.8353e-05, -1.1126e-04, -3.2766e-05,  ...,  7.7842e-05,
          7.3204e-05, -1.2883e-04],
        [-5.8353e-05, -1.1126e-04, -3.2766e-05,  ...,  7.7842e-05,
          7.3204

100%|██████████| 965/965 [09:59<00:00,  1.61it/s]

Preds: tensor([[ 6.4563e-05,  7.6968e-05, -1.8430e-05,  1.3464e-04, -6.7749e-05,
          1.0312e-04, -2.7145e-05,  1.2417e-04,  6.0879e-05,  1.3450e-04,
         -1.2210e-04,  3.3036e-05, -1.0175e-04,  1.0478e-04,  1.1230e-04,
         -9.4414e-05,  5.0670e-05,  9.1913e-05, -5.4251e-05, -7.1621e-05,
          3.0616e-05, -1.5507e-04, -5.1265e-06, -1.8455e-05, -6.1797e-05,
         -3.2221e-05, -1.0536e-04, -6.8728e-05, -1.2008e-04,  8.1906e-05,
          1.0701e-04,  9.1299e-05,  4.8648e-05,  7.0971e-05,  1.1074e-04,
          1.1886e-04, -5.9389e-05,  6.9227e-05, -1.9204e-06, -2.0485e-05,
          9.0255e-05, -1.0029e-04, -1.0155e-04,  1.3579e-04, -1.3323e-04,
          1.0747e-04, -8.3294e-05, -1.8589e-05,  1.4190e-04, -1.1266e-04,
          1.2413e-04,  9.2760e-07, -8.1866e-05,  9.2729e-05, -2.7806e-05,
         -6.6820e-05, -4.1799e-05,  8.6045e-06,  1.3622e-04, -7.9479e-05,
         -1.8873e-04,  7.1042e-05,  1.4346e-04,  5.3324e-05, -1.4813e-04,
         -2.8687e-04, -2.4376e-




🧪 Eval Batch 22/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5002624.0 | Density Mask Sum: 3364.0
🧪 Eval Batch 28/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5302084.0 | Density Mask Sum: 3319.0
🧪 Eval Batch 30/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5539086.0 | Density Mask Sum: 3358.0
🧪 Eval Batch 47/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 6195294.0 | Density Mask Sum: 3456.0

📊 Epoch 4/20 — Train Loss: 0.00010138027950318487 | Val Loss: 6.751192296517319e-06
Preds: tensor([[-6.7867e-05, -3.9009e-05,  4.1678e-05,  ...,  4.5301e-05,
         -2.0127e-05, -1.1757e-04],
        [-6.7867e-05, -3.9009e-05,  4.1678e-05,  ...,  4.5301e-05,
         -2.0127e-05, -1.1757e-04],
        [-6.7867e-05, -3.9009e-05,  4.1678e-05,  ...,  4.5301e-05,
         -2.0127e-05, -1.1757e-04],
        ...,
        [-6.7867e-05, -3.9009e-05,  4.1678e-05,  ...,  4.5301e-05,
         -2.0127e-05, -1.1757e-04],
        [-6.7867e-05, -3.9009e-05,  4.1678e-05,  ...,  4.5301e-05,
         -2.0127

100%|██████████| 965/965 [09:57<00:00,  1.61it/s]

Preds: tensor([[ 8.0593e-05,  6.1983e-05,  1.0965e-04, -1.1377e-04, -7.1138e-05,
          4.1661e-04, -1.2584e-05, -2.5285e-07, -1.1549e-04, -1.1493e-04,
          1.0904e-04,  1.8638e-04,  8.5948e-05, -6.4978e-05, -5.5281e-05,
          2.6411e-04, -6.3810e-05,  2.9357e-05,  2.9520e-04,  2.3905e-04,
          7.9032e-05,  9.1329e-05, -4.0226e-04, -3.3947e-04, -8.2335e-05,
         -4.9220e-05,  1.4485e-04,  8.4018e-05,  4.6834e-05,  2.5255e-04,
         -1.0098e-04,  3.1520e-05,  2.3216e-04, -1.7444e-04,  9.7297e-05,
         -3.2658e-04,  7.1801e-05,  2.2898e-05, -1.7032e-04,  1.1922e-04,
         -1.2579e-05, -4.5596e-05,  4.6020e-05,  4.1791e-05,  2.2881e-05,
          2.5206e-04,  8.2869e-06, -7.8113e-05,  1.3882e-04,  3.3060e-05,
          1.7128e-04,  3.2221e-05, -2.7430e-04, -4.4769e-06, -8.8971e-05,
          5.0608e-05,  1.2078e-04, -8.7139e-06, -4.4078e-05,  3.3826e-06,
          8.7872e-05,  1.7066e-04,  1.1946e-04, -5.2962e-05,  3.2941e-05,
          2.5537e-06, -4.4378e-




🧪 Eval Batch 22/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5002624.0 | Density Mask Sum: 3364.0
🧪 Eval Batch 28/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5302084.0 | Density Mask Sum: 3319.0
🧪 Eval Batch 30/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5539086.0 | Density Mask Sum: 3358.0
🧪 Eval Batch 47/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 6195294.0 | Density Mask Sum: 3456.0

📊 Epoch 5/20 — Train Loss: 9.419138079154464e-05 | Val Loss: 6.127624895752353e-06
Preds: tensor([[ 1.9457e-05, -1.7387e-05, -3.1874e-05,  ..., -1.1096e-04,
         -8.1150e-05,  2.3868e-05],
        [ 1.9457e-05, -1.7387e-05, -3.1874e-05,  ..., -1.1096e-04,
         -8.1150e-05,  2.3868e-05],
        [ 1.9457e-05, -1.7387e-05, -3.1874e-05,  ..., -1.1096e-04,
         -8.1150e-05,  2.3868e-05],
        ...,
        [ 1.9457e-05, -1.7387e-05, -3.1874e-05,  ..., -1.1096e-04,
         -8.1150e-05,  2.3868e-05],
        [ 1.9457e-05, -1.7387e-05, -3.1874e-05,  ..., -1.1096e-04,
         -8.1150e

100%|██████████| 965/965 [10:01<00:00,  1.60it/s]

Preds: tensor([[ 8.4005e-07,  1.3120e-05,  7.9274e-06,  1.3097e-04,  5.2063e-05,
          4.4840e-05,  2.3983e-05,  4.8282e-05, -5.6880e-05, -4.5415e-05,
         -5.6444e-05,  4.0898e-05,  5.2784e-05, -2.2115e-05, -1.3355e-05,
          1.9274e-06,  3.5366e-05, -8.8063e-05, -6.4038e-05,  8.7899e-05,
         -5.0846e-05, -4.2361e-06, -7.3972e-05, -2.9290e-05, -2.6447e-05,
          2.3541e-05, -2.4833e-05,  8.8643e-05,  3.2684e-05, -1.4266e-05,
         -4.1137e-05, -8.1139e-05, -3.9173e-05, -3.1589e-05, -1.2665e-04,
          1.4112e-04, -3.8168e-05, -6.9219e-05, -2.8478e-05,  6.5080e-05,
         -6.9542e-06,  9.7526e-05, -4.3418e-06, -2.8674e-05,  5.7468e-05,
         -1.0898e-04,  4.7991e-06,  9.3282e-05, -2.2171e-04,  2.1441e-05,
          2.5867e-05,  3.4147e-05, -5.5632e-05, -2.7971e-05, -8.2567e-05,
         -1.7554e-05,  8.8472e-05, -5.1636e-05, -1.9381e-05,  1.2081e-05,
          2.3678e-05,  2.4293e-05, -6.7296e-05, -7.8939e-05, -6.6325e-05,
          2.1150e-05,  4.5088e-




🧪 Eval Batch 22/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5002624.0 | Density Mask Sum: 3364.0
🧪 Eval Batch 28/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5302084.0 | Density Mask Sum: 3319.0
🧪 Eval Batch 30/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5539086.0 | Density Mask Sum: 3358.0
🧪 Eval Batch 47/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 6195294.0 | Density Mask Sum: 3456.0

📊 Epoch 6/20 — Train Loss: 8.49898346160524e-05 | Val Loss: 6.758665132984592e-06
Preds: tensor([[-3.5624e-04, -3.6363e-05, -2.7640e-05,  ...,  1.0120e-04,
          6.1245e-05,  1.5176e-04],
        [-3.5624e-04, -3.6363e-05, -2.7640e-05,  ...,  1.0120e-04,
          6.1245e-05,  1.5176e-04],
        [-3.5624e-04, -3.6363e-05, -2.7640e-05,  ...,  1.0120e-04,
          6.1245e-05,  1.5176e-04],
        ...,
        [-3.5624e-04, -3.6363e-05, -2.7640e-05,  ...,  1.0120e-04,
          6.1245e-05,  1.5176e-04],
        [-3.5624e-04, -3.6363e-05, -2.7640e-05,  ...,  1.0120e-04,
          6.1245e-

100%|██████████| 965/965 [09:57<00:00,  1.62it/s]

Preds: tensor([[-6.6938e-05, -7.0119e-05,  4.5160e-05, -1.1307e-04,  1.2832e-04,
         -3.1315e-05,  4.0757e-05,  5.0644e-05, -9.2020e-05, -8.3702e-05,
          1.1743e-04, -5.3272e-05,  3.4509e-05, -5.5522e-05, -3.2225e-05,
         -4.3747e-05,  1.9819e-05, -7.2368e-05,  1.0718e-04,  2.9671e-05,
         -7.5722e-05,  2.5279e-05, -6.0902e-05, -1.8831e-06, -2.7324e-05,
          6.5499e-05, -5.9545e-05,  5.0696e-05,  8.2199e-05, -7.5173e-05,
         -8.2152e-05, -6.7636e-05, -1.0839e-04,  2.2828e-05, -9.5988e-05,
          5.5624e-05, -3.9092e-05,  4.6641e-05,  8.9137e-05,  1.2471e-04,
          1.1421e-04,  8.9085e-05,  1.5351e-04, -1.3820e-05,  9.0006e-05,
         -1.1667e-05,  4.2169e-05,  8.7569e-05, -7.2105e-05,  4.8924e-05,
         -7.6560e-05, -1.7675e-05, -4.3482e-05, -1.6620e-05,  1.3377e-04,
          1.4275e-05,  3.5703e-05, -1.9455e-05, -5.2143e-05,  1.2629e-04,
          5.3655e-05,  8.2793e-05, -7.5152e-05, -5.4933e-05,  1.0076e-04,
          6.6327e-05,  2.2096e-




🧪 Eval Batch 22/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5002624.0 | Density Mask Sum: 3364.0
🧪 Eval Batch 28/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5302084.0 | Density Mask Sum: 3319.0
🧪 Eval Batch 30/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5539086.0 | Density Mask Sum: 3358.0
🧪 Eval Batch 47/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 6195294.0 | Density Mask Sum: 3456.0

📊 Epoch 7/20 — Train Loss: 7.929003709179851e-05 | Val Loss: 6.665065074248203e-06
Preds: tensor([[ 1.0558e-04,  5.9061e-05, -8.0577e-05,  ..., -1.1427e-06,
          5.5878e-05,  9.2570e-05],
        [ 1.0558e-04,  5.9061e-05, -8.0577e-05,  ..., -1.1427e-06,
          5.5878e-05,  9.2570e-05],
        [ 1.0558e-04,  5.9061e-05, -8.0577e-05,  ..., -1.1427e-06,
          5.5878e-05,  9.2570e-05],
        ...,
        [ 1.0558e-04,  5.9061e-05, -8.0577e-05,  ..., -1.1427e-06,
          5.5878e-05,  9.2570e-05],
        [ 1.0558e-04,  5.9061e-05, -8.0577e-05,  ..., -1.1427e-06,
          5.5878e

100%|██████████| 965/965 [10:01<00:00,  1.61it/s]

Preds: tensor([[ 1.0500e-04,  5.5657e-05,  5.8466e-05,  1.5143e-05, -3.1509e-05,
          1.3173e-05, -7.1364e-05,  6.0048e-05,  2.6170e-05,  6.0775e-05,
         -4.4784e-05,  2.3542e-04, -4.1481e-06,  6.4766e-05,  5.2569e-05,
          1.7617e-05, -2.0931e-05,  8.2436e-06, -1.2414e-05,  6.4268e-05,
          5.3212e-05,  1.5655e-05,  4.1386e-05, -3.0741e-05, -2.4605e-05,
          2.8234e-05, -6.6347e-06, -2.3615e-04, -1.3798e-05,  3.8940e-05,
          2.2031e-05,  1.1231e-04,  8.2720e-06,  6.5809e-05,  6.8906e-05,
          8.1654e-05,  6.8903e-05,  1.1258e-05, -6.1877e-05, -4.4221e-05,
         -3.9113e-06, -8.3065e-05, -1.5108e-04, -3.3586e-05, -2.8622e-05,
          5.1701e-05, -5.0472e-05,  6.9849e-07,  1.7784e-04, -8.1059e-05,
          1.2400e-04,  3.8908e-06, -5.4594e-05,  2.5696e-05, -7.8183e-05,
         -3.9489e-05, -5.5115e-05,  2.7298e-05,  9.8352e-05, -1.9473e-04,
         -5.5380e-05,  6.6915e-05,  9.8506e-06,  2.7502e-05, -1.3899e-04,
         -1.2454e-05, -5.7136e-




🧪 Eval Batch 22/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5002624.0 | Density Mask Sum: 3364.0
🧪 Eval Batch 28/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5302084.0 | Density Mask Sum: 3319.0
🧪 Eval Batch 30/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5539086.0 | Density Mask Sum: 3358.0
🧪 Eval Batch 47/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 6195294.0 | Density Mask Sum: 3456.0

📊 Epoch 8/20 — Train Loss: 7.51758041493643e-05 | Val Loss: 5.183457545296965e-06
Preds: tensor([[-1.1614e-04, -3.7353e-05, -6.2612e-05,  ...,  9.8683e-05,
         -1.2056e-04, -3.5147e-05],
        [-1.1614e-04, -3.7353e-05, -6.2612e-05,  ...,  9.8683e-05,
         -1.2056e-04, -3.5147e-05],
        [-1.1614e-04, -3.7353e-05, -6.2612e-05,  ...,  9.8683e-05,
         -1.2056e-04, -3.5147e-05],
        ...,
        [-1.1614e-04, -3.7353e-05, -6.2612e-05,  ...,  9.8683e-05,
         -1.2056e-04, -3.5147e-05],
        [-1.1614e-04, -3.7353e-05, -6.2612e-05,  ...,  9.8683e-05,
         -1.2056e-

100%|██████████| 965/965 [09:59<00:00,  1.61it/s]

Preds: tensor([[ 9.2803e-05,  7.3406e-05, -3.1057e-05,  3.8063e-05, -2.3180e-05,
         -4.9011e-06, -5.2730e-05,  5.5794e-05, -5.7238e-05,  6.1305e-05,
         -1.5559e-05,  9.2620e-05, -4.7183e-05,  4.3695e-05,  2.4846e-05,
         -1.1273e-05,  8.8152e-06,  9.1990e-05,  2.4773e-07,  2.6836e-05,
          4.2180e-05, -9.8115e-05, -1.2282e-05, -1.9584e-07,  3.7481e-05,
         -6.8334e-05,  1.0221e-04, -8.1753e-05, -3.3900e-06,  1.2234e-05,
          3.1555e-05,  1.4428e-05, -1.8359e-05,  7.4742e-05,  5.2311e-05,
          1.1639e-05,  1.1271e-05, -7.0453e-05, -4.3795e-05, -3.1338e-05,
          6.3384e-05, -6.8075e-05,  9.3882e-05,  5.5361e-05, -4.7297e-05,
         -8.9771e-06, -2.7678e-05,  2.2783e-05, -3.7158e-05, -3.3560e-05,
          3.9540e-05, -8.1663e-06,  7.9426e-06,  2.2710e-05, -7.2531e-06,
          2.0060e-05, -2.5341e-06,  2.3949e-05,  4.2187e-05, -1.2565e-04,
         -5.9251e-05, -2.5643e-05,  4.4703e-06,  7.4448e-05, -5.6233e-05,
         -6.0302e-05, -8.6228e-




🧪 Eval Batch 22/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5002624.0 | Density Mask Sum: 3364.0
🧪 Eval Batch 28/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5302084.0 | Density Mask Sum: 3319.0
🧪 Eval Batch 30/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5539086.0 | Density Mask Sum: 3358.0
🧪 Eval Batch 47/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 6195294.0 | Density Mask Sum: 3456.0

📊 Epoch 9/20 — Train Loss: 7.065422231662617e-05 | Val Loss: 6.479187041073672e-06
Preds: tensor([[-2.3611e-04, -9.6397e-05,  4.3830e-05,  ...,  8.7673e-05,
         -9.3366e-05, -9.3970e-05],
        [-2.3611e-04, -9.6397e-05,  4.3830e-05,  ...,  8.7673e-05,
         -9.3366e-05, -9.3970e-05],
        [-2.3611e-04, -9.6397e-05,  4.3830e-05,  ...,  8.7673e-05,
         -9.3366e-05, -9.3970e-05],
        ...,
        [-2.3611e-04, -9.6397e-05,  4.3830e-05,  ...,  8.7673e-05,
         -9.3366e-05, -9.3970e-05],
        [-2.3611e-04, -9.6397e-05,  4.3830e-05,  ...,  8.7673e-05,
         -9.3366e

100%|██████████| 965/965 [10:08<00:00,  1.59it/s]

Preds: tensor([[ 1.2340e-04,  1.0930e-05, -5.6759e-06,  2.3937e-05, -1.2620e-05,
         -9.8501e-06,  9.7882e-06,  1.2134e-05, -7.7950e-05,  3.3522e-05,
         -2.7184e-05, -6.3652e-05, -8.3840e-05,  1.4056e-04, -6.5845e-06,
         -4.1222e-05,  6.9660e-05,  3.2730e-05, -3.5544e-06,  7.4813e-06,
         -4.6957e-06, -2.9545e-05,  1.8254e-06,  5.0020e-05,  1.0733e-05,
         -1.7982e-05,  6.9290e-06, -1.8102e-04, -1.9046e-05,  1.6008e-04,
          2.1266e-05,  5.3917e-05, -4.8670e-06, -3.5010e-05,  5.3961e-06,
         -2.8888e-04, -2.3522e-05, -3.9606e-05, -1.4025e-05, -2.6138e-05,
         -2.7691e-05, -9.6234e-06, -1.0594e-04,  2.0957e-05, -1.3400e-05,
         -4.1180e-05, -1.9702e-05,  2.8156e-05,  1.6220e-05, -9.0133e-06,
         -2.8849e-05,  2.7826e-05,  4.0412e-05, -1.5733e-05,  1.8648e-05,
         -3.3519e-05, -2.5741e-05,  1.0811e-05,  2.0245e-05, -7.2997e-05,
         -2.6831e-06,  2.7332e-05,  1.9279e-05,  8.7423e-06, -7.7896e-06,
         -3.6071e-05, -6.0195e-




🧪 Eval Batch 22/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5002624.0 | Density Mask Sum: 3364.0
🧪 Eval Batch 28/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5302084.0 | Density Mask Sum: 3319.0
🧪 Eval Batch 30/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5539086.0 | Density Mask Sum: 3358.0
🧪 Eval Batch 47/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 6195294.0 | Density Mask Sum: 3456.0

📊 Epoch 10/20 — Train Loss: 6.743530717791767e-05 | Val Loss: 4.583208875356279e-06
Preds: tensor([[ 4.4597e-05, -8.1868e-06,  1.1691e-05,  ...,  1.6415e-05,
          7.5533e-06, -1.5642e-05],
        [ 4.4597e-05, -8.1868e-06,  1.1691e-05,  ...,  1.6415e-05,
          7.5533e-06, -1.5642e-05],
        [ 4.4597e-05, -8.1868e-06,  1.1691e-05,  ...,  1.6415e-05,
          7.5533e-06, -1.5642e-05],
        ...,
        [ 4.4597e-05, -8.1868e-06,  1.1691e-05,  ...,  1.6415e-05,
          7.5533e-06, -1.5642e-05],
        [ 4.4597e-05, -8.1868e-06,  1.1691e-05,  ...,  1.6415e-05,
          7.5533

100%|██████████| 965/965 [09:58<00:00,  1.61it/s]

Preds: tensor([[-6.3417e-05,  8.6506e-06,  2.8937e-05, -8.8898e-05,  3.8968e-05,
         -1.5961e-05,  4.0644e-05, -8.3625e-06, -7.5496e-05, -6.6007e-05,
          1.3629e-05,  3.0550e-05,  1.5114e-05, -2.2062e-05, -3.7149e-05,
         -6.0159e-06,  1.7055e-05, -1.0015e-05,  3.0085e-05, -3.9062e-05,
         -5.0354e-05,  4.5001e-05, -2.0310e-05,  7.4931e-05,  9.8068e-06,
         -6.0080e-06, -7.2968e-05,  6.0193e-05,  6.9033e-05, -7.7769e-05,
         -8.8477e-05, -8.1737e-05,  5.0333e-05,  6.6860e-06, -6.8182e-05,
          5.3041e-05, -3.1842e-06, -5.5569e-05,  5.3640e-05,  4.8081e-05,
         -3.5331e-05,  7.8097e-05,  7.5007e-05, -9.7159e-06,  1.3792e-05,
         -1.5819e-05,  3.2985e-05,  2.0445e-05, -8.1623e-05,  5.0037e-05,
         -5.4159e-05,  1.6675e-05,  3.5990e-05, -1.3634e-05,  5.9319e-05,
          5.0562e-05,  6.2482e-06, -4.4279e-05, -5.9078e-05,  7.4890e-05,
          6.6069e-05,  2.2627e-05, -6.0368e-06, -8.3557e-05,  4.4188e-05,
          4.6102e-05,  6.5329e-




🧪 Eval Batch 22/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5002624.0 | Density Mask Sum: 3364.0
🧪 Eval Batch 28/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5302084.0 | Density Mask Sum: 3319.0
🧪 Eval Batch 30/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5539086.0 | Density Mask Sum: 3358.0
🧪 Eval Batch 47/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 6195294.0 | Density Mask Sum: 3456.0

📊 Epoch 11/20 — Train Loss: 6.456717128646463e-05 | Val Loss: 5.158416267173986e-06
Preds: tensor([[ 9.4842e-05,  1.6107e-05, -4.6796e-05,  ..., -1.1192e-04,
         -4.5142e-05,  8.1727e-05],
        [ 9.4842e-05,  1.6107e-05, -4.6796e-05,  ..., -1.1192e-04,
         -4.5142e-05,  8.1727e-05],
        [ 9.4842e-05,  1.6107e-05, -4.6796e-05,  ..., -1.1192e-04,
         -4.5142e-05,  8.1727e-05],
        ...,
        [ 9.4842e-05,  1.6107e-05, -4.6796e-05,  ..., -1.1192e-04,
         -4.5142e-05,  8.1727e-05],
        [ 9.4842e-05,  1.6107e-05, -4.6796e-05,  ..., -1.1192e-04,
         -4.5142

100%|██████████| 965/965 [10:00<00:00,  1.61it/s]

Preds: tensor([[-3.6769e-06, -2.6201e-05, -1.2440e-05, -4.9130e-05, -4.6589e-05,
         -4.5843e-05,  5.5638e-05,  9.8241e-05, -3.1143e-05, -4.9807e-05,
          2.2290e-05, -6.6972e-05,  4.0860e-05, -6.2148e-05, -3.8483e-05,
          3.4900e-05, -4.0980e-05, -3.3177e-05, -2.6761e-05, -2.8112e-05,
         -7.5067e-05, -5.0709e-05, -5.0083e-05, -9.1703e-05, -3.1579e-05,
         -7.1493e-06, -2.3575e-05,  1.2833e-04,  9.0282e-05, -5.8919e-05,
         -2.1891e-05, -2.3631e-05,  2.9668e-05, -3.6649e-05, -5.2372e-05,
          2.3397e-05, -3.3514e-05, -2.7380e-05,  4.2600e-05,  2.9138e-05,
         -5.7566e-05,  6.0024e-05,  1.4087e-05,  2.9725e-05,  6.2430e-05,
          6.2792e-06,  5.2324e-05,  4.9218e-05, -7.7756e-05, -1.6063e-05,
         -1.9449e-04, -3.4065e-07,  7.2222e-05, -2.9043e-05,  4.3080e-05,
          2.1436e-05,  4.7392e-05,  4.7338e-05, -2.6128e-05,  2.6876e-05,
          3.0303e-05, -5.2938e-05, -3.5942e-05, -3.8842e-05,  9.6969e-06,
          4.9597e-05,  6.6057e-




🧪 Eval Batch 22/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5002624.0 | Density Mask Sum: 3364.0
🧪 Eval Batch 28/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5302084.0 | Density Mask Sum: 3319.0
🧪 Eval Batch 30/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5539086.0 | Density Mask Sum: 3358.0
🧪 Eval Batch 47/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 6195294.0 | Density Mask Sum: 3456.0

📊 Epoch 12/20 — Train Loss: 6.072777951494525e-05 | Val Loss: 4.559993719928624e-06
Preds: tensor([[ 1.9131e-05,  1.7460e-05,  2.5089e-05,  ..., -2.4927e-05,
          2.8173e-05,  4.8647e-05],
        [ 1.9131e-05,  1.7460e-05,  2.5089e-05,  ..., -2.4927e-05,
          2.8173e-05,  4.8647e-05],
        [ 1.9131e-05,  1.7460e-05,  2.5089e-05,  ..., -2.4927e-05,
          2.8173e-05,  4.8647e-05],
        ...,
        [ 1.9131e-05,  1.7460e-05,  2.5089e-05,  ..., -2.4927e-05,
          2.8173e-05,  4.8647e-05],
        [ 1.9131e-05,  1.7460e-05,  2.5089e-05,  ..., -2.4927e-05,
          2.8173

100%|██████████| 965/965 [10:07<00:00,  1.59it/s]

Preds: tensor([[-6.8249e-05, -9.8674e-07,  2.9052e-05, -1.0131e-05,  1.7298e-04,
          2.9746e-06,  1.4040e-06,  1.4807e-05,  1.2036e-05,  3.0627e-05,
          6.3712e-06,  4.1947e-06, -1.2154e-05,  5.2243e-05, -3.6526e-05,
          2.7761e-05, -1.3114e-06, -3.6900e-05, -1.3459e-05,  6.4131e-05,
         -2.2966e-06,  4.1763e-05, -6.9704e-05, -1.7473e-04,  1.3253e-06,
         -2.5678e-05, -2.6142e-05,  1.3143e-04,  8.4133e-05,  9.5621e-05,
          4.6575e-05,  2.0732e-04, -1.0203e-05, -1.0952e-05,  4.1444e-05,
          1.7688e-05,  2.4943e-05,  8.4787e-05, -1.0702e-05, -1.4994e-06,
          7.4172e-06, -9.5153e-06, -2.1607e-07, -8.9117e-06, -3.1516e-06,
         -8.2722e-05,  7.1089e-05, -5.2149e-05, -7.8468e-05,  5.3537e-05,
         -4.4195e-05, -1.5129e-04, -8.4843e-07, -1.6318e-05,  4.7898e-06,
         -6.6156e-06,  2.8232e-05,  7.1972e-05,  2.8959e-05, -1.0395e-05,
         -2.1960e-05,  6.6798e-05,  7.6070e-06, -7.1852e-05,  2.0271e-05,
         -5.6359e-06, -1.0163e-




🧪 Eval Batch 22/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5002624.0 | Density Mask Sum: 3364.0
🧪 Eval Batch 28/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5302084.0 | Density Mask Sum: 3319.0
🧪 Eval Batch 30/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5539086.0 | Density Mask Sum: 3358.0
🧪 Eval Batch 47/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 6195294.0 | Density Mask Sum: 3456.0

📊 Epoch 13/20 — Train Loss: 5.691141643578915e-05 | Val Loss: 4.469117365992062e-06
Preds: tensor([[-2.6390e-05, -2.4053e-05, -1.9168e-05,  ..., -4.5363e-05,
         -4.9888e-05,  2.6260e-05],
        [-2.6390e-05, -2.4053e-05, -1.9168e-05,  ..., -4.5363e-05,
         -4.9888e-05,  2.6260e-05],
        [-2.6390e-05, -2.4053e-05, -1.9168e-05,  ..., -4.5363e-05,
         -4.9888e-05,  2.6260e-05],
        ...,
        [-2.6390e-05, -2.4053e-05, -1.9168e-05,  ..., -4.5363e-05,
         -4.9888e-05,  2.6260e-05],
        [-2.6390e-05, -2.4053e-05, -1.9168e-05,  ..., -4.5363e-05,
         -4.9888

100%|██████████| 965/965 [10:10<00:00,  1.58it/s]

Preds: tensor([[ 6.2879e-05,  4.9745e-05,  3.8075e-05,  8.1615e-05, -1.0882e-05,
          6.7346e-06, -1.6323e-05,  1.8209e-05,  6.3027e-05,  3.4903e-05,
         -6.4393e-05,  4.9928e-05, -5.3164e-05,  2.5342e-05,  5.1668e-05,
         -3.6456e-05,  1.4290e-05,  2.7730e-06, -5.6043e-05,  2.9031e-05,
          2.1285e-05,  6.3870e-06,  2.3968e-06, -1.6285e-05,  8.0577e-05,
         -4.7302e-05,  1.5086e-04, -6.8517e-05, -6.0548e-05,  3.1639e-05,
         -2.5190e-05,  1.1010e-04, -1.4961e-05,  9.1504e-05,  4.9037e-05,
         -6.7004e-05,  1.8956e-05,  1.1263e-05, -3.6392e-05, -4.9088e-05,
          1.9064e-05, -7.5192e-05, -1.2258e-04,  2.7055e-05, -4.8945e-05,
          2.2063e-06, -2.6735e-05, -5.9809e-05,  1.1282e-04, -4.5468e-05,
          1.1104e-04,  3.6549e-06, -6.4785e-05,  3.4416e-05, -7.3393e-05,
         -6.6351e-05, -4.7555e-05,  1.6820e-05,  5.2234e-05, -1.1742e-04,
         -6.0505e-05, -1.7607e-05, -1.4246e-04,  1.8931e-04, -5.6098e-05,
         -5.0955e-05, -1.8941e-




🧪 Eval Batch 22/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5002624.0 | Density Mask Sum: 3364.0
🧪 Eval Batch 28/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5302084.0 | Density Mask Sum: 3319.0
🧪 Eval Batch 30/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 5539086.0 | Density Mask Sum: 3358.0
🧪 Eval Batch 47/51 — OMNI2 Mask Sum: 656640.0 | GOES Mask Sum: 6195294.0 | Density Mask Sum: 3456.0

📊 Epoch 14/20 — Train Loss: 5.641426763891175e-05 | Val Loss: 4.098086470085214e-06
Preds: tensor([[-6.9173e-05, -3.3191e-05, -3.8318e-05,  ...,  8.2931e-05,
         -3.5743e-05,  1.3523e-06],
        [-6.9173e-05, -3.3191e-05, -3.8318e-05,  ...,  8.2931e-05,
         -3.5743e-05,  1.3523e-06],
        [-6.9173e-05, -3.3191e-05, -3.8318e-05,  ...,  8.2931e-05,
         -3.5743e-05,  1.3523e-06],
        ...,
        [-6.9173e-05, -3.3191e-05, -3.8318e-05,  ...,  8.2931e-05,
         -3.5743e-05,  1.3523e-06],
        [-6.9173e-05, -3.3191e-05, -3.8318e-05,  ...,  8.2931e-05,
         -3.5743

 89%|████████▉ | 857/965 [08:59<01:07,  1.59it/s]
Exception in thread Thread-45 (_pin_memory_loop):
Traceback (most recent call last):
  File "/root/miniconda3/envs/runpod_conda/lib/python3.12/threading.py", line 1075, in _bootstrap_inner
    self.run()
  File "/root/miniconda3/envs/runpod_conda/lib/python3.12/site-packages/ipykernel/ipkernel.py", line 766, in run_closure
    _threading_Thread_run(self)
  File "/root/miniconda3/envs/runpod_conda/lib/python3.12/threading.py", line 1012, in run
    self._target(*self._args, **self._kwargs)
  File "/root/miniconda3/envs/runpod_conda/lib/python3.12/site-packages/torch/utils/data/_utils/pin_memory.py", line 59, in _pin_memory_loop
    do_one_step()
  File "/root/miniconda3/envs/runpod_conda/lib/python3.12/site-packages/torch/utils/data/_utils/pin_memory.py", line 35, in do_one_step
    r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/miniconda3/envs/runpod_conda/lib/pyth

KeyboardInterrupt: 