In [2]:
import os
import pandas as pd
import numpy as np
import torch
import glob
import re
from sklearn.preprocessing import normalize
from tqdm import tqdm

# ---------------------------
# Directory Configuration
# ---------------------------
density_dir = 'data/dataset/test/sat_density'
goes_dir = 'data/dataset/test/goes'
omni2_dir = 'data/dataset/test/omni2'
output_dir = 'data/processed/pt_files'
os.makedirs(output_dir, exist_ok=True)

density_length = 432
goes_length = 86400
omni2_length = 1440

# ---------------------------
# Utility Functions
# ---------------------------
def extract_file_id(filename):
    match = re.search(r"-(\d{5})-", filename)
    return match.group(1) if match else None

def pad_df(df, target_len, cols, pad_val=np.nan):
    if df.shape[0] > target_len:
        return df[:target_len]
    elif df.shape[0] < target_len:
        padding = pd.DataFrame(pad_val, index=range(target_len - df.shape[0]), columns=cols)
        df = pd.concat([df, padding], ignore_index=True)
    return df

# ---------------------------
# Main Conversion Loop
# ---------------------------
print("📦 Processing and saving .pt files...")
failed_ids = []

for density_file in tqdm(glob.glob(os.path.join(density_dir, "*.csv"))):
    file_id = extract_file_id(os.path.basename(density_file))

    if not file_id or not file_id.isdigit():
        print(f"⚠️ Skipping malformed file name: {os.path.basename(density_file)}")
        continue

    pt_path = os.path.join(output_dir, f"{file_id}.pt")
    if os.path.exists(pt_path):
        continue  # Already processed

    try:
        # -----------------
        # Density
        # -----------------
        density_df = pd.read_csv(density_file)
        density_df['Orbit Mean Density (kg/m^3)'] = np.where(
            density_df['Orbit Mean Density (kg/m^3)'] >= 1,
            np.nan,
            density_df['Orbit Mean Density (kg/m^3)']
        )
        density_df = pad_df(density_df, density_length, density_df.columns)
        density_tensor = torch.tensor(density_df['Orbit Mean Density (kg/m^3)'].fillna(0.0).values, dtype=torch.float32)
        density_mask = torch.tensor(pd.notnull(density_df['Orbit Mean Density (kg/m^3)']).astype(float).values, dtype=torch.float32)

        # -----------------
        # GOES
        # -----------------
        goes_match = glob.glob(os.path.join(goes_dir, f"*-{file_id}-*.csv"))
        if not goes_match:
            print(f"⚠️ Missing GOES file for File ID {file_id}")
            failed_ids.append(file_id)
            continue
        goes_df = pd.read_csv(goes_match[0])
        goes_df = pad_df(goes_df, goes_length, goes_df.columns)
        goes_valid = ((goes_df['xrsa_flag'] == 0.0) & (goes_df['xrsb_flag'] == 0.0)).astype(int)
        goes_mask = (~pd.isnull(goes_df)).astype(int).mul(goes_valid.values, axis=0)
        goes_tensor = torch.tensor(normalize(goes_df.iloc[:, 1:].fillna(0.0).values, norm='l2'), dtype=torch.float32)
        goes_mask_tensor = torch.tensor(goes_mask.iloc[:, 1:].values, dtype=torch.float32)

        # -----------------
        # OMNI2
        # -----------------
        omni2_match = glob.glob(os.path.join(omni2_dir, f"*-{file_id}-*.csv"))
        if not omni2_match:
            print(f"⚠️ Missing OMNI2 file for File ID {file_id}")
            failed_ids.append(file_id)
            continue
        omni2_df = pd.read_csv(omni2_match[0])
        omni2_df = pad_df(omni2_df, omni2_length, omni2_df.columns)
        omni2_tensor = torch.tensor(normalize(omni2_df.iloc[:, :57].fillna(0.0).values.astype(float), norm='l2'), dtype=torch.float32)
        omni2_mask_tensor = torch.tensor((~pd.isnull(omni2_df.iloc[:, :57])).astype(float).values, dtype=torch.float32)

        # -----------------
        # Save to .pt
        # -----------------
        data_dict = {
            "file_id": file_id,
            "density": density_tensor,
            "density_mask": density_mask,
            "goes": goes_tensor,
            "goes_mask": goes_mask_tensor,
            "omni2": omni2_tensor,
            "omni2_mask": omni2_mask_tensor
        }

        torch.save(data_dict, pt_path)

    except Exception as e:
        print(f"❌ Error processing File ID {file_id}: {e}")
        failed_ids.append(file_id)

# ---------------------------
# Done
# ---------------------------
print(f"\n✅ Processing complete. Saved .pt files in: {output_dir}")
if failed_ids:
    print(f"\n❌ {len(failed_ids)} files failed:")
    for fid in failed_ids:
        print(f" - {fid}")


📦 Processing and saving .pt files...


  0%|          | 0/8119 [00:00<?, ?it/s]

100%|██████████| 8119/8119 [41:56<00:00,  3.23it/s]  


✅ Processing complete. Saved .pt files in: data/processed/pt_files



