In [1]:
import xarray as xr
import xcdat
import pandas as pd
import numpy as np
import cftime
import datetime
import sys
import json

In [2]:
class TimeSeriesData():
    def __init__(self, ds, ds_var):
        self.ds = ds
        self.ds_var = ds_var
        self._set_years()
        self._set_calendar()
    
    def _set_years(self):
        self.year_beg = self.ds.isel({"time": 0}).time.dt.year.item()
        self.year_end = self.ds.isel({"time": -1}).time.dt.year.item()

        if self.year_end < self.year_beg + 1:
            raise Exception("Error: Final year must be greater than beginning year.")

        self.year_range = np.arange(self.year_beg,self.year_end+1,1)
        
    def _set_calendar(self):
        self.calendar = self.ds.time.encoding["calendar"]
        
    def rolling_5day(self):
        # Use on daily data
        return self.ds[self.ds_var].rolling(time=5).mean()

    def daily_total(self):
        # Use on sub-daily data
        return  self.ds[self.ds_var].resample(time='1D').sum(dim="time")
    
    def daily_min(self):
        # Use on sub-daily data
        return  self.ds[self.ds_var].resample(time='1D').min(dim="time")
        
    def daily_max(self):
        # Use on sub-daily data
        return self.ds[self.ds_var].resample(time='1D').max(dim="time")

