In [1]:
import xarray as xr
import numpy as np
import os
import json
import torch
from tqdm import tqdm

class ClimSimNumpySharder:
    def __init__(self):
        # Configuration of Variables (The 557/128 architecture)
        self.input_profiles = ['state_t', 'state_q0001', 'state_q0002', 'state_q0003', 'state_u', 'state_v', 'pbuf_ozone', 'pbuf_CH4', 'pbuf_N2O']
        self.input_scalars = ['state_ps', 'pbuf_SOLIN', 'pbuf_LHFLX', 'pbuf_SHFLX', 'pbuf_TAUX', 'pbuf_TAUY', 'cam_in_ALDIF', 'cam_in_ALDIR', 'cam_in_ASDIF', 'cam_in_ASDIR', 'cam_in_ICEFRAC', 'cam_in_LANDFRAC', 'cam_in_LWUP', 'cam_in_OCNFRAC', 'cam_in_SNOWHICE', 'cam_in_SNOWHLAND', 'pbuf_COSZRS']
        
        self.target_profiles = ['state_t', 'state_q0001'] 
        self.target_scalars = ['cam_out_NETSW', 'cam_out_FLWDS', 'cam_out_PRECC', 'cam_out_PRECSC', 'cam_out_SOLL', 'cam_out_SOLLD', 'cam_out_SOLS', 'cam_out_SOLSD']

        # Build Internal Mappings
        self.input_indices, self.total_input_dim = self._build_index_map(self.input_profiles, self.input_scalars)
        self.target_indices, self.total_target_dim = self._build_index_map(self.target_profiles, self.target_scalars)

    def _build_index_map(self, profiles, scalars):
        mapping = {}
        curr = 0
        for p in profiles:
            mapping[p] = {"start": curr, "end": curr + 60}
            curr += 60
        for s in scalars:
            mapping[s] = {"start": curr, "end": curr + 1}
            curr += 1
        return mapping, curr

    def create_shards(self, mli_paths, mlo_paths, output_dir, shard_size=100):
        os.makedirs(output_dir, exist_ok=True)
        all_input_means, all_input_sq_means = [], []
        num_files = len(mli_paths)

        for shard_idx, start_i in enumerate(range(0, num_files, shard_size)):
            X_shard, Y_shard = [], []
            end_i = min(start_i + shard_size, num_files)
            
            for mli, mlo in tqdm(zip(mli_paths[start_i:end_i], mlo_paths[start_i:end_i]), 
                                 total=end_i-start_i, desc=f"Shard {shard_idx}"):
                try:
                    with xr.open_dataset(mli) as ds_in, xr.open_dataset(mlo) as ds_out:
                        # Vectorized stacking
                        X_file = np.hstack([ds_in[v].values.T if v in self.input_profiles else ds_in[v].values.reshape(-1, 1) for v in self.input_profiles + self.input_scalars])
                        Y_file = np.hstack([ds_out[v].values.T if v in self.target_profiles else ds_out[v].values.reshape(-1, 1) for v in self.target_profiles + self.target_scalars])
                        X_shard.append(X_file)
                        Y_shard.append(Y_file)
                except Exception as e:
                    print(f"Error processing {mli}: {e}")

            # vstack rows (384 * shard_size, features)
            X_final = np.vstack(X_shard).astype(np.float32)
            Y_final = np.vstack(Y_shard).astype(np.float32)

            # Accumulate statistics
            all_input_means.append(np.mean(X_final, axis=0))
            all_input_sq_means.append(np.mean(X_final**2, axis=0))

            np.save(os.path.join(output_dir, f"X_shard_{shard_idx}.npy"), X_final)
            np.save(os.path.join(output_dir, f"Y_shard_{shard_idx}.npy"), Y_final)

        # Save Final Metadata
        final_mean = np.mean(all_input_means, axis=0)
        final_std = np.sqrt(np.mean(all_input_sq_means, axis=0) - final_mean**2)
        
        metadata = {
            "input_indices": self.input_indices,
            "target_indices": self.target_indices,
            "input_mean": final_mean.tolist(),
            "input_std": final_std.tolist(),
            "total_input_dim": self.total_input_dim,
            "total_target_dim": self.total_target_dim
        }
        
        with open(os.path.join(output_dir, "metadata.json"), "w") as f:
            json.dump(metadata, f, indent=4)
        print(f"Done! Metadata saved to {output_dir}/metadata.json")

    @staticmethod
    def get_variable(data, var_name, mapping):
        """Helper to extract a variable from a loaded numpy shard or torch tensor."""
        if var_name not in mapping:
            raise ValueError(f"Variable {var_name} not found in mapping.")
        start = mapping[var_name]['start']
        end = mapping[var_name]['end']
        return data[..., start:end]

In [2]:
def get_data_folders(path):
    data_folders = os.listdir(path)
    data_folders.sort()

    print(f"Found {len(data_folders)} data folders.")

    mli_samples = []
    mlo_samples = []
    for dir_name in data_folders:
        files = os.listdir(os.path.join(path, dir_name))
        for f in files:
            if f.split('.')[1] == 'mli':
                mli_samples.append(os.path.join(path, dir_name, f))
            elif f.split('.')[1] == 'mlo':
                mlo_samples.append(os.path.join(path, dir_name, f))
    
    return mli_samples, mlo_samples

def read_sample(file_path):
    return xr.open_dataset(file_path)

In [3]:
mli_samples, mlo_samples = get_data_folders("ClimSim_low-res/train/")

sharder = ClimSimNumpySharder()
sharder.create_shards(mli_samples, mlo_samples, "ClimSimLowResShards", shard_size=1000)

Found 5 data folders.


Shard 0: 100%|██████████| 1000/1000 [01:36<00:00, 10.41it/s]
  all_input_sq_means.append(np.mean(X_final**2, axis=0))
Shard 1: 100%|██████████| 1000/1000 [01:36<00:00, 10.35it/s]
  all_input_sq_means.append(np.mean(X_final**2, axis=0))
Shard 2: 100%|██████████| 1000/1000 [01:37<00:00, 10.21it/s]
  all_input_sq_means.append(np.mean(X_final**2, axis=0))
Shard 3: 100%|██████████| 1000/1000 [01:39<00:00, 10.10it/s]
  all_input_sq_means.append(np.mean(X_final**2, axis=0))
Shard 4: 100%|██████████| 1000/1000 [01:38<00:00, 10.19it/s]
  all_input_sq_means.append(np.mean(X_final**2, axis=0))
Shard 5: 100%|██████████| 1000/1000 [01:39<00:00, 10.06it/s]
  all_input_sq_means.append(np.mean(X_final**2, axis=0))
Shard 6: 100%|██████████| 1000/1000 [01:39<00:00, 10.05it/s]
  all_input_sq_means.append(np.mean(X_final**2, axis=0))
Shard 7: 100%|██████████| 1000/1000 [01:38<00:00, 10.12it/s]
  all_input_sq_means.append(np.mean(X_final**2, axis=0))
Shard 8: 100%|██████████| 1000/1000 [01:39<00:00, 10.08i

Done! Metadata saved to ClimSimLowResShards/metadata.json
