In [4]:
#$ Imports
import os
import shutil
import sys
import json
from datetime import datetime

import numpy as np
from scipy import stats
import xarray as xr
import cftime
from tqdm import tqdm

import matplotlib as mpl
import matplotlib.pyplot as plt
import cartopy
import cartopy.crs as ccrs
import cmocean as cmo

import torch
import utils
from utils import helpers

#$ Global variables
with open("~/s2s/paths.json") as paths_json:
    PATHS = json.load(paths_json)
with open("~/s2s/globals.json") as globals_json:
    GLOBALS = json.load(globals_json)
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
if DEVICE == 'cuda':
    print(f"Using cuda device {torch.cuda.get_device_name(0)}")

In [5]:
def save_all_the_data(tau, output_path = PATHS['s2s_predictions']):
    data_path = os.path.join(PATHS['full_globe'], 'data', f'tau_{tau}')
    prediction_path = os.path.join(PATHS['full_globe'], 'predictions',  f'tau_{tau}')
    output_path = os.path.join(output_path, f'tau_{tau}')
    os.makedirs(output_path, exist_ok=True)

    # Standardizer
    X_scaler = utils.processing.load_standardizer(data_path, 'X')
    utils.processing.save_standardizer(X_scaler, output_path, 'X_standardizer')
    y_scaler = utils.processing.load_standardizer(data_path, 'y')
    y_scaler = y_scaler.isel(lat=slice(0,None,2), lon=slice(0,None,2))
    utils.processing.save_standardizer(y_scaler, output_path, 'y_standardizer')

    # Gotta do this stupid thing with the residual scaler since it's dependent on each individual standardizer
    y_scaler_flattened = y_scaler.stack(s=('lat', 'lon')).dropna(dim='s')
    s = y_scaler_flattened['s']
    residual_scaler = xr.zeros_like(y_scaler_flattened)
    for i in tqdm(range(len(s))):
        network_path = os.path.join(PATHS['full_globe'], 'networks', f'tau_{tau}', f'loc_{i}')
        point_scaler = np.load(os.path.join(network_path, 'residual_standardizer.npy'), allow_pickle=True).item()
        residual_scaler['mean'][i] = point_scaler['mean']
        residual_scaler['std'][i] = point_scaler['std']
    residual_scaler = residual_scaler.unstack()
    utils.processing.save_standardizer(residual_scaler, output_path, 'residual_standardizer')

    for file in ['target', 'pred_mean', 'pred_logvar', 'dp_pred']:
        sourcefile = os.path.join(prediction_path, f'{file}.nc')
        shutil.copy(sourcefile, output_path)

    # X_test
    shutil.copy(os.path.join(data_path, 'X_test.nc'), output_path)

    # X_dp
    times = xr.open_dataarray(os.path.join(prediction_path, 'target.nc'))['time']

    da = xr.open_dataarray(os.path.join(data_path, f'X_dp_test.nc'))
    da = xr.DataArray(
        da.data, coords={
            'time': da['time'].values,
            'lat': da['lat'],
            'lon': da['lon']
        }
    )
    da.to_netcdf(os.path.join(output_path, f'X_dp_test.nc'))

In [7]:
%%time
save_all_the_data(tau=10)

100%|██████████| 6590/6590 [01:03<00:00, 104.50it/s]


CPU times: user 6.27 s, sys: 1min 40s, total: 1min 46s
Wall time: 2min 52s


In [6]:
%%time
save_all_the_data(tau=20)

100%|██████████| 6590/6590 [01:00<00:00, 108.54it/s]


CPU times: user 6.43 s, sys: 1min 27s, total: 1min 34s
Wall time: 2min 37s


In [8]:
%%time
save_all_the_data(tau=60)

100%|██████████| 6590/6590 [01:02<00:00, 104.65it/s]


CPU times: user 6.19 s, sys: 1min 5s, total: 1min 11s
Wall time: 2min 12s


In [9]:
%%time
save_all_the_data(tau=120)

100%|██████████| 6590/6590 [00:58<00:00, 112.50it/s]


CPU times: user 6.55 s, sys: 19.3 s, total: 25.8 s
Wall time: 1min 22s
