In [None]:
#!/usr/bin/env python3
#
#
# usage: xarray_mapplot.py [-h] [--output OUTPUT]
#                               [--time TIMES]
#                               [--nrow NROW]
#                               [--ncol NCOL]
#                               [--title title]
#                               [--xlim "x1,x2"]
#                               [--ylim "y1,y2"]
#                               [--x_label x_label]
#                               [--y_label y_label]
#                               [--linestyle LINESTYLE]
#                               [--marker MARKER]
#                               [--range "valmin,valmax"]
#                               [--threshold VAL]
#                               [--config config-file]
#                               [--shift]
#                               [-v]
#                               input varname
#
# positional arguments:
#  input            input filename with geographical coordinates (netCDF
#                   format)
#  varname          Specify which variable to plot (case sensitive)
#
# optional arguments:
#  -h, --help       show this help message and exit
#  --output OUTPUT  output filename to store resulting image (png format)
#  --time TIMES     time index from the file for multiple plots ("0 1 2 3")
#  --title          plot or subplot title
#  --xlim           limited geographical area longitudes "x1,x2"
#  --ylim           limited geographical area latitudes "y1,y2"
#  --x_label        label for the x-axis
#  --y_label        label for the y-axis
#  --linestyle      sets the linestyle for the plot
#  --marker         sets the marker for the plotted data points
#  --range          "valmin,valmax" for plotting
#  --threshold      do not plot values below threshold
#  --label          set a label for time series
#  --config         plotting parameters are passed via a config file
#                   (overwrite other plotting options)
#  --shift          shift longitudes if specified
#  -v, --verbose    switch on verbose mode
#


In [3]:
import argparse
import ast
import warnings
from pathlib import Path
import netCDF4

In [4]:
import pandas as pd
import matplotlib as mpl
mpl.use('Agg')
from matplotlib import pyplot
import xarray as xr

In [None]:
class TimeSeries ():
    def __init__(self, input, varname, output, verbose=False,
                 config_file="", shift=False):

        li = list(input.split(","))
        if len(li) > 1:
            self.input = li
        else:
            self.input = input
        ###########################################################################################################3
        self.varname = varname
        self.shift = shift
        self.xylim_supported = False
        if output is None:
            if type(self.input) is list:
                self.output = Path(self.input[0]).stem + '.png'
            else:
                self.output = Path(self.input).stem + '.png'
        else:
            self.output = output
        self.verbose = verbose
        self.label = {}
        self.x_label = {}
        self.y_label = {}
        self.time = []
        self.xlim = []
        self.ylim = []
        self.range = []
        self.linestyle = "solid"
        self.marker = "."
        self.threshold = ""
        self.title = ""
        ################################################################################################################
         if config_file != "" and config_file is not None:
            with open(config_file) as f:
                sdict = ''.join(
                    f.read().replace("\n", "").split('{')[1].split('}')[0]
                    )
                tmp = ast.literal_eval('{' + sdict.strip() + '}')
                for key in tmp:
                    if key == 'time':  #we can have change here for the xarray mapplot
                        time = tmp[key]
                        self.time = list(map(int, time.split(",")))
                    if key == 'linestyle':
                        self.get_linestyle(tmp[key])
                    if key == 'marker':
                        self.get_marker(tmp[key])
                    if key == 'xlim':
                        xlim = tmp[key]
                        self.xlim = list(map(float, xlim.split(",")))
                    if key == 'ylim':
                        ylim = tmp[key]
                        self.ylim = list(map(float, ylim.split(",")))
                    if key == 'x_label':
                        self.x_label = tmp[key]
                    if key == 'y_label':
                        self.y_label = tmp[key]
                    if key == 'range':
                        range_values = tmp[key]
                        self.range = list(map(float, range_values.split(",")))
                    if key == 'threshold':
                        self.threshold = float(tmp[key])
                    if key == 'label':
                        self.label['label'] = tmp[key]
                    if key == 'title':
                        self.title = tmp[key]
 ###############################################################################################################################
        if type(self.input) is list:
            self.dset = xr.open_mfdataset(self.input, use_cftime=True)
        else:
            self.dset = xr.open_dataset(self.input, use_cftime=True)

        if verbose:
            print("input: ", self.input)
            print("varname: ", self.varname)
            print("time: ", self.time)
            print("minval, maxval: ", self.range)
            print("title: ", self.title)
            print("x_label: ", self.x_label)
            print("y_label: ", self.y_label)
            print("output: ", self.output)
            print("shift: ", self.shift)
            print("xlim: ", self.xlim)
            print("ylim: ", self.ylim)
            print("label: ", self.label)
            