In [3]:
class SeasonalAverager():
    # Make seasonal averages of data in TimeSeriesData class

    def __init__(self, tsds, dec_mode="DJF", drop_incomplete_djf=True, annual_strict=True):
        self.ds = tsds
        self.dec_mode = dec_mode
        self.drop_incomplete_djf = drop_incomplete_djf
        self.annual_strict = annual_strict
        self.del1d = datetime.timedelta(days=1)
        self.del0d = datetime.timedelta(days=0)
    
    def calc_daily_max(self):
        self.daily_max = self.ds.daily_max()
    
    def calc_daily_min(self):
        self.daily_min = self.ds.daily_min()
        
    def calc_5day_mean(self):
        self.rolling = self.ds.rolling_5day()
        
    def annual_stats(self,daily_stat,stat):
                
        if daily_stat == "daily_max":
            ds = self.daily_max
        elif daily_stat == "daily_min":
            ds = self.daily_min
        elif daily_stat == "rolling_5day":
            ds = self.rolling

        if self.annual_strict:
            # This setting is for means using 5 day rolling average values, where
            # we do not want to include any data from the prior year
            cal = self.ds.calendar
            year_range = self.ds.year_range

            # Only use data from that year - start on Jan 5 avg
            date_range = [xr.cftime_range(
                            start=cftime.datetime(year,1,5,calendar=cal)-self.del0d,
                            end = cftime.datetime(year+1,1,1,calendar=cal)-self.del1d,
                            freq='D',
                            calendar=cal) for year in year_range]
            date_range = [item for sublist in date_range for item in sublist]
            if stat=="max":
                ds_ann = ds.sel(time=date_range).groupby("time.year").max(dim="time")
            elif stat=="min":
                ds_ann = ds.sel(time=date_range).groupby("time.year").min(dim="time")
        else:
            # Group by date
            if stat=="max":
                ds_ann = ds.groupby("time.year").max(dim="time")
            elif stat=="min":
                ds_ann = ds.groupby("time.year").min(dim="time")
        ds_ann = ds_ann.rename({"year": "time"})
        return ds_ann

    def seasonal_stats(self,daily_stat,season,stat):
        # Seasons can be "DJF","MAM","JJA","SON"
        # Stat can be "max", "min"

        year_range = self.ds.year_range
        
        if daily_stat == "daily_max":
            ds = self.daily_max
        elif daily_stat == "daily_min":
            ds = self.daily_min
        elif daily_stat == "rolling_5day":
            ds = self.rolling

        if season == "DJF" and self.dec_mode =="DJF":
            # Resample DJF to count prior DJF in current year
            if stat == "max":
                ds_stat = ds.resample(time='QS-DEC').max(dim="time")
            elif stat=="min":
                ds_stat = ds.resample(time='QS-DEC').min(dim="time")

            ds_stat = ds_stat.isel(time=ds_stat.time.dt.month.isin([12]))
            
            if self.drop_incomplete_djf:
                ds_stat = ds_stat.sel(time=slice(str(year_range[0]),str(year_range[-1]-1)))
                ds_stat["time"] = np.arange(year_range[0]+1,year_range[-1]+1)
            else:
                ds_stat = ds_stat.sel(time=slice(str(year_range[0]-1),str(year_range[-1])))
                ds_stat["time"] = np.arange(year_range[0],year_range[-1]+2)
    
        elif season == "DJF" and self.dec_mode == "JFD":
            cal = self.ds.calendar

            # Make date lists that capture JF and D in all years, then merge and sort
            date_range_1 = [xr.cftime_range(
                                start=cftime.datetime(year,1,1,calendar=cal)-self.del0d,
                                end=cftime.datetime(year,3,1,calendar=cal)-self.del1d,
                                freq='D',
                                calendar=cal) for year in year_range]
            date_range_1 = [item for sublist in date_range_1 for item in sublist]
            date_range_2 = [xr.cftime_range(
                                start=cftime.datetime(year,12,1,calendar=cal)-self.del0d,
                                end=cftime.datetime(year+1,1,1,calendar=cal)-self.del1d,
                                freq='D',
                                calendar=cal) for year in year_range]
            date_range_2 = [item for sublist in date_range_2 for item in sublist]
            date_range = sorted(date_range_1 + date_range_2)
            
            if stat=="max":
                ds_stat = ds.sel(time=date_range).groupby("time.year").max(dim="time")
            elif stat=="min":
                ds_stat = ds.sel(time=date_range).groupby("time.year").min(dim="time")
            ds_stat = ds_stat.rename({"year": "time"})
        
        else:  # Other 3 seasons
            dates = { # Month/day tuples
                "MAM": [(3,1), (6,1)],
                "JJA": [(6,1), (9,1)],
                "SON": [(9,1), (12,1)]
            }
            
            mo_st = dates[season][0][0]
            day_st = dates[season][0][1]
            mo_en = dates[season][1][0]
            day_en = dates[season][1][1]
            
            cal = self.ds.calendar

            date_range = [xr.cftime_range(
                            start=cftime.datetime(year,mo_st,day_st,calendar=cal)-self.del0d,
                            end=cftime.datetime(year,mo_en,day_en,calendar=cal)-self.del1d,
                            freq='D',
                            calendar=cal) for year in year_range]
            date_range = [item for sublist in date_range for item in sublist]
            
            if stat=="max":
                ds_stat = ds.sel(time=date_range).groupby("time.year").max(dim="time")
            elif stat=="min":
                ds_stat = ds.sel(time=date_range).groupby("time.year").min(dim="time")
            ds_stat = ds_stat.rename({"year": "time"})  
            
        return ds_stat

In [118]:
season_list = ["DJF","MAM","JJA","SON"]
ds = xr.open_dataset("test_data/lowres_ts_rand_1980-1999.nc")

In [121]:
sftlf = xr.zeros_like(ds)
sftlf = sftlf.rename({"TS":"sftlf"})
sftlf["sftlf"] = sftlf["sftlf"]+1
sftlf = sftlf.isel({"time": 0})


In [122]:
TS = TimeSeriesData(ds,"TS")

S = SeasonalAverager(TS,dec_mode="DJF", drop_incomplete_djf=True, annual_strict=False)

S.calc_daily_max()
S.calc_daily_min()

TXx = xr.Dataset()
TXn = xr.Dataset()
TNx = xr.Dataset()
TNn = xr.Dataset()

TXx["ANN"] = S.annual_stats("daily_max","max")
TXn["ANN"] = S.annual_stats("daily_max","min")
TNx["ANN"] = S.annual_stats("daily_min","max")
TNn["ANN"] = S.annual_stats("daily_min","max")

for season in season_list:
    TXx[season] = S.seasonal_stats("daily_max",season,"max")
    TXn[season] = S.seasonal_stats("daily_max",season,"min")
    TNx[season] = S.seasonal_stats("daily_min",season,"max")
    TNn[season] = S.seasonal_stats("daily_min",season,"min")
    
