In [3]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from netCDF4 import Dataset, num2date
import os
import random
import copy
import csv
import pandas as pd
from timeit import default_timer
from fourier_2d import Net2d

# --------------------
# General Setup
# --------------------

deBug = True
seed_num = 1
np.random.seed(seed_num)
random.seed(seed_num)
torch.manual_seed(seed_num)
torch.cuda.manual_seed_all(seed_num)
torch.cuda.is_available()
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda")
model_name = "FullField"
lead_time_width = 2

# --------------------
# Data Directories
# --------------------
mdl_directory = "/gpfs/gibbs/project/lu_lu/ax59/MJO_Project/Data/AllLeadTms/"
mdl_out_dir = "Result_01/"
obs_directory = "/gpfs/gibbs/project/lu_lu/ax59/MJO_Project/Data/AllLeadTms/"
obs_out_dir = "Result_02/"
bcor_directory = "Result_03/"
time_dir = "Data/Time/"

In [4]:
# Initializations

import matplotlib.pyplot as plt
import numpy as np
import torch
from netCDF4 import Dataset, num2date
import os
import random
import copy
import csv
import pandas as pd
from timeit import default_timer
from fourier_2d import Net2d

# --------------------
# General Setup
# --------------------

deBug = True
seed_num = 1
np.random.seed(seed_num)
random.seed(seed_num)
torch.manual_seed(seed_num)
torch.cuda.manual_seed_all(seed_num)
torch.cuda.is_available()
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda")
model_name = "FullField"
lead_time_width = 2

# --------------------
# Data Directories
# --------------------
mdl_directory = "/gpfs/gibbs/project/lu_lu/ax59/MJO_Project/Data/AllLeadTms/"
mdl_out_dir = "Result_01/"
obs_directory = "/gpfs/gibbs/project/lu_lu/ax59/MJO_Project/Data/AllLeadTms/"
obs_out_dir = "Result_02/"
bcor_directory = "Result_03/"
time_dir = "Data/Time/"

# --------------------
# Section 0: Class and Helper Function Definitions
# --------------------

class MDLDataset(torch.utils.data.Dataset):
    def __init__(self, mdl_dir, time_dir, lead_time):
        self.lead_time = lead_time
        
        # Load Data
        variable_names = ["TMQ", "FLUT", "U200", "U850", "TREFHT"]
        time_objects = None
        lead_time_vars = []
        
        for i, var_name in enumerate(variable_names):
            file_path = os.path.join(mdl_dir, f"CML2025_Step0C_TROP30_MDL_remapped_90x180_daily_DJFM_Anom_nonFltr_{var_name}_leadTm1.nc")
            with Dataset(file_path) as f:
                lead_time_vars.append(f.variables[var_name][:])
                if i == 0:
                    time_var = f.variables['time']
                    time_objects = num2date(time_var[:], units=time_var.units, calendar=time_var.calendar)
                    
        all_features_raw = np.stack(lead_time_vars, axis=1)
        self.times = time_objects 
        print(f"MDL raw data loaded. {len(all_features_raw)} total entries.")

        # Boundaries of each run
        run_boundaries = np.where(self.times[1:] < self.times[:-1])[0]
        
        # Create a list of start/end indices for each run
        run_chunks_indices = []
        start_idx = 0
        for boundary_idx in run_boundaries:
            end_idx = boundary_idx + 1 
            run_chunks_indices.append((start_idx, end_idx))
            start_idx = end_idx 
        run_chunks_indices.append((start_idx, len(all_features_raw))) 
        
        print(f"Found {len(run_chunks_indices)} independent runs in the file.")

        # Process runs individually
        self.valid_pairs = []
        for run_start, run_end in run_chunks_indices:
            run_features = all_features_raw[run_start:run_end]
            run_times = self.times[run_start:run_end]

            # Find seasonal gaps
            chunk_start_idx_seasonal = 0
            # We need at least 5 days to make one sample (t-2, t-1, t, gap, t+2)
            if len(run_times) < 5: continue 

            seasonal_gaps = np.where((run_times[1:] - run_times[:-1]).astype('timedelta64[D]').astype(int) > 1)[0]
            
            chunk_end_indices_seasonal = list(seasonal_gaps)
            chunk_end_indices_seasonal.append(len(run_times) - 1)

            for gap_idx_seasonal in chunk_end_indices_seasonal:
                chunk_end_slice = gap_idx_seasonal + 1
                
                season_features = run_features[chunk_start_idx_seasonal:chunk_end_slice]
                season_times = run_times[chunk_start_idx_seasonal:chunk_end_slice]
                
                # Create input-target pairs
                num_in_season = len(season_features)
                
                max_start_idx = num_in_season - (2 + self.lead_time)
                
                for i in range(max_start_idx): 
                    # Input: concatenate t-2, t-1, t
                    t_minus_2 = season_features[i]
                    t_minus_1 = season_features[i+1]
                    t_0       = season_features[i+2]
                    
                    feature = np.concatenate([t_minus_2, t_minus_1, t_0], axis=0)
                    
                    # Target: t + lead_time
                    target_idx = i + 2 + self.lead_time
                    target = season_features[target_idx]
                    
                    # Metadata for debugging
                    feature_time = season_times[i + 2] # t
                    target_time = season_times[target_idx]
                    
                    self.valid_pairs.append((feature, target, feature_time, target_time))
                
                chunk_start_idx_seasonal = gap_idx_seasonal + 1

        print(f"MDL dataset fully processed. Found {len(self.valid_pairs)} valid input-target pairs.")

    def __len__(self):
        return len(self.valid_pairs)

    def __getitem__(self, idx):
        feature, target, feature_time, target_time = self.valid_pairs[idx]
        
        return (
            torch.tensor(feature, dtype=torch.float32).permute(1, 2, 0),
            torch.tensor(target, dtype=torch.float32).permute(1, 2, 0),
        )

