In [25]:
import torch
from torch.utils.data import Dataset
import numpy as np
import xarray as xr
import os
from sklearn.preprocessing import QuantileTransformer
from functools import lru_cache
from torch.utils.data import Dataset, DataLoader


In [10]:
config = get_config()

In [2]:
def get_config():
    config = {
        "input_channels": 1,
        "output_channels": 1,
        "context_image": True,
        "context_channels": 1,
        "num_blocks": [2, 2],
        "hidden_channels": 32,
        "hidden_context_channels": 8,
        "time_embedding_dim": 256,
        "image_size": 128,
        "noise_sampling_coeff": 0.85,
        "denoise_time": 970,
        "activation": "gelu",
        "norm": True,
        "subsample": 100000,
        "save_name": "model_weights.pt",
        "dim_mults": [4, 4],
        "base_dim": 32,
        "timesteps": 1000,
        "pading": "reflect",
        "scaling": "std",
        "optimization": {
            "epochs": 400,
            "lr": 0.01,
            "wd": 0.05,
            "batch_size": 32,
            "scheduler": True
        }
    }
    return config

def load_datasets(config):
    lead_nums = [str(int(i)) for i in np.linspace(0, 168, 57)]
    
    # Load datasets using xarray and convert directly to NumPy arrays
    train_datasets = [xr.open_dataset(f'/glade/derecho/scratch/timothyh/data/diffusion_forecasts/processed/lead_{i}/train.nc') for i in lead_nums]
    test_datasets = [xr.open_dataset(f'/glade/derecho/scratch/timothyh/data/diffusion_forecasts/processed/lead_{i}/test.nc') for i in lead_nums]
    val_datasets = [xr.open_dataset(f'/glade/derecho/scratch/timothyh/data/diffusion_forecasts/processed/lead_{i}/val.nc') for i in lead_nums]

    # Concatenate forecasts along the time dimension and convert to NumPy arrays
    xtrain = np.concatenate([ds.forecast.values for ds in train_datasets], axis=0)
    xtest = np.concatenate([ds.forecast.values for ds in test_datasets], axis=0)
    xval = np.concatenate([ds.forecast.values for ds in val_datasets], axis=0)

    # Load the Quantile Transform Scaler
    if config['scaling']=='quantile':
        print('quantiling')
        qtpie = load(open('quantile_transform_scaler.pkl', 'rb'))
    
        print('... transforming ...')
        # Apply Quantile Normalization
        xtrain_t = qtpie.transform(xtrain.reshape(-1, 1)).reshape(xtrain.shape)
        print('... done 1 ...')
        xtest_t = qtpie.transform(xtest.reshape(-1, 1)).reshape(xtest.shape)
        xval_t = qtpie.transform(xval.reshape(-1, 1)).reshape(xval.shape)
    elif config['scaling']=='std':
        print('standardizing')
        meanx = np.mean(xtrain)
        stdx = np.std(xtrain)

        xtrain_t = (xtrain-meanx)/stdx
        xval_t = (xval-meanx)/stdx
        xtest_t = (xtest-meanx)/stdx
        config['mean'] = meanx
        config['std'] = stdx
    else:
        print('bad things happened')
        raise 

    print('... rotating ...')

    # Data Augmentation using NumPy (rotations)
    xtrain_rot90 = np.rot90(xtrain_t, k=1, axes=(1, 2))  # 90 degrees rotation
    xtrain_rot180 = np.rot90(xtrain_t, k=2, axes=(1, 2)) # 180 degrees rotation
    xtrain_rot270 = np.rot90(xtrain_t, k=3, axes=(1, 2)) # 270 degrees rotation

    # Concatenate all rotations along the first dimension (time)
    xtrain_all = np.concatenate([xtrain_t, xtrain_rot90, xtrain_rot180, xtrain_rot270], axis=0)

    # Create datasets as NumPy arrays
    train_dataset = DataProcessed(xtrain_all)
    test_dataset = DataProcessed(xtest_t)
    val_dataset = DataProcessed(xval_t)

    return train_dataset, test_dataset, val_dataset, config

In [7]:
lead_nums = [str(int(i)) for i in np.linspace(0, 168, 57)]
train_datasets = [(f'/glade/derecho/scratch/timothyh/data/diffusion_forecasts/processed/lead_{i}/train.nc') for i in lead_nums]


In [19]:

class DataProcessed(Dataset):
    def __init__(self, file_paths, config, mean, std):
        """
        Args:
            file_paths: List of paths to netCDF files
            config: Configuration dict with scaling and augmentation options
        """
        self.file_paths = file_paths
        self.config = config
        self.mean = mean
        self.std = std
        self.augmentation = config.get('augment', False)
        self.scaler = None
        if config['scaling'] == 'quantile':
            self.scaler = QuantileTransformer()

         # Initialize cache size (you can adjust it depending on memory constraints)
        self.cache_size = config.get('cache_size', 10)
        self.cached_data = {}  # Manual cache for storing loaded data

    def __len__(self):
        # Calculate total number of samples across all files
        total_len = 0
        for path in self.file_paths:
            with xr.open_dataset(path) as ds:
                total_len += ds.sizes['time']  # Assuming 'time' is the main dimension
        return total_len

    def _apply_scaling(self, data):
        if self.config['scaling'] == 'quantile':
            data_t = self.scaler.fit_transform(data.reshape(-1, 1)).reshape(data.shape)
        elif self.config['scaling'] == 'std':
            data_t = (data - self.mean) / self.std
            self.config['mean'], self.config['std'] = self.mean, self.std
        else:
            raise ValueError("Invalid scaling method specified.")
        return data_t

    def _augment_data(self, data):
        # Apply rotations for augmentation
        data_rot90 = np.rot90(data, k=1, axes=(1, 2))
        data_rot180 = np.rot90(data, k=2, axes=(1, 2))
        data_rot270 = np.rot90(data, k=3, axes=(1, 2))
        
        # Concatenate all rotations along the first dimension (time)
        return np.concatenate([data, data_rot90, data_rot180, data_rot270], axis=0)

    @lru_cache(maxsize=2)  # Cache up to 5 file loads at once
    def _load_data_from_file(self, file_path):
        """
        Lazy load the data from file and preprocess it.
        """
        with xr.open_dataset(file_path) as ds:
            data = ds.forecast.values  # Replace 'forecast' with the relevant key in your dataset

            # Apply scaling
            data = self._apply_scaling(data)

            # Apply augmentation if necessary
            if self.augmentation:
                data = self._augment_data(data)

        return torch.tensor(data, dtype=torch.float32)

    def __getitem__(self, idx):
        """
        Load data lazily and cache it, based on global index.
        """
        # Determine which file and sample this index belongs to
        cumulative_len = 0
        for file_idx, path in enumerate(self.file_paths):
            with xr.open_dataset(path) as ds:
                file_len = ds.sizes['time']  # Length along the 'time' dimension
                if cumulative_len + file_len > idx:
                    sample_idx = idx - cumulative_len
                    data = self._load_data_from_file(path)  # Load data from cache or disk
                    return data[sample_idx]
                cumulative_len += file_len

        raise IndexError(f"Index {idx} is out of bounds")



In [20]:
config = get_config()
train_ds = DataProcessed(file_paths=train_datasets, config=config, mean=152, std=152)

In [21]:
train_ds.__len__()

215503

In [26]:
train_loader = DataLoader(train_ds, batch_size=config["optimization"]["batch_size"], shuffle=True, num_workers=6)