TXx = TXx.bounds.add_missing_bounds()
TXn = TXn.bounds.add_missing_bounds()
TNx = TNx.bounds.add_missing_bounds()
TNn = TNn.bounds.add_missing_bounds()



In [115]:
#TXx.to_netcdf("TXx.nc")
#TXn.to_netcdf("TXn.nc")
#TNx.to_netcdf("TNx.nc")
#TNn.to_netcdf("TNn.nc")

In [80]:
# Initialize metrics
metrics = {
    "Dimensions": {
        "dimensions": ["Metric","Region","Season","Year"],
        "Region": ["land"],
        "Season": ["ANN","DJF","MAM","JJA","SON"]
    },
    "Results": {}
}

for m,ds_m in zip(["TXx","TXn","TNx","TNn"],[TXx,TXn,TNx,TNn]):
    metrics["Results"][m] = {
        "land": {
            "ANN": {},
            "DJF": {},
            "MAM": {},
            "JJA": {},
            "SON": {}
        }
    }
    for season in ["ANN"]+season_list:
        # TODO: Need to mask out land 
        tmp = ds_m.spatial.average(season)[season]
        tmp_list = [{int(yr.data): float(tmp.sel({"time":yr}).data)} for yr in tmp.time]
        tmp_dict ={}
        for d in tmp_list:
            tmp_dict.update(d)
        metrics["Results"][m]["land"][season] = tmp_dict
        

In [85]:
metrics