#########################################################################################################################################
       def plot(self, ts=None):
        ''' if self.shift:
            if self.longitude == 'longitude':
                self.dset = self.dset.assign_coords(
                                 longitude=(((
                                        self.dset[self.longitude]
                                        + 180) % 360) - 180))
            elif self.longitude == 'lon':
                self.dset = self.dset.assign_coords(
                                 lon=(((self.dset[self.longitude]
                                        + 180) % 360) - 180))
'''
        pyplot.figure(1, figsize=[20, 10])

        # Set the projection to use for plotting
    ''' ax = pyplot.subplot(1, 1, 1, projection=self.projection())
        if self.land:
            ax.add_feature(feature.LAND, alpha=self.land)

        if self.ocean:
            ax.add_feature(feature.OCEAN, alpha=self.ocean)
        if self.coastline:
            ax.coastlines(resolution='10m', alpha=self.coastline)
        if self.borders:
            ax.add_feature(feature.BORDERS, linestyle=':', alpha=self.borders)'''

        if self.xlim:
            min_lon = min(self.xlim[0], self.xlim[1])
            max_lon = max(self.xlim[0], self.xlim[1])
        else:
            min_lon = self.dset[self.longitude].min()
            max_lon = self.dset[self.longitude].max()

        if self.ylim:
            min_lat = min(self.ylim[0], self.ylim[1])
            max_lat = max(self.ylim[0], self.ylim[1])
        else:
            min_lat = self.dset[self.latitude].min()
            max_lat = self.dset[self.latitude].max()

        if self.xylim_supported:
            pyplot.xlim(min_lon, max_lon)
            pyplot.ylim(min_lat, max_lat)

        # Fix extent
        if self.threshold == "" or self.threshold is None:
            threshold = self.dset[self.varname].min()
        else:
            threshold = float(self.threshold)

        if self.range == []:
            minval = self.dset[self.varname].min()
            maxval = self.dset[self.varname].max()
        else:
            minval = self.range[0]
            maxval = self.range[1]

        if self.verbose:
            print("minval: ", minval)
            print("maxval: ", maxval)

        # pass extent with vmin and vmax parameters
        if ts is None:
            self.dset.where(
                 self.dset[self.varname] > threshold
                 )[self.varname].plot(ax=ax,
                                      vmin=minval,
                                      vmax=maxval,
                                      transform=proj_t,
                                      cmap=self.cmap,
                                      cbar_kwargs=self.label
                                      )
            if self.title != "" and self.title is not None:
                pyplot.title(self.title)
            pyplot.savefig(self.output)
        else:
            if self.colorbar:
                self.dset.where(
                     self.dset[self.varname] > threshold
                     )[self.varname].isel(time=ts).plot(ax=ax,
                                                        vmin=minval,
                                                        vmax=maxval,
                                                        transform=proj_t,
                                                        cmap=self.cmap,
                                                        cbar_kwargs=self.label
                                                        )
            else:
                self.dset.where(
                     self.dset[self.varname] > minval
                     )[self.varname].isel(time=ts).plot(ax=ax,
                                                        vmin=minval,
                                                        vmax=maxval,
                                                        transform=proj_t,
                                                        cmap=self.cmap,
                                                        add_colorbar=False)
            if self.title != "" and self.title is not None:
                pyplot.title(self.title + "(time = " + str(ts) + ')')
            pyplot.savefig(self.output[:-4] + "_time" + str(ts) +
                           self.output[-4:])  # assume png format
            
#########################################################################################################################
if __name__ == '__main__':
    warnings.filterwarnings("ignore")
    parser = argparse.ArgumentParser()
    parser.add_argument(
        'input',
        help='input filename with geographical coordinates (netCDF format)'
    )
    parser.add_argument(
        '--proj',
        help='Config file with the projection on which we draw'
    )
    parser.add_argument(
        'varname',
        help='Specify which variable to plot (case sensitive)'
    )
    parser.add_argument(
        '--output',
        help='output filename to store resulting image (png format)'
    )
    parser.add_argument(
        '--config',
        help='pass plotting parameters via a config file'
    )
    parser.add_argument(
        '--shift',
        help='shift longitudes if specified',
        action="store_true"
    )
    parser.add_argument(
        "-v", "--verbose",
        help="switch on verbose mode",
        action="store_true")
    args = parser.parse_args()

    dset = MapPlotXr(input=args.input, varname=args.varname,
                     output=args.output, verbose=args.verbose,
                     config_file=args.config, proj=args.proj,
                     shift=args.shift)

    if dset.time == []:
        dset.plot()
    else:
        for t in dset.time:
            dset.plot(t)
            dset.shift = False   # only shift once
            dset.colorbar = True