class OBSDataset(torch.utils.data.Dataset):
    def __init__(self, obs_dir, time_dir, lead_time, indices=None):
        self.lead_time = lead_time

        # Load Data
        variable_names = ["tcwv", "olr", "u200", "u850", "trefht"]
        time_objects = None
        lead_time_vars = []
        
        for i, var_name in enumerate(variable_names):
            file_path = os.path.join(obs_dir, f"CML2025_Step0C_TROP30_OBS_remapped_90x180_daily_DJFM_Anom_nonFltr_{var_name}_leadTm1.nc")
            with Dataset(file_path) as f:
                lead_time_vars.append(f.variables[var_name][:])
                if i == 0:
                    time_var = f.variables['time']
                    time_objects = num2date(time_var[:], units=time_var.units, calendar=time_var.calendar)
        
        all_features_full = np.stack(lead_time_vars, axis=1)
        all_times_full = time_objects

        # Subset indices
        if indices is not None:
            all_features = all_features_full[indices]
            self.times = all_times_full[indices] 
            print(f"OBS dataset subset applied. Processing {len(all_features)} time steps for this fold.")
        else:
            all_features = all_features_full
            self.times = all_times_full
            print(f"OBS dataset loaded. Processing {len(all_features)} time steps.")

        # Find gaps
        self.valid_pairs = []
        chunk_start_idx = 0
        
        if len(self.times) == 0:
            print("Warning: No data to process (empty indices). Dataset will be empty.")
            return

        gaps = np.where((self.times[1:] - self.times[:-1]).astype('timedelta64[D]').astype(int) > 1)[0]

        for gap_idx in np.append(gaps, len(self.times) - 1):
            chunk_end_idx = gap_idx + 1
            
            chunk_features = all_features[chunk_start_idx:chunk_end_idx]
            chunk_times = self.times[chunk_start_idx:chunk_end_idx]
            
            # Create input-target pairs
            num_in_chunk = len(chunk_features)
            
            # Require enough frames for t-2, t-1, t ... gap ... t+lead
            max_start_idx = num_in_chunk - (2 + self.lead_time)
            
            if max_start_idx > 0:
                for i in range(max_start_idx): 
                    # Input: concatenate t-2, t-1, t
                    t_minus_2 = chunk_features[i]
                    t_minus_1 = chunk_features[i+1]
                    t_0       = chunk_features[i+2]
                    
                    feature = np.concatenate([t_minus_2, t_minus_1, t_0], axis=0)
                    
                    # Target: t + lead_time
                    target_idx = i + 2 + self.lead_time
                    target = chunk_features[target_idx]
                    
                    feature_time = chunk_times[i+2]
                    target_time = chunk_times[target_idx]
                    
                    self.valid_pairs.append((feature, target, feature_time, target_time))
            
            chunk_start_idx = chunk_end_idx
            
        print(f"OBS dataset processed. Found {len(self.valid_pairs)} valid input-target pairs.")

    def __len__(self):
        return len(self.valid_pairs)

    def __getitem__(self, idx):
        feature, target, feature_time, target_time = self.valid_pairs[idx]
        
        return (
            torch.tensor(feature, dtype=torch.float32).permute(1, 2, 0),
            torch.tensor(target, dtype=torch.float32).permute(1, 2, 0),
        )