{'Dimensions': {},
 'Results': {'TXx': {'land': {'ANN': {1980: 99.00000000000004,
     1981: 99.00000000000004,
     1982: 99.00000000000004,
     1983: 99.00000000000004,
     1984: 99.00000000000004,
     1985: 99.00000000000004,
     1986: 99.00000000000004,
     1987: 99.00000000000004,
     1988: 99.00000000000004,
     1989: 99.00000000000004,
     1990: 99.00000000000004,
     1991: 99.00000000000004,
     1992: 99.00000000000004,
     1993: 99.00000000000004,
     1994: 99.00000000000004,
     1995: 99.00000000000004,
     1996: 99.00000000000004,
     1997: 99.00000000000004,
     1998: 99.00000000000004,
     1999: 99.00000000000004},
    'DJF': {1980: nan,
     1981: 99.00000000000004,
     1982: 99.00000000000004,
     1983: 99.00000000000004,
     1984: 99.00000000000004,
     1985: 99.00000000000004,
     1986: 99.00000000000004,
     1987: 99.00000000000004,
     1988: 99.00000000000004,
     1989: 99.00000000000004,
     1990: 99.00000000000004,
     1991: 99.0000000000

In [89]:
ds=xr.open_dataset("test_data/lowres_randint_1980-1999.nc")

In [109]:
PR = TimeSeriesData(ds,"pr")

S = SeasonalAverager(PR,dec_mode="DJF",drop_incomplete_djf=True,annual_strict=True)
S.calc_5day_mean()

Px = xr.Dataset()

Px["ANN"] = S.annual_stats("rolling_5day","max")

for season in season_list:
    Px[season] = S.seasonal_stats("rolling_5day",season,"max")

Px = Px.bounds.add_missing_bounds()

In [114]:
#Px.to_netcdf("precipitation_rolling_5day_max.nc")

In [111]:
metrics["Results"]["Pr"] = {
    "land": {
        "ANN": [],
        "DJF": [],
        "MAM": [],
        "JJA": [],
        "SON": []

    }
}

for season in ["ANN"]+season_list:
    tmp = Px.spatial.average(season)[season]
    tmp_list = [{int(yr.data): float(tmp.sel({"time":yr}).data)} for yr in tmp.time]
    tmp_dict ={}
    for d in tmp_list:
        tmp_dict.update(d)
    metrics["Results"]["Pr"]["land"][season] = tmp_dict

In [112]:
metrics

{'Dimensions': {},
 'Results': {'TXx': {'land': {'ANN': {1980: 99.00000000000004,
     1981: 99.00000000000004,
     1982: 99.00000000000004,
     1983: 99.00000000000004,
     1984: 99.00000000000004,
     1985: 99.00000000000004,
     1986: 99.00000000000004,
     1987: 99.00000000000004,
     1988: 99.00000000000004,
     1989: 99.00000000000004,
     1990: 99.00000000000004,
     1991: 99.00000000000004,
     1992: 99.00000000000004,
     1993: 99.00000000000004,
     1994: 99.00000000000004,
     1995: 99.00000000000004,
     1996: 99.00000000000004,
     1997: 99.00000000000004,
     1998: 99.00000000000004,
     1999: 99.00000000000004},
    'DJF': {1980: nan,
     1981: 99.00000000000004,
     1982: 99.00000000000004,
     1983: 99.00000000000004,
     1984: 99.00000000000004,
     1985: 99.00000000000004,
     1986: 99.00000000000004,
     1987: 99.00000000000004,
     1988: 99.00000000000004,
     1989: 99.00000000000004,
     1990: 99.00000000000004,
     1991: 99.0000000000

## Testing with actual data


In [11]:
# 3 hourly obs
fname = "/p/user_pub/PCMDIobs/obs4MIPs/NASA-GSFC/IMERG-v06B-Final/3hr/pr/2x2/latest/pr_3hr_IMERG-v06B-Final_PCMDI_2x2_201812010000-201812312100.nc"

In [12]:
ds = xr.open_dataset(fname)

In [16]:
xr.infer_freq(ds.time)


'3H'

In [17]:
fname = "/p/user_pub/cmip/CMIP6/CMIP/MIROC/MIROC6/historical/r1i1p1f1/day/tasmax/gn/v20191016/tasmax_day_MIROC6_historical_r1i1p1f1_gn_20100101-20141231.nc"

In [18]:
ds = xr.open_dataset(fname)

In [19]:
xr.infer_freq(ds.time)

'D'

## New Section

In [1]:
import xarray as xr
import xcdat
import pandas as pd
import numpy as np
import cftime
import datetime
import sys
import os

#from pcmdi_metrics.extremes.lib import (
#    compute_metrics,
#    create_extremes_parser
#)
from lib import (
    compute_metrics,
    create_extremes_parser
)

In [9]:
case_id = "test_case_1"
model_list = ["MRI-ESM2-0"]
realization = "r1i1p1f1"
variable_list = ["pr"]
#reference_data_set = parameter.reference_data_set
filename_template = "%(model)/historical/%(realization)/day/pr/gn/v20190603/pr_day_%(model)_historical_%(realization)_gn_20000101-20141231.nc"
sftlf_filename_template = "/p/css03/esgf_publish/CMIP6/CMIP/MRI/MRI-ESM2-0/historical/r1i1p1f1/fx/sftlf/gn/v20190603/sftlf_fx_MRI-ESM2-0_historical_r1i1p1f1_gn.nc"
generate_sftlf = False
test_data_path = "/p/user_pub/cmip/CMIP6/CMIP/MRI/"
#reference_data_path = parameter.reference_data_path
metrics_output_path = "./test/"
debug = False
cmec = True
chunk_size = None
strict_annual = True
exclude_leap = False
dec_mode = "DJF"
drop_incomplete_djf = True

In [10]:
if metrics_output_path is not None:
    metrics_output_path = metrics_output_path.replace('%(case_id)', case_id)

find_all_realizations = False
if realization is None:
    realization = ""
    realizations = [realization]
elif isinstance(realization, str):
    if realization.lower() in ["all", "*"]:
        find_all_realizations = True
    else:
        realizations = [realization]

metrics_dict = compute_metrics.init_metrics_dict()

# Loop over models
for model in model_list:
    sftlf = xr.open_dataset(sftlf_filename_template.replace('%(model)', model).replace('%(model_version)', model))

    if find_all_realizations:
        test_data_full_path = os.path.join(
            test_data_path,
            filename_template).replace('%(model)', model).replace('%(model_version)', model).replace('%(realization)', '*')
        ncfiles = glob.glob(test_data_full_path)
        realizations = []
        for ncfile in ncfiles:
            realizations.append(ncfile.split('/')[-1].split('.')[3])
        print('=================================')
        print('model, runs:', model, realizations)
    
    for run in realizations:
        for varname in variable_list:
            test_data_full_path = os.path.join(
                test_data_path,
                filename_template
                ).replace('%(variable)', varname).replace('%(model)', model).replace('%(model_version)', model).replace('%(realization)', run)
            if os.path.exists(test_data_full_path):
                print('-----------------------')
                print('model, run:', model, run)
                print('test_data (model in this case) full_path:', test_data_full_path)

            # TODO: mfdataset option?
            if chunk_size:
                ds = xr.load_dataset(test_data_full_path,chunks={"latitude":chunk_size,"longitude": chunk_size})
            else:
                ds = xr.load_dataset(test_data_full_path)

            if ds.time.encoding["calendar"] != "noleap" and exclude_leap:
                ds = self.ds.convert_calendar('noleap')

            # TODO convert 3 hourly to daily option
            if varname == "tasmax":
                TXx,TXn = compute_metrics.temperature_metrics(ds,varname)
                tmp_dict = {"TXx": TXx, "TXn": TXn}
                result_dict = compute_metrics.temperature_metrics_json(tmp_dict,sftlf)
                met_dict["Results"][model] = {
                    realization: result_dict
                }
                filepath = os.path.join(metrics_output_path,"TXx_{0}.nc".format("_".join([model,realization])))
                TXx.to_netcdf(filepath)
                filepath = os.path.join(metrics_output_path,"TXn_{0}.nc".format("_".join([model,realization])))
                TXn.to_netcdf(filepath)   
   
            if varname == "tasmin":
                TNx,TNn = compute_metrics.temperature_metrics(ds,varname)
                tmp_dict = {"TNx": TNx, "TNn": TNn}
                result_dict = compute_metrics.temperature_metrics_json(tmp_dict,sftlf)
                met_dict["Results"][model] = {
                    realization: result_dict
                }
                filepath = os.path.join(metrics_output_path,"TNx_{0}.nc".format("_".join([model,realization])))
                TNx.to_netcdf(filepath)
                filepath = os.path.join(metrics_output_path,"TNn_{0}.nc".format("_".join([model,realization])))
                TNn.to_netcdf(filepath)   

            if varname in ["pr","PRECT","precip"]:
                # Rename possible precipitation variable names for consistency
                if varname in ["precip","PRECT"]:
                    ds = ds.rename({variable: "pr"})
                P = compute_metrics.precipitation_metrics(ds)
                # Update metrics
                #print(P)
                #result_dict = compute_metrics.precipitation_metrics_json(P,sftlf)
                #met_dict["Results"][model] = {
                #    realization: result_dict
                #}
                #P.to_netcdf("Rx5day_{0}.nc".format("_".join([model,realization])))



-----------------------
model, run: MRI-ESM2-0 r1i1p1f1
test_data (model in this case) full_path: /p/user_pub/cmip/CMIP6/CMIP/MRI/MRI-ESM2-0/historical/r1i1p1f1/day/pr/gn/v20190603/pr_day_MRI-ESM2-0_historical_r1i1p1f1_gn_20000101-20141231.nc


In [11]:
season="ANN"
tmp = P.where(sftlf > 50).spatial.average(season)[season]

KeyError: "The data variable 'ANN' does not exist in the Dataset."

In [17]:
P.where(sftlf.sftlf > 50).spatial.average(season)[season]

In [None]:
met_dict = {"land": {
                    "ANN": [],
                    "DJF": [],
                    "MAM": [],
                    "JJA": [],
                    "SON": []
                }
            }

for season in ["ANN","DJF","MAM","JJA","SON"]:
    tmp = P.where(sftlf > 50).spatial.average(season)[season]
    tmp_list = [{int(yr.data): float(tmp.sel({"time":yr}).data)} for yr in tmp.time]
    tmp_dict ={}
    for d in tmp_list:
        tmp_dict.update(d)
    met_dict["Rx5day"]["land"][season] = tmp_dict