In [6]:
import sys
import os
os.environ['PROJ_DATA'] = "/pscratch/sd/p/plutzner/proj_data"
import xarray as xr
import torch
import torchinfo
import random
import numpy as np
import importlib as imp
import pandas as pd
import matplotlib.pyplot as plt
import pickle
import cartopy.crs as ccrs
import json
import pickle
import gzip
#import matplotlib.colors as mcolorsxx

%load_ext autoreload
%autoreload 2
import utils
import utils.filemethods as filemethods
import databuilder.data_loader as data_loader
import model.loss as module_loss
import model.metric as module_metric
from databuilder.data_generator import multi_input_data_organizer
import databuilder.data_loader as data_loader
from trainer.trainer import Trainer
from model.build_model import TorchModel
from utils import utils
from databuilder.sampleclass import SampleDict

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Check climate data class

In [1]:
class ClimateData:
    " Custom dataset for climate data and processing "

    def __init__(self, config, expname, seed, data_dir, figure_dir, target_only=False, fetch=True, verbose=False):
   
        self.config = config
        self.expname = expname
        self.seed = seed
        self.data_dir = data_dir
        self.figure_dir = figure_dir
        self.verbose = verbose
        self.target_only = target_only
    
        if fetch:
            self.fetch_data()

    def fetch_data(self, verbose=None):
        if verbose is not None: 
            self.verbose = verbose

        self.d_train = SampleDict()
        self.d_val = SampleDict()
        self.d_test = SampleDict()

        self._create_data() 

        # if self.verbose:
        #     self.d_train.summary()
        #     self.d_val.summary()
        #     self.d_test.summary()

        return self.d_train, self.d_val, self.d_test 

    def _create_data(self):  
        for iens, ens in enumerate(self.config["ensembles"]):
            print("Opening .nc files")
            if self.verbose:
                print(ens)
            if ens == "ens1":   
                train_ds = filemethods.get_netcdf_da(self.data_dir + str(ens) + "/input_vars.v2.LR.historical_0101.eam.h1.1850-2014.nc")
                #train_ds = filemethods.get_netcdf_da(self.data_dir +  "/input_vars.v2.LR.historical_0101.eam.h1.1850-2014.nc")

            if ens == "ens2":
                validate_ds = filemethods.get_netcdf_da(self.data_dir + str(ens) + "/input_vars.v2.LR.historical_0151.eam.h1.1850-2014.nc")
                #validate_ds = filemethods.get_netcdf_da(self.data_dir + "/input_vars.v2.LR.historical_0151.eam.h1.1850-2014.nc")

            elif ens == "ens3":
                test_ds = filemethods.get_netcdf_da(self.data_dir + str(ens) + "/input_vars.v2.LR.historical_0201.eam.h1.1850-2014.nc")
                #test_ds = filemethods.get_netcdf_da(self.data_dir + "/input_vars.v2.LR.historical_0201.eam.h1.1850-2014.nc")
        
  
        train_ds = train_ds.sel(time = slice("1850", "2014"))
        validate_ds = validate_ds.sel(time = slice("1850", "2014"))
        test_ds = test_ds.sel(time = slice("1850", "2014"))

        # Get opened X and Y data
        # Process Data (compute anomalies)
        print("Processing training")
        f_dict_train = self._process_data(train_ds)
        print("Processing validation")
        f_dict_val = self._process_data(validate_ds)
        print("Processing testing")
        f_dict_test = self._process_data(test_ds)

        self.d_train.concat(f_dict_train) 
        self.d_val.concat(f_dict_val) 
        self.d_test.concat(f_dict_test) 


    def _process_data(self, ds):
        '''
        Motivation: create file data dictionary to contain samples for use in ML model

        Input: 
        - Xarray DataSet
            Input dataset contains all input variables in one file

        Output: 
        - Dictionary containing Xarray DataArrays
            Output f_dict contains 'da'. 
            'da' contains multiple dimensions of masked, de-trended, de-seasonalized anomalies for all input variables. 
            
            f_dict contains 'da' using preprocessing keys as pointers

        '''

        f_dict = SampleDict() 

        # (1) Isolate the individual dataset values of ds : PRECT, TS, etc. 
        for ivar, var in enumerate(self.config["input_vars"]):
            if ivar == 0:
                da = ds[var]
                print("isolating variables from ds")
                print(f"da of isolated PRECT shape: {da.shape}")
                if var == "PRECT": ## CONVERTING PRECIP TO MM/DAY!
                    da = da * 10e3 * 86400 
                else:
                    pass
                da = da.expand_dims(dim={"channel": 1}, axis = -1)   # (2) Create a channel dimension in da
            else: 
                da = xr.concat([da, ds[var]], dim = "channel")  # (3) Fill channel dim with var array
      
        da = da.rename('SAMPLES')
        da.attrs['long_name'] = None
        da.attrs['units'] = None
        da.attrs['cell_methods'] = None


        # For each input variable or data entity you would like to process: 
        for ikey, key in enumerate(f_dict):
            if key == "y":
                print("Processing target output")

                f_dict[key] = ds[self.config["target_var"]]

                if self.config["target_var"] == "PRECT": # CONVERTING PRECIP TO MM/DAY!
                    f_dict[key] = f_dict[key] * 10e3 * 86400 

                print(f"Length of target after temp conversion = {(len(f_dict[key]))}")
                
                # EXTRACT TARGET LOCATION
                targetlat = self.config["target_region"][0]
                targetlon = self.config["target_region"][1]
                f_dict[key] = f_dict[key].sel(lat = targetlat, lon = targetlon, method = 'nearest')

                print(f"Length of target after extract target location = {(len(f_dict[key]))}")

                # REMOVE SEASONAL CYCLE
                f_dict[key] = self.trend_remove_seasonal_cycle(f_dict[key])

                print(f"Length of target after trend remove seasonal cycle = {(len(f_dict[key]))}")

                # ROLLING AVERAGE
                f_dict[key] = self.rolling_ave(f_dict[key]) # first six values are now nans due to 7-day rolling mean

                print(f"Length of target after rolling average = {(len(f_dict[key]))}")

                # LAG ADJUSTMENT OF TARGET DATASET : Lagging by self.config["lagtime"] number of days allows the input and target samples to align
                #  such that each input is paired with a target that is X days in the future
                if self.config["lagtime"] != 0: 
                    f_dict[key] = f_dict[key][ self.config["lagtime"]: ]

                print(f"Length of target after lag adjustment = {(len(f_dict[key]))}")
                 #TODO: Confirm addition of nans?? "Lead/Lag code for y - shift forward 10 days = input 10x nans at the beginning of the dataset"

            else: 
                if self.target_only == True:
                    pass
                else:
                    print("Processing inputs")
                    if len(self.config["input_vars"]) == 1:
                        f_dict[key] = da
                    
                        ## EXTRACT REGION
                        f_dict[key] = self._extractregion(f_dict[key])

                        ## MASK LAND/OCEAN 
                        f_dict[key] = self._masklandocean(f_dict[key])
                    
                        ## REMOVE SEASONAL CYCLE
                        f_dict[key] = self.trend_remove_seasonal_cycle(f_dict[key])

                        ## ROLLING AVERAGE 
                        f_dict[key] = self.rolling_ave(f_dict[key])

                        ## LAG ADJUSTMENT OF INPUT: 
                        f_dict[key] = f_dict[key][0 : -self.config["lagtime"], ...]

                    else:
                        # LOAD f_dict dictionary with unprocessed channels of 'da'
                        f_dict[key] = da 
                
                        ## EXTRACT REGION
                        f_dict[key] = self._extractregion(f_dict[key])

                        ## MASK LAND/OCEAN 
                        f_dict[key] = self._masklandocean(f_dict[key])

                        # REMOVE SEASONAL CYCLE
                        for ichannel in range(f_dict[key].shape[-1]):
                            f_dict[key][..., ichannel] = self.trend_remove_seasonal_cycle(f_dict[key][...,ichannel])
                        
                        # checkplot = f_dict[key].sel(time = '1905-01-01')
                        # checkplot[...,1].plot()

                        ## ROLLING AVERAGE 
                        f_dict[key] = self.rolling_ave(f_dict[key])

                        ## LAG ADJUSTMENT OF INPUT: 
                        f_dict[key] = f_dict[key][0 : -self.config["lagtime"], ...]
                    
                    # Confirmed smoothed, detrended, deseasonalized, lag-adjusted anomalies of PRECT and TS
            
        return f_dict
    
    def _extractregion(self, da): 
        if self.config["input_region"] == "None": 
            
            # "input_region": [[-15.0, 15.0, 40.0, 300.0],
            #              [-15.0, 15.0, 40.0, 300.0]],
            
            min_lon, max_lon = [0, 360]
            min_lat, max_lat = [-90, 90]
            print("input region is none")
        else:
            min_lat, max_lat = self.config["input_region"][:2]
            min_lon, max_lon = self.config["input_region"][2:]

        if isinstance(da, xr.DataArray):
            mask_lon = (da.lon >= min_lon) & (da.lon <= max_lon)
            mask_lat = (da.lat >= min_lat) & (da.lat <= max_lat)
            data_masked = da.where(mask_lon & mask_lat, drop=True)
            return (
                data_masked #,
                #data_masked["lat"].to_numpy().astype(np.float32),
                #data_masked["lon"].to_numpy().astype(np.float32),
            )
        else:
            raise NotImplementedError("data must be xarray")
        
    
    def _masklandocean(self, da):
        if self.config["input_mask"][0] == "None":
            return da
        
        mask = xr.open_dataset(self.data_dir + "/landfrac.bilin.nc")["LANDFRAC"][0, :, :]

        if self.config["input_mask"][0] == "land":
            da_masked = da * xr.where(mask > 0.5, 1.0, 0.0)
        elif self.config["input_mask"][0] == "ocean":
            da_masked = da * xr.where(mask > 0.5, 0.0, 1.0)
        else: 
            raise NotImplementedError('oops NONE error - line 147 of _masklandocean')
        
        return da_masked

    def subtract_trend(self, x): 
        
        detrendOrder = 3

        curve = np.polynomial.polynomial.polyfit(np.arange(0, x.shape[0]), x, detrendOrder)
        trend = np.polynomial.polynomial.polyval(np.arange(0, x.shape[0]), curve) 
    
        try: 
            detrend = x - np.swapaxes(trend, 0, 1)
        except:
            detrend = x - trend
        return detrend 
    
    
    def trend_remove_seasonal_cycle(self, da):

        if len(np.array(da.shape)) == 1: 
            return da.groupby("time.dayofyear").map(self.subtract_trend).dropna("time")
        
        else: 
            da_copy = da.copy()

            inc = 45 # 45 degree partitions in longitude to split up the data
        
            for iloop in np.arange(0, da_copy.shape[2] // inc + 1):
                start = inc * iloop
                end = np.min([inc * (iloop + 1), da_copy.shape[2]])
                if start == end:
                    break

                stacked = da[:, :, start:end].stack(z=("lat", "lon"))

                da_copy[:, :, start:end] = stacked.groupby("time.dayofyear").map(self.subtract_trend).unstack()
        
        return da_copy.dropna("time")

    def rolling_ave(self, da):
        if self.config["averaging_length"] == 0:
            return da
        else: 
            if len(da.shape) == 1: 
                return da.rolling(time = self.config["averaging_length"]).mean()
            else: 
                da_copy = da.copy()
                inc = 45
                for iloop in np.arange(0, da.shape[2] // inc + 1): 
                    start = inc * iloop
                    end = np.min([inc *(iloop + 1), da_copy.shape[2]])
                    if start == end: 
                        break

                    da_copy[:, :, start:end] = da[:, :, start:end].rolling(time = self.config["averaging_length"]).mean()

                return da_copy
            

SyntaxError: '(' was never closed (572498212.py, line 54)

In [8]:
config = utils.get_config("exp001")
seed = config["seed_list"][0]

In [9]:
imp.reload(utils)
imp.reload(filemethods)

data = ClimateData(
    config["databuilder"], 
    expname = config["expname"],
    seed=seed,
    data_dir = config["data_dir"], 
    figure_dir=config["figure_dir"],
    target_only=True,
    fetch=False,
    verbose=False
)

In [10]:
d_train, d_val, d_test = data.fetch_data()

Opening .nc files
Opening .nc files
Opening .nc files
train ds shape <xarray.DataArray 'time' (time: 60226)> Size: 482kB
array([cftime.DatetimeNoLeap(1850, 1, 1, 0, 0, 0, 0, has_year_zero=True),
       cftime.DatetimeNoLeap(1850, 1, 2, 0, 0, 0, 0, has_year_zero=True),
       cftime.DatetimeNoLeap(1850, 1, 3, 0, 0, 0, 0, has_year_zero=True), ...,
       cftime.DatetimeNoLeap(2014, 12, 30, 0, 0, 0, 0, has_year_zero=True),
       cftime.DatetimeNoLeap(2014, 12, 31, 0, 0, 0, 0, has_year_zero=True),
       cftime.DatetimeNoLeap(2015, 1, 1, 0, 0, 0, 0, has_year_zero=True)],
      dtype=object)
Coordinates:
  * time     (time) object 482kB 1850-01-01 00:00:00 ... 2015-01-01 00:00:00
Attributes:
    standard_name:  time
    long_name:      time
    bounds:         time_bnds
    axis:           T
post slice train ds shape <xarray.DataArray 'time' (time: 60225)> Size: 482kB
array([cftime.DatetimeNoLeap(1850, 1, 1, 0, 0, 0, 0, has_year_zero=True),
       cftime.DatetimeNoLeap(1850, 1, 2, 0, 0, 0,

KeyboardInterrupt: 