In [5]:
# Load dataset
mdl_dataset = MDLDataset(mdl_dir=mdl_directory, time_dir=time_dir, lead_time=lead_time_width)

# Split into training and testing sets
mdl_total_samples = len(mdl_dataset)
mdl_train_size = int(0.8 * mdl_total_samples) - int(0.8 * mdl_total_samples) % 121
mdl_test_size = mdl_total_samples - mdl_train_size
mdl_train_dataset, mdl_test_dataset = torch.utils.data.random_split(mdl_dataset, [mdl_train_size, mdl_test_size])

# Create DataLoaders
batch_size = 121
train_dataloader = torch.utils.data.DataLoader(mdl_train_dataset, batch_size=batch_size, shuffle=True, pin_memory=False, num_workers=4)
test_dataloader = torch.utils.data.DataLoader(mdl_test_dataset, batch_size=batch_size, shuffle=False, pin_memory=False, num_workers=4)

Loading MDL dataset... Target will be t+2 (Input: t-2, t-1, t)
MDL raw data loaded. 43076 total entries.
Found 8 independent runs in the file.
MDL dataset fully processed. Found 41652 valid input-target pairs.




In [6]:
batch_data, target_data = next(iter(train_dataloader))

batch_data.shape

torch.Size([121, 30, 180, 15])

In [7]:
obs_dataset = OBSDataset(obs_dir=obs_directory, time_dir=time_dir, lead_time=lead_time_width)
total_samples = len(obs_dataset)

train_test_split = 0.8
obs_train_dataset, obs_test_dataset = torch.utils.data.random_split(obs_dataset, [int(train_test_split * total_samples), total_samples - int(train_test_split * total_samples)])

batch_size = 121
train_dataloader = torch.utils.data.DataLoader(obs_train_dataset, batch_size=batch_size, shuffle=True, pin_memory=False, num_workers=4)
val_dataloader = torch.utils.data.DataLoader(obs_test_dataset, batch_size=batch_size, shuffle=True, pin_memory=False, num_workers=4)

Loading OBS dataset... Target will be t+2 (Input: t-2, t-1, t)
OBS dataset loaded. Processing 5092 time steps.
OBS dataset processed. Found 4924 valid input-target pairs.


In [8]:
batch_data, target_data = next(iter(train_dataloader))

batch_data.shape

torch.Size([121, 30, 180, 15])

In [11]:
modes = 12
width = 32
input_channels = 15

model = Net2d(modes, width, in_channels=input_channels, out_channels = 5).to(device)

lr = 0.001
epochs = 500
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
#scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, threshold=0.0001, factor=0.5, mode='min')
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min = 1e-6)
loss_fn = torch.nn.MSELoss()

# Initial evaluation
train_losses = []
test_losses = []

with torch.no_grad():
    model.eval()
    total_train_loss = 0.0
    for batch_data, batch_target in train_dataloader:
        batch_data = batch_data.to(device)
        batch_target = batch_target.to(device)

        y = model(batch_data)
        loss = loss_fn(y, batch_target)
        total_train_loss += loss.item() * batch_data.size(0)
    avg_train_loss = total_train_loss / len(train_dataloader.dataset)
    total_test_loss = 0.0

    for batch_data, batch_target in test_dataloader:
        batch_data = batch_data.to(device)
        batch_target = batch_target.to(device)

        pred_y = model(batch_data)
        loss = loss_fn(pred_y, batch_target)
        total_test_loss += loss.item() * batch_data.size(0)

    avg_test_loss = total_test_loss / len(test_dataloader.dataset)
    print(f"Initial Performance of Untrained Model: \nAvg Train Loss: {avg_train_loss}, Avg Test Loss: {avg_test_loss}")
    train_losses.append(avg_train_loss)
    test_losses.append(avg_test_loss)

Initial Performance of Untrained Model: 
Avg Train Loss: 1.1982220709096905, Avg Test Loss: 1.2055670367770623
