In [None]:
import os
import xarray as xr
import pandas as pd
import numpy as np
import rioxarray as rxr
import rasterio
import xarray as xr
import matplotlib as mpl
import matplotlib.pyplot as plt
import import_ipynb
from datetime import datetime
import warnings
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
import sklearn.metrics as metrics
from itertools import product
from tqdm import tqdm
import folium
from IPython.display import display

warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)

In [None]:
def _get_info(col_name):

  wrf_lu_col = {
  '1 Evergreen Needleleaf Forest': '#006600',
  '2 Evergreen Broadleaf Forest': '#007700',
  '3 Deciduous Needleleaf Forest': '#33cc32',
  '4 Deciduous Broadleaf Forest': '#34cc66',
  '5 Mixed Forests': '#349933',
  '6 Closed Shrublands': '#4db31b',
  '7 Open Shurblands': '#d1691f',
  '8 Woody Savannas': '#beb569',
  '9 Savannas': '#ffd600',
  '10 Grasslands': '#00ff27',
  '11 Permanant Wetlands': '#10fdfd',
  '12 Croplands': '#745a00',
  '13 Urban and Built-up': '#ff0e00',
  '14 Cropland/Natual Vegation Mosaic': '#b1e34d',
  '15 Snow and Ice': '#e1fcff',
  '16 Barren or Sparsely Vegetated': '#e9e9b4',
  '17 Water (like oceans)': '#80b3ff',
  '18 Wooded Tundra': '#d9133a',
  '19 Mixed Tundra': '#f7804f',
  '20 Barren Tundra': '#e8967d',
  '21 Lake': '#0400de',
  }

  wrf_sim_info = {
  'NOURBAN_SIM': ('geo_em.d03_NoUrban', '#077F41', '-'),
  'LCZ_NORMS_EXTENT_SIM':  ('geo_em.d03_LCZ_extent', '#FF79AE', '-'),
  'LCZ_NORMS_SIM':  ('geo_em.d03_LCZ_params', '#BB0003', '-'),
  'LCZ_CR_SIM':  ('geo_em.d03_LCZ_params', '#00ECFF', '-'),
  'LCZ_GR_SIM':  ('geo_em.d03_LCZ_params', '#4CFFA8', '-'),
  'LCZ_PVP_SIM':  ('geo_em.d03_LCZ_params', '#FCD601', '-'),
  }

  lcz_col = {
  '31 Compact high-rise (1)': '#8C0000',
  '32 Compact midrise (2)': '#D10000',
  '33 Compact low-rise (3)': '#FF0000',
  '34 Open high-rise (4)': '#BF4D00',
  '35 Open midrise (5)': '#FF6600',
  '36 Open low-rise (6)': '#FF9955',
  '37 Lightweight low-rise (7)': '#FAEE05',
  '38 Large low-rise (8)': '#BCBCBC',
  '39 Sparsely built (9)': '#FFCCAA',
  '40 Heavy industry (10)': '#555555',
  'Dense trees (A)': '#006a00',
  'Scattered trees (B)': '#00aa00',
  'Bush, scrub (C)': '#648525',
  'Low plants (D)': '#b9db79',
  'Bare rock or paved (E)': '#000000',
  'Bare soil or sand (F)': '#fbf7ae',
  'Water (G)': '#6a6aff',
  }

  no_col = {
    '22 ': '#FFFFFF',
    '23 ': '#FFFFFF',
    '24 ': '#FFFFFF',
    '25 ': '#FFFFFF',
    '26 ': '#FFFFFF',
    '27 ': '#FFFFFF',
    '28 ': '#FFFFFF',
    '29 ': '#FFFFFF',
    '30 ': '#FFFFFF',
  }


  if col_name == "wrf":
    col_scheme = wrf_lu_col
  elif col_name == "lcz":
    col_scheme = lcz_col
  elif col_name == "wrf_lcz":
    wrf_lcz_col = {**wrf_lu_col,**no_col, **dict(list(lcz_col.items())[:10])}
    col_scheme = wrf_lcz_col
  elif col_name == "wrf_sim":
    col_scheme = wrf_sim_info

  return col_scheme

In [None]:
def plot_lu_index(FN_GEO_NAME, RUN_INFO):

  # Define the file
  WRF_FILE = os.path.join(RUN_INFO['WRF_DIR'], 'INPUT', f'{FN_GEO_NAME}.nc')

  # Open the file
  ds = xr.open_dataset(WRF_FILE)

  # Get the colors
  if 'LCZ_params' in FN_GEO_NAME:
    coldict =  _get_info('wrf_lcz') 
  else:
    coldict =  _get_info('wrf') 
  
  # Set the colors
  cols = [i for i in coldict.values()]
  cmap = mpl.colors.ListedColormap(cols)
  cmap.set_bad(color='white')
  cmap.set_under(color='white')

  f, ax = plt.subplots(1,1,figsize=(14,10))

  im = ds['LU_INDEX'].plot(
      cmap=cmap, vmin=1, vmax=len(cols),
      ax=ax, add_colorbar=False,
  )

  cb = plt.colorbar(
      im,
      ticks=np.linspace(1.5, len(cols)-0.5,len(cols)),
      orientation='vertical', pad=0.08,
  )

  col_labels = [i for i in coldict.keys()]
  cb.set_ticklabels(col_labels)
  cb.set_label(label='LU_INDEX', fontsize=11)
  cb.ax.tick_params(labelsize=11)

  ax.tick_params(axis='both', which='major', labelsize=11)
  ax.set_xlabel('')
  ax.set_ylabel('')

  if FN_GEO_NAME == 'geo_em.d03_LCZ_params':
   ax.set_title(f"DEFAULT: {int(np.sum(ds['LU_INDEX'] > 30).values)} urban pixels",
               fontsize=18)

  else:
    ax.set_title(f"DEFAULT: {int(np.sum(ds['LU_INDEX'] == 13).values)} urban pixels",
                    fontsize=18)
  
  OFILE = os.path.join(
      RUN_INFO['FIG_DIR'],
      f'{FN_GEO_NAME}_lu_index.jpg'
  )
  plt.tight_layout()
  plt.savefig(OFILE, dpi=RUN_INFO['DPI'])

  print(f"Done! Figure available at: {OFILE}")


In [None]:
def plot_urb_frac(RUN_INFO):

  print("Plotting URB_FRAC for 'LCZ_extent', 'LCZ_params', and their difference") 

  # Define the files
  WRF1_FILE = os.path.join(RUN_INFO['WRF_DIR'], 'INPUT', 'geo_em.d03_LCZ_extent.nc')
  WRF2_FILE = os.path.join(RUN_INFO['WRF_DIR'], 'INPUT', 'geo_em.d03_LCZ_params.nc')

  # Open the files
  ds1 = xr.open_dataset(WRF1_FILE)
  ds2 = xr.open_dataset(WRF2_FILE)

  # Get the colors
  cmap = plt.cm.get_cmap('inferno')
  cmap.set_bad('white')

  f, axs = plt.subplots(1,3,figsize=(23,7), sharey=True)

  ds1['FRC_URB2D'] = xr.where(ds1['LU_INDEX'] == 13, 0.9, 0)
  ds1['FRC_URB2D'][0,:,:].plot(
      vmin=0, vmax=1,
      cmap=cmap,
      ax=axs[0]
      )
  axs[0].set_title('geo_em.d03_LCZ_extent')

  ds2['FRC_URB2D'][0,:,:].plot(
      vmin=0, vmax=1,
      cmap=cmap,
      ax=axs[1]
      )
  axs[1].set_title('geo_em.d03_LCZ_params')

  (ds2['FRC_URB2D'][0,:,:]-ds1['FRC_URB2D'][0,:,:]).plot(
      vmin=-0.5, vmax=0.5, extend = 'both',
      cmap=plt.cm.get_cmap('PiYG_r'),
      ax=axs[2]
      )
  axs[2].set_title('LCZ Params - LCZ Extent')
  
  OFILE = os.path.join(
      RUN_INFO['FIG_DIR'],
      'LCZ_params_vs_LCZ_extent_urb_frac.jpg'
  )
  plt.tight_layout()
  plt.savefig(OFILE, dpi=RUN_INFO['DPI'])

  print(f"Done! Figure available at: {OFILE}")


In [None]:
def _replace_values(da, to_replace, value):

  """
  Helper function to replace LCZ classes into UCP values (in plot_ucp)

  Inspired by: https://github.com/pydata/xarray/issues/6377
  """
    
  flat = da.values.ravel()

  sorter = np.argsort(to_replace)
  insertion = np.searchsorted(to_replace, flat, sorter=sorter)
  indices = np.take(sorter, insertion, mode="clip")
  replaceable = (to_replace[indices] == flat)

  out = flat.copy()
  out[replaceable] = value[indices[replaceable]]
  return da.copy(data=out.reshape(da.shape))


In [None]:
def plot_ucp(UCP_NAME, RUN_INFO):

  ucp_values = {
      'BH': (7.5, 92, 0, 30, -15, 15), # (MODIS Value, URB PARAM index, Min, Max, deltamin, deltamax)
      'BHstd': (3.0, 93, 0, 4, -4, 4),
      'ALB_R': (0.15, 93, 0.1, 0.25, -0.08, 0.08), # Roof
      'ALB_W': (0.15, 93, 0.1, 0.25, -0.08, 0.08), # Wall
      'ALB_G': (0.15, 93, 0.1, 0.25, -0.08, 0.08), # Road
  }

  # Values available in URBPARM_LCZ.TBL: https://github.com/wrf-model/WRF/blob/master/run/URBPARM_LCZ.TBL
  URBPARM_LCZ = {
      'ALB_R': [0.13, 0.18, 0.15, 0.13, 0.13, 0.13, 0.15, 0.18, 0.13, 0.10], # LCZ 1 to 10
      'ALB_W': [0.25, 0.20, 0.20, 0.25, 0.25, 0.25, 0.20, 0.25, 0.25, 0.20], 
      'ALB_G': [0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.18, 0.14, 0.14, 0.14], 
  }
  WRF1_FILE = os.path.join(RUN_INFO['WRF_DIR'], 'INPUT', 'geo_em.d03_LCZ_extent.nc')
  WRF2_FILE = os.path.join(RUN_INFO['WRF_DIR'], 'INPUT', 'geo_em.d03_LCZ_params.nc')

  # Open the files
  ds1 = xr.open_dataset(WRF1_FILE)
  ds2 = xr.open_dataset(WRF2_FILE)

  # Get the colors
  cmap = plt.cm.get_cmap('inferno')
  cmap.set_bad('white')

  f, axs = plt.subplots(1,3,figsize=(23,7), sharey=True)

  ds1[UCP_NAME] = xr.where(ds1['LU_INDEX'] == 13, ucp_values[UCP_NAME][0], 0)
  ds1_tp = ds1[UCP_NAME][0,:,:]
  ds1_tp.plot(
      vmin=ucp_values[UCP_NAME][2], vmax=ucp_values[UCP_NAME][3],
      cmap=cmap,
      ax=axs[0]
      )
  axs[0].set_title(f'geo_em.d03_LCZ_extent (Default: {ucp_values[UCP_NAME][0]})')


  if UCP_NAME in ['BH', 'BHstd']:
    ds2_tp = ds2.copy()
    ds2_tp[UCP_NAME] = ds2['URB_PARAM'][0,ucp_values[UCP_NAME][1]-1,:,:]
  else:
    # Need to replace LCZ majority class with corresponding UCP
    ds2_tp = ds2.copy()
    ds2_tp[UCP_NAME] = xr.where(ds2_tp['LU_INDEX']>30, ds2_tp['LU_INDEX'], 0)
    ds2_tp[UCP_NAME] = _replace_values(ds2_tp[UCP_NAME],np.arange(31,41,1),np.array(URBPARM_LCZ[UCP_NAME]))

  ds2_tp[UCP_NAME].plot(
      vmin=ucp_values[UCP_NAME][2], vmax=ucp_values[UCP_NAME][3],
      cmap=cmap,
      ax=axs[1]
      )
  axs[1].set_title('geo_em.d03_LCZ_params')

  (ds2_tp[UCP_NAME]-ds1_tp).plot(
      vmin=ucp_values[UCP_NAME][4], vmax=ucp_values[UCP_NAME][5],
      extend = 'both',
      cmap=plt.cm.get_cmap('PiYG_r'),
      ax=axs[2]
      )
  axs[2].set_title('LCZ Params - LCZ Extent') 

  OFILE = os.path.join(
      RUN_INFO['FIG_DIR'],
      f'LCZ_params_vs_LCZ_extent_{UCP_NAME}.jpg'
  )
  plt.tight_layout()
  plt.savefig(OFILE, dpi=RUN_INFO['DPI'])

  print(f"Done! Figure available at: {OFILE}")

In [None]:
def _wrf_to_xr(ds):

  """
  # WRF files aren’t exactly CF compliant: you’ll need a special parser for the timestamp, 
  # the coordinate names are a bit exotic and do not correspond to the dimension names, 
  # they contain so-called staggered variables (and their correponding coordinates), etc.

  # Note that salem also has a parser function doing this.
  # Yet this will not work here, as we have removed many of the vertical layers to save space.
  # Info: https://salem.readthedocs.io/en/v0.3.7/wrf.html#wrf-tools
  """

  # Make WRF output compatible with xarray
  # Info: https://gallery.pangeo.io/repos/NCAR/notebook-gallery/notebooks/Run-Anywhere/WRF/wrf_ex.html

  # Problem #1: Time in bytes
  times = [t.decode('UTF-8') for t in list(ds.Times.data)]
  times = [datetime.strptime(t, "%Y-%m-%d_%H:%M:%S") for t in times]

  ds_time = ds.rename({'Time':'time'})
  ds_time = ds_time.assign(time=times)
  ds_time = ds_time.drop('Times')
    #print(ds_time.time)


  # Problem #2: NO COORDINATES
  # https://gallery.pangeo.io/repos/NCAR/notebook-gallery/notebooks/Run-Anywhere/WRF/wrf_ex.html#Problem-#2:-NO-COORDINATES

  ds_wrf_w_latlon = ds_time.assign_coords(
    lat=ds['XLAT'][0, :, :], 
    lon=ds['XLONG'][0, :, :],
    #landmask=ds_geo.LANDMASK.squeeze('Time'),
  )
  ds_wrf_rename_latlon = ds_wrf_w_latlon.rename({'south_north':'y', 'west_east':'x'})
  ds_wrf_dropxlatlon = ds_wrf_rename_latlon.drop(['XLAT', 'XLONG'])
  ds_wrf_dropxlatlon

  return ds_wrf_dropxlatlon

In [None]:
def _get_2d_idx_point(stnlat, stnlon, ds):

  """
  Get index of grid cell that is the nearest neighbour of a point coordinate
  """

  wrf_lat_min, wrf_lat_max = ds.lat.min(), ds.lat.max()
  wrf_lon_min, wrf_lon_max = ds.lon.min(), ds.lon.max()

  error_code = 0
  if stnlon < wrf_lon_min or stnlon > wrf_lon_max:
    print(f"ERROR: {stnlon} is outside of model domain")
    error_code = 1
  if stnlat < wrf_lat_min or stnlat > wrf_lat_max:
    print(f"ERROR: {stnlat} is outside of model domain")
    error_code = 1

  # Only continue if error_code = 0
  if error_code == 0:

    dist = (ds.lat.data-stnlat)**2 + (ds.lon.data-stnlon)**2
    #y_nr, x_nr = np.argwhere(dist == np.min(dist))
    out = np.argwhere(dist == np.min(dist))
    y_nr = out[0][0]
    x_nr = out[0][1]
    
    #print(f"x_nr: {str(x_nr)}, Lon: {float(ds.lon[y_nr,x_nr])}")
    #print(f"y_nr: {str(y_nr)}, Lat: {float(ds.lat[y_nr,x_nr])}")

    return x_nr, y_nr


In [None]:
def get_clean_cws(RUN_INFO, QC_LEVEL="o1"):

  # Get the CWS data
  # 1. Station metadata from .csv
  fn_metadata = os.path.join(RUN_INFO['CWS_DIR'], 'cws_heatwave_2019_bucss_metadata.csv')
  metadata = pd.read_csv(fn_metadata, index_col=0)

  # 2. Actual observations from .csv and set a multiindex
  fn_data = os.path.join(RUN_INFO['CWS_DIR'], 'cws_heatwave_2019_bucss_data.csv')
  data = pd.read_csv(
      fn_data,
      parse_dates=True,
      # set 1st level of multi index to intern_id, 2nd to date
      index_col=[0, 1],
  )

  # Merge the two datasources
  df = data.join(metadata)

  # remove stations with coordinates outside of WRF domain
  WRF_FILE = os.path.join(RUN_INFO['WRF_DIR'], 'INPUT', 'geo_em.d03_LCZ_params.nc')

  # Open the file
  ds = xr.open_dataset(WRF_FILE)

  # Get outer coordinates
  x_wrf_min, x_wrf_max = float(ds.XLONG_M.min()), float(ds.XLONG_M.max())
  y_wrf_min, y_wrf_max = float(ds.XLAT_M.min()), float(ds.XLAT_M.max())
  print(f'WRF boundaries (xmin, xmax, ymin, ymax): {x_wrf_min, x_wrf_max, y_wrf_min, y_wrf_max}')

  # Only select data within these boundaries
  net_sel_data = df[
  (df['lat'] > y_wrf_min) & 
  (df['lat'] < y_wrf_max) & 
  (df['lon'] > x_wrf_min) & 
  (df['lon'] < x_wrf_max) &
  (df[QC_LEVEL] == 't')
  ]

  #print(net_sel_data)
  print(f"# Unique CWS stations:\
    {len(net_sel_data.index.get_level_values('intern_id').unique())},\
    resulting in {net_sel_data.shape[0]} observations")
  
  return net_sel_data


In [None]:
def plot_LCZ_CWS_static(RUN_INFO):

  # Define and read the data
  FN_LCZ = os.path.join(
      RUN_INFO['WRF_DIR'],
      'INPUT',
      'global_lcz_filter_v1_ruhrarea.tif'
  )
  ds_lcz = rxr.open_rasterio(FN_LCZ).squeeze()

  # Crop tif according to available CWS coordinates, with small buffer
  fn_metadata = os.path.join(RUN_INFO['CWS_DIR'], 'cws_heatwave_2019_bucss_metadata.csv')
  metadata = pd.read_csv(fn_metadata, index_col=0)

  xmin_net, xmax_net = metadata.lon.min()-0.1, metadata.lon.max()+0.1
  ymin_net, ymax_net = metadata.lat.min()-0.1, metadata.lat.max()+0.1

  ds_lcz_net = ds_lcz.sel(
      x=slice(xmin_net, xmax_net), 
      y=slice(ymax_net, ymin_net)
    )

  # Set color scheme
  coldict = _get_info('lcz')
  cb_lcz = [i for i in coldict.values()]
  lcz_cmap = mpl.colors.ListedColormap(cb_lcz)
  lcz_cmap.set_bad(color='white')
  lcz_cmap.set_under(color='white')

  cb_labels = [
      '1', '2', '3', '4', '5', '6', '7', '8', '9', '10',
      'A', 'B', 'C', 'D', 'E', 'F', 'G',
  ]

  NRLCZ = 17

  f, ax = plt.subplots(figsize=(10, 10))
  im = ds_lcz_net.plot(cmap=lcz_cmap,
                  vmin=1,
                  vmax=NRLCZ,
                  add_colorbar=False,
                  ax=ax)
  cb = plt.colorbar(
      im,
      ticks=np.linspace(1.5, NRLCZ-0.5, NRLCZ),
      orientation='horizontal', pad=0.08,
  )
  cb.set_ticklabels(cb_labels)
  cb.set_label(label='LCZ Class', fontsize=11)
  cb.ax.tick_params(labelsize=11)

  ax.set_title("")
  #ax.set_axis_off()


  # Add CWS stations
  ax.scatter(metadata['lon'], metadata['lat'], s=10, marker='o', c="#FF748C")

  OFILE = os.path.join(
      RUN_INFO['FIG_DIR'],
      f'LCZ_Map_with_Netatmo_Static.jpg'
  )
  plt.savefig(OFILE, dpi=RUN_INFO['DPI'])
  print(f"Done! Figure available at: {OFILE}")

In [None]:
def plot_LCZ_CWS_interactive(RUN_INFO):

  # Define and read the data
  FN_LCZ = os.path.join(
      RUN_INFO['WRF_DIR'],
      'INPUT',
      'global_lcz_filter_v1_ruhrarea.tif'
  )
  ds_lcz = rxr.open_rasterio(FN_LCZ).squeeze()

  # Crop tif according to available CWS coordinates, with small buffer
  fn_metadata = os.path.join(RUN_INFO['CWS_DIR'], 'cws_heatwave_2019_bucss_metadata.csv')
  metadata = pd.read_csv(fn_metadata, index_col=0)

  # Set index as column, for plotting
  metadata = metadata.reset_index(level=0)

  xmin_net, xmax_net = metadata.lon.min()-0.1, metadata.lon.max()+0.1
  ymin_net, ymax_net = metadata.lat.min()-0.1, metadata.lat.max()+0.1

  ds_lcz_net = ds_lcz.sel(
      x=slice(xmin_net, xmax_net), 
      y=slice(ymax_net, ymin_net)
    )
  
  NRLCZ = 17

  # Create .png for interactive plotting
  # Set color scheme
  coldict = _get_info('lcz')
  cb_lcz = [i for i in coldict.values()]
  lcz_cmap = mpl.colors.ListedColormap(cb_lcz)
  lcz_cmap.set_bad(color='white')
  lcz_cmap.set_under(color='white')

  figsize = (ds_lcz_net.shape[0] / 100, ds_lcz_net.shape[1] / 100)

  fig, ax = plt.subplots(figsize=figsize)
  ds_lcz_net.plot(cmap=lcz_cmap, vmin=1, vmax=NRLCZ, ax=ax, add_colorbar=False)
  ax.set_title('')
  plt.axis('off')

  PNG_FILE = os.path.join(
      RUN_INFO['FIG_DIR'],
      "lcz_map.png"
  )
  plt.savefig(
      fname=PNG_FILE,
      facecolor=fig.get_facecolor(),
      transparent=True,
      dpi=RUN_INFO['DPI'],
      bbox_inches='tight',
      pad_inches=0,
  )
  plt.close('all')
  print(f"Intermediate LCZ .PNG available in {RUN_INFO['FIG_DIR']}/\nUse now to plot interactively ...")

  # Create a map , centered on study area with set zoom level
  m = folium.Map(location=[(ymin_net+ymax_net)/2, (xmin_net+xmax_net)/2],
                zoom_start = 9)

  lcz_map_bounds = [
      [float(ds_lcz_net.y.min()), float(ds_lcz_net.x.min())], 
      [float(ds_lcz_net.y.max()), float(ds_lcz_net.x.max())]
    ]

  # Overlay raster called img using add_child() function (opacity and bounding box set)
  folium.raster_layers.ImageOverlay(
      image=PNG_FILE,
      name='LCZ',
      opacity=0.7,
      bounds=lcz_map_bounds,
      interactive=True,
      cross_origin=False,
      zindex=1,
  ).add_to(m)
  folium.LayerControl().add_to(m)

  # Add the CWS stations
  metadata.apply(lambda row:folium.CircleMarker(
      location=[row["lat"], row["lon"]], 
      radius=5, 
      fill=True, color="#FF748C",fill_color="#FF748C", fill_opacity=0.5,
      popup=f"ID: {row['intern_id']}<br>Lat: {row['lat']}°N<br>Lon: {row['lon']}°E", 
  ).add_to(m), axis=1)

  # Display map 
  display(m)

  # Save the map as html
  OFILE = os.path.join(
      RUN_INFO['FIG_DIR'],
      f'LCZ_Map_with_Netatmo_Interative.html'
  )
  m.save(OFILE)

  print(f"Done! Figure available at: {OFILE}")

In [None]:
def plot_cws_grouped(NET_VAR, RUN_INFO, FONT_SIZE=14, QC_LEVEL="o1"):

  plt.rcParams.update({'font.size': FONT_SIZE})

  df = get_clean_cws(RUN_INFO, QC_LEVEL)

  # Check sunset and sunrise hours here: 
  # https://www.timeanddate.com/sun/germany/bochum?month=7&year=2019
  day_mask = (df.index.get_level_values('date').to_series().dt.hour >= 6) & (df.index.get_level_values('date').to_series().dt.hour < 22)
  night_mask = (df.index.get_level_values('date').to_series().dt.hour < 6) | (df.index.get_level_values('date').to_series().dt.hour >= 22)

  # Slice the dataframe
  idx_day = pd.IndexSlice
  df_day = df.loc[idx_day[:,day_mask],:]

  idx_night = pd.IndexSlice
  df_night = df.loc[idx_night[:,night_mask],:]

  # Get the colors
  coldict = _get_info('lcz')
  cb_lcz = [i for i in coldict.values()]

  # Start the plot
  fig, ax = plt.subplots(1,3, figsize=(20,7))

  # Loop over LCZ classes
  for lcz_i in range(1,18,1):

    df_lcz = df[df['lcz'] == lcz_i]
    df_day_lcz = df_day[df_day['lcz'] == lcz_i]
    df_night_lcz = df_night[df_night['lcz'] == lcz_i]

    # Count number of stations
    stn = len(df_lcz.index.get_level_values('intern_id').unique())
    stn_day = len(df_day_lcz.index.get_level_values('intern_id').unique())
    stn_night = len(df_night_lcz.index.get_level_values('intern_id').unique())

    # Plot all hour mean
    ax[0].errorbar(lcz_i, df_lcz[NET_VAR].mean(), 
                    df_lcz[NET_VAR].std(),
                    linestyle='None', marker='o',lw=1, markersize=10,
                    color = cb_lcz[lcz_i-1])
    ax[0].set_title('Daily mean')

    ax[0].text(lcz_i/18, 0.02, str(stn),
            verticalalignment='bottom', horizontalalignment='center',
            transform=ax[0].transAxes,
            color='purple', fontsize=11)

    # Plot day only mean
    ax[1].errorbar(lcz_i, df_day_lcz[NET_VAR].mean(), 
                    df_day_lcz[NET_VAR].std(),
                    linestyle='None', marker='o',lw=1, markersize=10,
                    color = cb_lcz[lcz_i-1])
    ax[1].set_title('Day-time')

    ax[1].text(lcz_i/18, 0.02, str(stn_day),
          verticalalignment='bottom', horizontalalignment='center',
          transform=ax[1].transAxes,
          color='purple', fontsize=11)

    # Plot night only mean
    ax[2].errorbar(lcz_i, df_night_lcz[NET_VAR].mean(), 
                    df_night_lcz[NET_VAR].std(),
                    linestyle='None', marker='o',lw=1, markersize=10,
                    color = cb_lcz[lcz_i-1])
    ax[2].set_title('Night-time')

    ax[2].text(lcz_i/18, 0.02, str(stn_night),
        verticalalignment='bottom', horizontalalignment='center',
        transform=ax[2].transAxes,
        color='purple', fontsize=11)

  for i in range(3):
      ax[i].set_xlabel('LCZ Class')
      ax[i].set_xticks(range(0,19,1))
      ax[i].set_xticklabels(range(0,19,1))
      xticks = ax[i].xaxis.get_major_ticks()
      xticks[0].label1.set_visible(False)
      xticks[-1].label1.set_visible(False)
      ax[i].grid(ls='--',color='0.8',zorder=-10)

  ax[0].set_ylabel(NET_VAR)

  plt.tight_layout()

  OFILE = os.path.join(
      RUN_INFO['FIG_DIR'],
      f'Netatmo_{NET_VAR}_perLCZ_day_night_mean_std.jpg'
  )
  plt.savefig(OFILE, dpi=RUN_INFO['DPI'])
  print(f"Done! Figure available at: {OFILE}")

In [None]:
def plot_cws_diurnal_cycle(NET_VAR, REF1, REF2, RUN_INFO, FONT_SIZE=14, QC_LEVEL="o1"):

  plt.rcParams.update({'font.size': FONT_SIZE})

  # Get the colors
  coldict = _get_info('lcz')
  cb_lcz = [i for i in coldict.values()]

  # Get the CWS data
  df = get_clean_cws(RUN_INFO, QC_LEVEL)

  # Make the plot
  fig, axes = plt.subplots(1, 3, figsize=(22,10))

  for lcz_i in range(1,18,1):

    if lcz_i < 11:
      ax_i = 0
      ax_title = 'Built LCZs'
    else:
      ax_i = 1
      ax_title = 'Land cover types'

    df_lcz = df[df['lcz'] == lcz_i]
    lcz_m = df_lcz.groupby(df_lcz.index.get_level_values('date').to_series().index.hour).mean()
    lcz_s = df_lcz.groupby(df_lcz.index.get_level_values('date').to_series().index.hour).std()

    im = axes[ax_i].plot(lcz_m.index, lcz_m[NET_VAR], color = cb_lcz[lcz_i-1], zorder=10)
    axes[ax_i].fill_between(lcz_m.index,
                    lcz_m[NET_VAR]-lcz_s[NET_VAR],
                    lcz_m[NET_VAR]+lcz_s[NET_VAR],
                    alpha=0.1, color = cb_lcz[lcz_i-1], 
                    zorder=10)
    axes[ax_i].set_title(ax_title)
    axes[ax_i].set_ylabel('T2M [°C]')
    axes[ax_i].set_xlabel('Hour of the day (UTC)')


  # Also plot UHI signal
  df_lcz_ref1 = df[df['lcz'] == REF1]
  df_lcz_ref2 = df[df['lcz'] == REF2]
  lcz_uhi = df_lcz_ref1.groupby(df_lcz_ref1.index.get_level_values('date').to_series().index.hour).mean() - \
            df_lcz_ref2.groupby(df_lcz_ref2.index.get_level_values('date').to_series().index.hour).mean()

  axes[2].plot(lcz_uhi.index, lcz_uhi[NET_VAR], color = "0.2", zorder=10)
  axes[2].set_title(f'LCZ {REF1} - LCZ {REF2}')
  axes[2].set_ylabel(r'$\Delta$ T2M [°C]')
  axes[2].set_xlabel('Hour of the day (UTC)')

  axes[0].grid(ls='--',color='0.8',zorder=-10)
  axes[1].grid(ls='--',color='0.8',zorder=-10)
  axes[2].grid(ls='--',color='0.8',zorder=-10)

  # Share y-axis between first two plots
  axes[0].set_ylim(axes[1].get_ylim())

  # Add shaded area for times where sun is down
  for i in range(3):
    axes[i].axvspan(0, 3+47/60, alpha=0.2, color='0.7', zorder=-20) # in UTC
    axes[i].axvspan(19+27/60,23, alpha=0.2, color='0.7', zorder=-20) # in UTC 

  # Add LCZ Legend
  lcz_elements = [
      Patch(facecolor=cb_lcz[key-1], edgecolor='0.5', label=f"LCZ {key}")  # , markerfacecolor=None
      for key in range(1,18,1)
  ]
  lcz_legend = axes[2].legend(handles=lcz_elements,
                        bbox_to_anchor=(1.0, 0.0), ncol=1,
                        loc='lower left', fontsize=12,
                        numpoints=1, scatterpoints=1
                        )
  plt.gca().add_artist(lcz_legend)

  plt.tight_layout()

  OFILE = os.path.join(
      RUN_INFO['FIG_DIR'],
      f'Netatmo_{NET_VAR}_perLCZ_DiurnalCycle_DeltaT{REF1}-{REF2}.jpg'
  )
  plt.savefig(OFILE, dpi=RUN_INFO['DPI'])
  print(f"Done! Figure available at: {OFILE}")


In [None]:
def _get_wrf_gridcell_coords(ds, x, y):

  """
  Helper function to get bbox of WRF grid cell, based on center coordinate lat lon value.
  """
  # Position of grid cell center, decimal degrees
  lat = float(ds.lat[y,x].lat.data)
  lon = float(ds.lon[y,x].lon.data)

  # Earth’s radius, sphere
  R= 6378137

  # offsets in meters
  dn = 500
  de = 500

  # Coordinate offsets in radians
  dLat = dn/R
  dLon = de/(R*np.cos(np.pi*lat/180))

  # OffsetPosition, decimal degrees, for both directions
  lat_min = lat - dLat * 180/np.pi
  lat_max = lat + dLat * 180/np.pi
  lon_min = lon - dLon * 180/np.pi
  lon_max = lon + dLon * 180/np.pi
  #print(lat_min, lat_max)
  #print(lon_min, lon_max)

  return lat_min, lat_max, lon_min, lon_max

In [None]:
def _get_net_id_wrf_gridcell(RUN_INFO, ds, x, y):

  # Add CWS stations, if available in WRF grid cell?
  # Use CWS metadata to find appropriate CWS station ids
  fn_metadata = os.path.join(RUN_INFO['CWS_DIR'], 'cws_heatwave_2019_bucss_metadata.csv')
  metadata = pd.read_csv(fn_metadata, index_col=0)
  #print(metadata.head(5))

  # BBOX of WRF grid cell in lat lon
  latmin, latmax, lonmin, lonmax = _get_wrf_gridcell_coords(ds, x, y)
  #print(latmin, latmax, lonmin, lonmax)

  # Get station IDs that fall in this box
  metatdata_sel = metadata.loc[
      (metadata['lat'] >= latmin) & (metadata['lat'] <= latmax) & 
      (metadata['lon'] >= lonmin) & (metadata['lon'] <= lonmax) 
  ]

  return metatdata_sel

In [None]:
def _get_stats(mod, obs):

  """
  Helper function for evaluation statistics
  """

  stats_dict = {
    'MAE': np.round(metrics.mean_absolute_error(mod, obs),2), 
    'MSE': np.round(metrics.mean_squared_error(mod, obs),2),
    'RMSE': np.round(np.sqrt(metrics.mean_squared_error(mod, obs)),2),  
    'MBE': np.round(np.mean(mod-obs),2),
    'R2': np.round(metrics.r2_score(obs, mod),2),
  }

  return stats_dict

In [None]:
def eval_T2_1px_wrf_cws_time(stnlat, stnlon, RUN_INFO, QC_LEVEL="o1", FONT_SIZE=14):

  plt.rcParams.update({'font.size': FONT_SIZE})

  wrf_dict = _get_info('wrf_sim')

  # Read the WRF data, each of the land use simulations
  WRF_FILE_NOU = os.path.join(RUN_INFO['WRF_DIR'], 'OUTPUT', 'NOURBAN_SIM.nc')
  WRF_FILE_EXT = os.path.join(RUN_INFO['WRF_DIR'], 'OUTPUT', 'LCZ_NORMS_EXTENT_SIM.nc')
  WRF_FILE_LCZ = os.path.join(RUN_INFO['WRF_DIR'], 'OUTPUT', 'LCZ_NORMS_SIM.nc')

  # Open the files
  ds_nou = xr.open_dataset(WRF_FILE_NOU)
  ds_ext = xr.open_dataset(WRF_FILE_EXT)
  ds_lcz = xr.open_dataset(WRF_FILE_LCZ)

  # Read domain info for LCZ NORMS only
  WRF_FILE_GEO_NOU = os.path.join(RUN_INFO['WRF_DIR'], 'INPUT', f"{wrf_dict['NOURBAN_SIM'][0]}.nc")
  WRF_FILE_GEO_EXT = os.path.join(RUN_INFO['WRF_DIR'], 'INPUT', f"{wrf_dict['LCZ_NORMS_EXTENT_SIM'][0]}.nc")
  WRF_FILE_GEO_LCZ = os.path.join(RUN_INFO['WRF_DIR'], 'INPUT', f"{wrf_dict['LCZ_NORMS_SIM'][0]}.nc")
  ds_nou_geo = xr.open_dataset(WRF_FILE_GEO_NOU)
  ds_ext_geo = xr.open_dataset(WRF_FILE_GEO_EXT)
  ds_lcz_geo = xr.open_dataset(WRF_FILE_GEO_LCZ)

  # Make WRF compliant with xarray
  ds_nou_xr = _wrf_to_xr(ds_nou)
  ds_ext_xr = _wrf_to_xr(ds_ext)
  ds_lcz_xr = _wrf_to_xr(ds_lcz)

  # Get index location of lat lon coordinate
  x, y = _get_2d_idx_point(stnlat, stnlon, ds_nou_xr)

  # Add CWS stations, if available in WRF grid cell?
  # Use CWS metadata to find appropriate CWS station ids
  df_net_ids = _get_net_id_wrf_gridcell(RUN_INFO, ds_nou_xr, x, y)

  # Get CWS data if stations available
  if len(df_net_ids.index) != 0:

    # Read all CWS data
    df = get_clean_cws(RUN_INFO, QC_LEVEL=QC_LEVEL)

    # CWS lines take color of corresponding LCZ
    coldict = _get_info('lcz')
    cb_lcz = [i for i in coldict.values()]

  # Plot the timeseries
  fig, ax = plt.subplots(1,1, figsize=(15,8))
  (ds_nou_xr['T2']-273.15).sel(x=x,y=y).plot(ax=ax, color='0.2', ls=":", lw=2, label="NOURBAN_SIM")
  (ds_ext_xr['T2']-273.15).sel(x=x,y=y).plot(ax=ax, color='0.2', ls="--",lw=2, label="LCZ_NORMS_EXTENT_SIM")
  (ds_lcz_xr['T2']-273.15).sel(x=x,y=y).plot(ax=ax, color='0.2', lw=2, label="LCZ_NORMS_SIM")
  ax.set_title("") # We add a custom title below

  # Plot CWS information if available
  # Also plot CWS average, as dotted purple line
  if len(df_net_ids.index) != 0:

    # Plot the CWS mean, also used for the statistics
    net_mean = df.loc[list(df_net_ids.index)]['ta_int'].groupby('date').mean().resample('1H').last()
    net_mean.plot(color="#f320f3", lw=2, ax=ax, label='Netatmo Mean')

    # Plot individual stations
    for id in list(df_net_ids.index):
      #print(df_net_ids.loc[id, 'lcz'], cb_lcz[df_net_ids.loc[id, 'lcz']-1])
      df_stn = df.loc[[id]]['ta_int'].droplevel('intern_id').resample('1H').last()
      df_stn.plot(
          ax=ax, 
          color=cb_lcz[df_net_ids.loc[id, 'lcz']-1], 
          ls=(0, (1, 1)), lw=1,
      )

    # Get statistics: MAE, MBE, RMSE, R2, for common time stamps
    wrf_nou_to_df = (ds_nou_xr['T2']-273.15).sel(x=x,y=y).to_dataframe()['T2']
    t2_join_nou = net_mean.to_frame().join(wrf_nou_to_df).dropna()
    stats_nou = _get_stats(t2_join_nou['T2'], t2_join_nou['ta_int'])
    #print(stats_nou)
    wrf_ext_to_df = (ds_ext_xr['T2']-273.15).sel(x=x,y=y).to_dataframe()['T2']
    t2_join_ext = net_mean.to_frame().join(wrf_ext_to_df).dropna()
    stats_ext = _get_stats(t2_join_ext['T2'], t2_join_ext['ta_int'])
    #print(stats_ext)
    wrf_lcz_to_df = (ds_lcz_xr['T2']-273.15).sel(x=x,y=y).to_dataframe()['T2']
    t2_join_lcz = net_mean.to_frame().join(wrf_lcz_to_df).dropna()
    stats_lcz = _get_stats(t2_join_lcz['T2'], t2_join_lcz['ta_int'])
    #print(stats_lcz)

  # Add info to title
  lu_nou_val = int(ds_nou_geo.LU_INDEX[0,y,x])
  lu_ext_val = int(ds_ext_geo.LU_INDEX[0,y,x])
  lu_lcz_val = int(ds_lcz_geo.LU_INDEX[0,y,x])
  #imp_val = np.round(float(ds_lcz_geo.FRC_URB2D[0,y,x])*100,1)

  plt.title(f"Lat: {str(np.round((ds_nou_xr['T2']-273.15).sel(x=x,y=y).lat.data,3))}°N (y={y}) " +
            f"| Lon: {str(np.round((ds_nou_xr['T2']-273.15).sel(x=x,y=y).lon.data,3))}°E (x={x}) " +
            f"| LU (NOU | EXT | LCZ): {lu_nou_val} | {lu_ext_val} | {lu_lcz_val} "+
            f"| # Netatmo: {int(len(df_net_ids.index))}",
            loc='left')

  custom_lines = [Line2D([0], [0], color='#BF40BF', lw=1, ls=':'),
                  Line2D([0], [0], color='#f320f3', lw=2),
                  Line2D([0], [0], color='0.2', lw=2, ls=':'),
                  Line2D([0], [0], color='0.2', lw=2, ls='--'),
                  Line2D([0], [0], color='0.2', lw=2)
  ]
  legend_labels = [
      'Single Netatmo', 
      'Mean Netatmo', 
      'WRF-NOURBAN_SIM',
      'WRF-LCZ_NORMS_EXTENT_SIM',
      'LCZ_NORMS_SIM',
  ]
  ax.legend(custom_lines, legend_labels,
            loc='upper center', bbox_to_anchor=(0.5, -0.2),
            ncol=5, fancybox=True, shadow=True
  )

  # #Print the statistics, if they exist.
  if len(df_net_ids.index) != 0:
    ax.text(0.01, 0.95, 'WRF:  NOU | EXT | LCZ', 
              horizontalalignment='left',
              verticalalignment='center', 
              transform=ax.transAxes,
              )
    for s_i, stat in enumerate(['MAE', 'MBE', 'RMSE', 'R2']):
      ax.text(0.01, 0.90-(0.05*s_i), f"{stat}: {stats_nou[stat]} | {stats_ext[stat]} | {stats_lcz[stat]}", 
              horizontalalignment='left',
              verticalalignment='center', 
              transform=ax.transAxes,
              )

  plt.tight_layout()

  OFILE = os.path.join(
      RUN_INFO['FIG_DIR'],
      f'plot_x{x}_y{y}_wrf_netatmo_T2.jpg'
  )
  plt.savefig(OFILE, dpi=RUN_INFO['DPI'])
  print(f"Done! Figure available at: {OFILE}")



In [None]:
def eval_vars_1px_wrf_cws_time(stnlat, stnlon, RUN_INFO, FONT_SIZE=14):

  plt.rcParams.update({'font.size': FONT_SIZE})

  wrf_dict = _get_info('wrf_sim')

  # Read the WRF data, each of the land use simulations
  WRF_FILE_NOU = os.path.join(RUN_INFO['WRF_DIR'], 'OUTPUT', 'NOURBAN_SIM.nc')
  WRF_FILE_EXT = os.path.join(RUN_INFO['WRF_DIR'], 'OUTPUT', 'LCZ_NORMS_EXTENT_SIM.nc')
  WRF_FILE_LCZ = os.path.join(RUN_INFO['WRF_DIR'], 'OUTPUT', 'LCZ_NORMS_SIM.nc')

  # Open the files
  ds_nou = xr.open_dataset(WRF_FILE_NOU)
  ds_ext = xr.open_dataset(WRF_FILE_EXT)
  ds_lcz = xr.open_dataset(WRF_FILE_LCZ)

  # Read domain info for LCZ NORMS only
  WRF_FILE_GEO_NOU = os.path.join(RUN_INFO['WRF_DIR'], 'INPUT', f"{wrf_dict['NOURBAN_SIM'][0]}.nc")
  WRF_FILE_GEO_EXT = os.path.join(RUN_INFO['WRF_DIR'], 'INPUT', f"{wrf_dict['LCZ_NORMS_EXTENT_SIM'][0]}.nc")
  WRF_FILE_GEO_LCZ = os.path.join(RUN_INFO['WRF_DIR'], 'INPUT', f"{wrf_dict['LCZ_NORMS_SIM'][0]}.nc")
  ds_nou_geo = xr.open_dataset(WRF_FILE_GEO_NOU)
  ds_ext_geo = xr.open_dataset(WRF_FILE_GEO_EXT)
  ds_lcz_geo = xr.open_dataset(WRF_FILE_GEO_LCZ)

  # Make WRF compliant with xarray
  ds_nou_xr = _wrf_to_xr(ds_nou)
  ds_ext_xr = _wrf_to_xr(ds_ext)
  ds_lcz_xr = _wrf_to_xr(ds_lcz)

  # Get index location of lat lon coordinate
  x, y = _get_2d_idx_point(stnlat, stnlon, ds_nou_xr)

  # Plot the timeseries - for the available variables
  VARS = ["SWDOWN", "GLW", "HFX", "LH", "TSK", "PBLH"]

  fig, axes = plt.subplots(len(VARS),1, figsize=(15,25), sharex=True)
  
  for v_i, VAR in enumerate(VARS):
    (ds_nou_xr[VAR]).sel(x=x,y=y).plot(ax=axes[v_i], color='0.2', ls=":", lw=2, label="NOURBAN_SIM")
    (ds_ext_xr[VAR]).sel(x=x,y=y).plot(ax=axes[v_i], color='0.2', ls="--",lw=2, label="LCZ_NORMS_EXTENT_SIM")
    (ds_lcz_xr[VAR]).sel(x=x,y=y).plot(ax=axes[v_i], color='0.2', lw=2, label="LCZ_NORMS_SIM")
    axes[v_i].set_title("") # We add a custom title below

  # Add info to title
  lu_nou_val = int(ds_nou_geo.LU_INDEX[0,y,x])
  lu_ext_val = int(ds_ext_geo.LU_INDEX[0,y,x])
  lu_lcz_val = int(ds_lcz_geo.LU_INDEX[0,y,x])

  # On the top panel
  axes[0].set_title(
      f"| Lat: {str(np.round((ds_nou_xr['T2']-273.15).sel(x=x,y=y).lat.data,3))}°N (y={y}) " +
      f"| Lon: {str(np.round((ds_nou_xr['T2']-273.15).sel(x=x,y=y).lon.data,3))}°E (x={x}) " +
      f"| LU (NOU | EXT | LCZ): {lu_nou_val} | {lu_ext_val} | {lu_lcz_val} ",
      loc='left'
      )

  custom_lines = [
      Line2D([0], [0], color='0.2', lw=2, ls=':'),
      Line2D([0], [0], color='0.2', lw=2, ls='--'),
      Line2D([0], [0], color='0.2', lw=2)
  ]
  legend_labels = [
      'WRF-NOURBAN_SIM',
      'WRF-LCZ_NORMS_EXTENT_SIM',
      'LCZ_NORMS_SIM',
  ]
  axes[5].legend(custom_lines, legend_labels,
            loc='upper center', bbox_to_anchor=(0.5, -0.3),
            ncol=3, fancybox=True, shadow=True
  )

  plt.tight_layout()

  OFILE = os.path.join(
      RUN_INFO['FIG_DIR'],
      f'plot_x{x}_y{y}_wrf_netatmo_ALL_VAR.jpg'
  )
  plt.savefig(OFILE, dpi=RUN_INFO['DPI'])
  print(f"Done! Figure available at: {OFILE}")


In [None]:
def get_error_table_wrf_vs_cws(SIM_NAME, RUN_INFO, QC_LEVEL="o1"):

  wrf_dict = _get_info('wrf_sim')

  # Read the WRF data
  WRF_FILE = os.path.join(RUN_INFO['WRF_DIR'], 'OUTPUT', f'{SIM_NAME}.nc')
  WRF_FILE_GEO = os.path.join(RUN_INFO['WRF_DIR'], 'INPUT', f'{wrf_dict[SIM_NAME][0]}.nc')

  # Open the files
  ds = xr.open_dataset(WRF_FILE)
  ds_geo = xr.open_dataset(WRF_FILE_GEO)

  # Make WRF compliant with xarray
  ds_xr = _wrf_to_xr(ds)

  # Initialize empty dataframe
  df_error = pd.DataFrame(
      columns=['x', 'y', 'Lon', 'Lat', 
               'NrNetatmo', 
               'LU_INDEX', 'URB_FRAC', 'HGT_M',
               'RMSE', 'MBE', 'R2']
  )

    # Read all CWS data
  df = get_clean_cws(RUN_INFO, QC_LEVEL=QC_LEVEL)

  xrange = ds_xr.x.data 
  yrange = ds_xr.y.data

  #xrange = range(40)
  #yrange = range(40)

  total_len = len(xrange) * len(yrange)
  for ix, iy in tqdm(product(xrange, yrange), total=total_len):

    # Get pixel value WRF
    wrf_df = (ds_xr['T2']-273.15).sel(x=ix,y=iy).to_dataframe()['T2']

    # Get imperviousnesss
    if SIM_NAME == 'NOURBAN_SIM':
      imp_val = 0
    elif SIM_NAME == 'LCZ_NORMS_EXTENT_SIM':
      imp_val = 0.9
    else:
      imp_val = np.round(float(ds_geo.FRC_URB2D[0,iy,ix]),3)

    lu_val = int(ds_geo.LU_INDEX[0,iy,ix])
    hgt_val = float(ds_geo.HGT_M[0,iy,ix])

    # Get coordinates WRF
    wrf_lat = float(ds_xr.lat[iy, ix])
    wrf_lon = float(ds_xr.lon[iy, ix])

    # Get CWS ids
    df_net_ids = _get_net_id_wrf_gridcell(RUN_INFO, ds_xr, ix, iy)

    # Only compute when CWS available
    if len(df_net_ids) > 0:

      net_mean = df.loc[list(df_net_ids.index)]['ta_int'].groupby('date').mean()

      t2_join = net_mean.to_frame().join(wrf_df).dropna()

      # Joined series might be empty, if CWS series is too short.
      if not t2_join.empty:
        stats = _get_stats(t2_join['T2'], t2_join['ta_int'])

        df_line = pd.DataFrame({
            'x': ix,
            'y': iy,
            'Lon': [wrf_lat],
            'Lat': [wrf_lon], 
            'NrNetatmo': len(df_net_ids),
            'LU_INDEX': int(lu_val),
            'URB_FRAC': np.round(imp_val*100,2),
            'HGT_M': np.round(hgt_val,2),
            'RMSE': [stats['RMSE']], 
            'MBE': [stats['MBE']], 
            'R2': [stats['R2']],
        })

        # Add to dataframe
        df_error = df_error.append(df_line, ignore_index = True)
    
    # else:
    #   print(f"Skipping x={ix} and y={iy}, no CWS stations available.")

  # Save to .csv
  OFILE = os.path.join(
      RUN_INFO['OUT_DIR'],
      f"WRF_{SIM_NAME}_Netatmo_stats_alldomain.csv",
  )
  df_error.to_csv(OFILE, index=False)
  print(f"File available at: {OFILE}")

In [None]:
def plot_error_wrf_cws_spatially(ERROR_VAR, SIM_NAME, RUN_INFO, FONT_SIZE=14):

  """
  Read the error matrix and plot spatially
  Size of the error dot scaled according to # CWS stations.
  """

  wrf_info = _get_info("wrf_sim")

  # Read the lower boundary conditions
  WRF_FILE = os.path.join(RUN_INFO['WRF_DIR'], 'INPUT', f'{wrf_info[SIM_NAME][0]}.nc')
  ds = xr.open_dataset(WRF_FILE)

  # Read the error matrix
  fn_error = os.path.join(
      RUN_INFO['OUT_DIR'],
      f"WRF_{SIM_NAME}_Netatmo_stats_alldomain.csv"
  )
  df = pd.read_csv(fn_error)

  # Plot results on map
  plt.rcParams.update({'font.size': FONT_SIZE})

  f, ax = plt.subplots(1,1,figsize=(20,20))

  im_hgt = ds['HGT_M'][0,:,:].plot.contour(
      vmin=0, vmax=500,
      cmap=plt.cm.get_cmap('gist_earth'),
      ax=ax,
      add_colorbar=False,
      levels=np.arange(0, 501, 50),
  )
  ax.clabel(im_hgt, inline=True, fontsize=FONT_SIZE)

  if SIM_NAME == 'NOURBAN_SIM':
    ds['FRC_URB2D'] = xr.where(ds['LU_INDEX'] == 13, 0, 0)
  elif SIM_NAME == 'LCZ_NORMS_EXTENT_SIM':
    ds['FRC_URB2D'] = xr.where(ds['LU_INDEX'] == 13, 0.9, 0)

  ds['FRC_URB2D'][0,:,:].plot(
        vmin=0, vmax=1,
        cmap=plt.cm.get_cmap('Greys'),
        add_colorbar=True,
        cbar_kwargs={'orientation': 'horizontal', 'shrink':0.6, 'pad': -0.05},
        ax=ax,
    )


  # Add the errors from the table
  # error range dict + colors
  error_range_dict = {
      'RMSE': (0, 5, plt.cm.get_cmap('YlOrBr'), 'max'),
      'MBE': (-3, 3, plt.cm.get_cmap('RdYlBu_r'), 'both'),
      'R2': (0.5, 1, plt.cm.get_cmap('YlOrBr'), 'min'),
  }

  # Plot all CWS stations by looping
  #for i in range(20):
  for i in range(df.shape[0]):
    sc = ax.scatter(
        df.iloc[i]['x'],
        df.iloc[i]['y'], 
        s=df.iloc[i]['NrNetatmo']*20,
        c=df.iloc[i][ERROR_VAR],
        cmap=error_range_dict[ERROR_VAR][2],
        vmin=error_range_dict[ERROR_VAR][0],
        vmax=error_range_dict[ERROR_VAR][1],
        )

  # Add a title
  ax.set_title(f"{ERROR_VAR} (WRF {SIM_NAME} vs Netatmo)")

  # Add colorbar for error metrix
  plt.colorbar(sc, orientation = 'horizontal', shrink = 0.6, pad=0.05,
               label=f"{ERROR_VAR} (WRF vs Netatmo)", 
               extend=error_range_dict[ERROR_VAR][3])

  plt.tight_layout()

  OFILE = os.path.join(
      RUN_INFO['FIG_DIR'],
      f'Map_WRF_{SIM_NAME}_Netatmo_{ERROR_VAR}_alldomain.jpg'
  )
  plt.savefig(OFILE, dpi=RUN_INFO['DPI'])
  print(f"Done! Figure available at: {OFILE}")


In [None]:
def plot_error_wrf_cws_bxp_lu_urb_hgt(ERROR_VAR, RUN_INFO, FONT_SIZE=14):

  """
  Plot error metric stratified per LCZ | HGT | URB_FRC2D
  """

  unit_dict = {
    'RMSE': "°C",
    'R2': "-",
    'MBE': "°C",
  }

  # Plot results on map
  plt.rcParams.update({'font.size': FONT_SIZE})

  fig, axes = plt.subplots(3,3,figsize=(15,20), sharey=True)

  SIM_NAMES = [
    'NOURBAN_SIM',
    'LCZ_NORMS_EXTENT_SIM',
    'LCZ_NORMS_SIM',
  ]
  for s_i, SIM_NAME in enumerate(SIM_NAMES):

    # Read the error matrix
    fn_error = os.path.join(
        RUN_INFO['OUT_DIR'],
        f"WRF_{SIM_NAME}_Netatmo_stats_alldomain.csv"
    )
    df = pd.read_csv(fn_error)
    
    df.boxplot(
        column=ERROR_VAR,
        by='LU_INDEX', 
        rot=90, 
        showfliers=False, 
        whis=[5,95], 
        showmeans=True,
        boxprops=dict(facecolor='0.5', color='0.5'),
        whiskerprops=dict(color='0.65', linestyle='-'),
        capprops = dict(color='0.65', linestyle='-'),
        meanprops = dict(marker='o', markerfacecolor='0.2', markeredgecolor='0.2'),
        medianprops=dict(color='0.8', lw=2),
        patch_artist=True,
        ax=axes[s_i, 0],
    )
    axes[s_i,0].set_ylabel(f"{ERROR_VAR} [{unit_dict[ERROR_VAR]}]")
    axes[s_i,0].grid(color='0.8', linestyle=':', linewidth=1)    
    axes[s_i,0].set_title("")
    if ERROR_VAR == 'MBE':
      axes[s_i,0].axhline(0, color='0.2', ls=":", lw=2)

    df['URB_FRAC_GROUP'] = pd.cut(df['URB_FRAC'], bins=np.arange(0, 100.1,10), include_lowest=True)
    df.boxplot(
        column=ERROR_VAR,
        by='URB_FRAC_GROUP', 
        rot=90, 
        showfliers=False, 
        whis=[5,95], 
        showmeans=True,
        boxprops=dict(facecolor='0.5', color='0.5'),
        whiskerprops=dict(color='0.65', linestyle='-'),
        capprops = dict(color='0.65', linestyle='-'),
        meanprops = dict(marker='o', markerfacecolor='0.2', markeredgecolor='0.2'),
        medianprops=dict(color='0.8', lw=2),
        patch_artist=True,
        ax=axes[s_i, 1],
    )
    axes[s_i,1].grid(color='0.8', linestyle=':', linewidth=1)
    axes[s_i,1].set_title(f"WRF {SIM_NAME} vs Netatmo")
    if ERROR_VAR == 'MBE':
      axes[s_i,1].axhline(0, color='0.2', ls=":", lw=2)

    df['HGT_M_GROUP'] = pd.cut(df['HGT_M'], bins=np.arange(0, 500.1,50))
    df.boxplot(
        column=ERROR_VAR,
        by='HGT_M_GROUP',
        rot=90, 
        showfliers=False, 
        whis=[5,95], 
        showmeans=True,
        boxprops=dict(facecolor='0.5', color='0.5'),
        whiskerprops=dict(color='0.65', linestyle='-'),
        capprops = dict(color='0.65', linestyle='-'),
        meanprops = dict(marker='o', markerfacecolor='0.2', markeredgecolor='0.2'),
        medianprops=dict(color='0.8', lw=2),
        patch_artist=True,
        ax=axes[s_i, 2],
    )

    axes[s_i, 2].grid(color='0.8', linestyle=':', linewidth=1)
    axes[s_i, 2].set_title("")
    if ERROR_VAR == 'MBE':
      axes[s_i, 2].axhline(0, color='0.2', ls=":", lw=2)

    fig1 = axes[s_i, 1].get_figure()
    fig1.suptitle("")
    fig2 = axes[s_i, 2].get_figure()
    fig2.suptitle("")

  plt.tight_layout()

  OFILE = os.path.join(
      RUN_INFO['FIG_DIR'],
      f'WRF_vs_Netatmo_{ERROR_VAR}_BXP_LU_URB_HGT.jpg'
  )
  plt.savefig(OFILE, dpi=RUN_INFO['DPI'])
  print(f"Done! Figure available at: {OFILE}")


In [None]:
def plot_vars_1px_wrf_rms_time(stnlat, stnlon,  
                      RUN_INFO,
                      FONT_SIZE=14):
  
  plt.rcParams.update({'font.size': FONT_SIZE})

  wrf_dict = _get_info('wrf_sim')

  var_dict = {
    'SWDOWN': r"W/m$^2$",
    'GLW': r"W/m$^2$",
    'HFX': r"W/m$^2$",
    'LH': r"W/m$^2$",
    'CM_AC_URB3D': r"W/m$^2$",
    'EP_PV_URB3D': r"W/m$^2$",
    'TSK': "°C",
    'T2': "°C",
  }

  # Read the WRF data, each of the land use simulations
  WRF_NORM = os.path.join(RUN_INFO['WRF_DIR'], 'OUTPUT', 'LCZ_NORMS_SIM.nc')
  WRF_CR = os.path.join(RUN_INFO['WRF_DIR'], 'OUTPUT', 'LCZ_CR_SIM.nc')
  WRF_GR = os.path.join(RUN_INFO['WRF_DIR'], 'OUTPUT', 'LCZ_GR_SIM.nc')
  WRF_PVP = os.path.join(RUN_INFO['WRF_DIR'], 'OUTPUT', 'LCZ_PVP_SIM.nc')

  # Open the files
  ds_norm = xr.open_dataset(WRF_NORM)
  ds_cr = xr.open_dataset(WRF_CR)
  ds_gr = xr.open_dataset(WRF_GR)
  ds_pvp = xr.open_dataset(WRF_PVP)

  # Read domain info for LCZ NORMS only
  WRF_FILE_GEO_LCZ = os.path.join(RUN_INFO['WRF_DIR'], 'INPUT', f"{wrf_dict['LCZ_NORMS_SIM'][0]}.nc")
  ds_geo = xr.open_dataset(WRF_FILE_GEO_LCZ)

  # Make WRF compliant with xarray
  ds_norm_xr = _wrf_to_xr(ds_norm)
  ds_cr_xr = _wrf_to_xr(ds_cr)
  ds_gr_xr = _wrf_to_xr(ds_gr)
  ds_pvp_xr = _wrf_to_xr(ds_pvp)

  # Get index location of lat lon coordinate
  x, y = _get_2d_idx_point(stnlat, stnlon, ds_norm_xr)

  # Plot the timeseries - for the available variables

  fig, axes = plt.subplots(len(var_dict.keys()),1, figsize=(12,25), sharex=True)

  for v_i, VAR in enumerate(var_dict.keys()):

    if VAR in ['T2', 'TSK']:
      var_scale = 273.15
    else:
      var_scale = 0

    (ds_norm_xr[VAR]-var_scale).sel(x=x,y=y).plot(ax=axes[v_i], color='#FF5733', lw=2)
    (ds_cr_xr[VAR]-var_scale).sel(x=x,y=y).plot(ax=axes[v_i], color='#5968D1', lw=2)
    (ds_gr_xr[VAR]-var_scale).sel(x=x,y=y).plot(ax=axes[v_i], color='#56B561', lw=2)
    (ds_pvp_xr[VAR]-var_scale).sel(x=x,y=y).plot(ax=axes[v_i], color='#1E3220', lw=2)
    axes[v_i].set_title("") # We add a custom title below
    axes[v_i].set_ylabel(f"{VAR} [{var_dict[VAR]}]")

  # Add info to title
  lu_lcz_val = int(ds_geo.LU_INDEX[0,y,x])
  imp_val = np.round((float(ds_geo.FRC_URB2D[0,y,x]))*100,2)

  # On the top panel
  axes[0].set_title(
      f"Lat: {str(np.round((ds_norm_xr['T2']).sel(x=x,y=y).lat.data,3))}°N (y={y}) " +
      f"| Lon: {str(np.round((ds_norm_xr['T2']).sel(x=x,y=y).lon.data,3))}°E (x={x}) " +
      f"| LU (LCZ): {str(lu_lcz_val)} " +
      f"| Imperviousness: {str(imp_val)}%",
      loc='left'
      )

  custom_lines = [
      Line2D([0], [0], color='#FF5733', lw=2),
      Line2D([0], [0], color='#5968D1', lw=2),
      Line2D([0], [0], color='#56B561', lw=2),
      Line2D([0], [0], color='#1E3220', lw=2),
  ]
  legend_labels = [
      'LCZ_NORMS_SIM',
      'LCZ_CR_SIM',
      'LCZ_GR_SIM',
      'LCZ_PVP_SIM',
  ]
  axes[len(var_dict.keys())-1].legend(custom_lines, legend_labels,
            loc='upper center', bbox_to_anchor=(0.5, -0.5),
            ncol=4, fancybox=True, shadow=True
  )

  plt.tight_layout()

  OFILE = os.path.join(
      RUN_INFO['FIG_DIR'],
      f'plot_vars_x{x}_y{y}_WRF_RMS_time.jpg'
  )
  plt.savefig(OFILE, dpi=RUN_INFO['DPI'])
  print(f"Done! Figure available at: {OFILE}")

In [None]:
def plot_vars_1px_wrf_rms_time_diff(stnlat, stnlon,  
                      RUN_INFO,
                      FONT_SIZE=14):
  
  plt.rcParams.update({'font.size': FONT_SIZE})

  wrf_dict = _get_info('wrf_sim')

  var_dict = {
    'SWDOWN': r"W/m$^2$",
    'GLW': r"W/m$^2$",
    'HFX': r"W/m$^2$",
    'LH': r"W/m$^2$",
    'CM_AC_URB3D': r"W/m$^2$",
    'EP_PV_URB3D': r"W/m$^2$",
    'TSK': "°C",
    'T2': "°C",
  }

  # Read the WRF data, each of the land use simulations
  WRF_NORM = os.path.join(RUN_INFO['WRF_DIR'], 'OUTPUT', 'LCZ_NORMS_SIM.nc')
  WRF_CR = os.path.join(RUN_INFO['WRF_DIR'], 'OUTPUT', 'LCZ_CR_SIM.nc')
  WRF_GR = os.path.join(RUN_INFO['WRF_DIR'], 'OUTPUT', 'LCZ_GR_SIM.nc')
  WRF_PVP = os.path.join(RUN_INFO['WRF_DIR'], 'OUTPUT', 'LCZ_PVP_SIM.nc')

  # Open the files
  ds_norm = xr.open_dataset(WRF_NORM)
  ds_cr = xr.open_dataset(WRF_CR)
  ds_gr = xr.open_dataset(WRF_GR)
  ds_pvp = xr.open_dataset(WRF_PVP)

  # Read domain info for LCZ NORMS only
  WRF_FILE_GEO_LCZ = os.path.join(RUN_INFO['WRF_DIR'], 'INPUT', f"{wrf_dict['LCZ_NORMS_SIM'][0]}.nc")
  ds_geo = xr.open_dataset(WRF_FILE_GEO_LCZ)

  # Make WRF compliant with xarray
  ds_norm_xr = _wrf_to_xr(ds_norm)
  ds_cr_xr = _wrf_to_xr(ds_cr)
  ds_gr_xr = _wrf_to_xr(ds_gr)
  ds_pvp_xr = _wrf_to_xr(ds_pvp)

  # Get index location of lat lon coordinate
  x, y = _get_2d_idx_point(stnlat, stnlon, ds_norm_xr)

  # Plot the timeseries - for the available variables

  fig, axes = plt.subplots(len(var_dict.keys()),1, figsize=(12,25), sharex=True)

  for v_i, VAR in enumerate(var_dict.keys()):

    if VAR in ['T2', 'TSK']:
      var_scale = 273.15
    else:
      var_scale = 0

    ds_cr_norm = (ds_cr_xr[VAR]-var_scale).sel(x=x,y=y) - (ds_norm_xr[VAR]-var_scale).sel(x=x,y=y)
    ds_gr_norm = (ds_gr_xr[VAR]-var_scale).sel(x=x,y=y) - (ds_norm_xr[VAR]-var_scale).sel(x=x,y=y)
    ds_pvp_norm = (ds_pvp_xr[VAR]-var_scale).sel(x=x,y=y) - (ds_norm_xr[VAR]-var_scale).sel(x=x,y=y)

    ds_cr_norm.plot(ax=axes[v_i], color='#5968D1', lw=2)
    ds_gr_norm.plot(ax=axes[v_i], color='#56B561', lw=2)
    ds_pvp_norm.plot(ax=axes[v_i], color='#1E3220', lw=2)
    
    axes[v_i].set_title("") # We add a custom title below
    axes[v_i].set_ylabel(f"$\Delta$ {VAR} [{var_dict[VAR]}]")
    axes[v_i].axhline(0, color='0.2', ls=":", lw=2)

  # Add info to title
  lu_lcz_val = int(ds_geo.LU_INDEX[0,y,x])
  imp_val = np.round((float(ds_geo.FRC_URB2D[0,y,x]))*100,2)

  # On the top panel
  axes[0].set_title(
      f"Lat: {str(np.round((ds_norm_xr['T2']).sel(x=x,y=y).lat.data,3))}°N (y={y}) " +
      f"| Lon: {str(np.round((ds_norm_xr['T2']).sel(x=x,y=y).lon.data,3))}°E (x={x}) " +
      f"| LU (LCZ): {str(lu_lcz_val)} " +
      f"| Imperviousness: {str(imp_val)}%",
      loc='left'
      )

  custom_lines = [
      Line2D([0], [0], color='#5968D1', lw=2),
      Line2D([0], [0], color='#56B561', lw=2),
      Line2D([0], [0], color='#1E3220', lw=2),
  ]
  legend_labels = [
      'LCZ_CR_SIM',
      'LCZ_GR_SIM',
      'LCZ_PVP_SIM',
  ]
  axes[len(var_dict.keys())-1].legend(custom_lines, legend_labels,
            loc='upper center', bbox_to_anchor=(0.5, -0.5),
            ncol=3, fancybox=True, shadow=True
  )

  plt.tight_layout()

  OFILE = os.path.join(
      RUN_INFO['FIG_DIR'],
      f'plot_vars_x{x}_y{y}_WRF_RMS_time_diff.jpg'
  )
  plt.savefig(OFILE, dpi=RUN_INFO['DPI'])
  print(f"Done! Figure available at: {OFILE}")

In [None]:
def map_var_rms(VAR, SIM_NAME1, SIM_NAME2, RUN_INFO, FONT_SIZE=14):

  plt.rcParams.update({'font.size': FONT_SIZE})

  wrf_dict = _get_info('wrf_sim')

  var_dict = {
    'SWDOWN': r"W/m$^2$",
    'GLW': r"W/m$^2$",
    'HFX': r"W/m$^2$",
    'LH': r"W/m$^2$",
    'CM_AC_URB3D': r"W/m$^2$",
    'EP_PV_URB3D': r"W/m$^2$",
    'TSK': "°C",
    'T2': "°C",
  }

  # Read the WRF data, each of the land use simulations
  WRF_FILE1 = os.path.join(RUN_INFO['WRF_DIR'], 'OUTPUT', f'{SIM_NAME1}.nc')
  WRF_FILE2 = os.path.join(RUN_INFO['WRF_DIR'], 'OUTPUT', f'{SIM_NAME2}.nc')

  # Open the files
  ds_sim1 = xr.open_dataset(WRF_FILE1)
  ds_sim2 = xr.open_dataset(WRF_FILE2)

  # Make WRF compliant with xarray - heat-wave period only
  ds1_xr = _wrf_to_xr(ds_sim1).sel(time=slice('2019-07-21', '2019-07-27'))
  ds2_xr = _wrf_to_xr(ds_sim2).sel(time=slice('2019-07-21', '2019-07-27'))

  fig, axes = plt.subplots(3,3,figsize=(15,15), sharey=True, sharex=True)

  # Plot the data: all day
  if VAR in ['T2', 'TSK']:
    ds1_tp = ds1_xr[VAR].mean(axis=0) - 273.15
    ds2_tp = ds2_xr[VAR].mean(axis=0) - 273.15
  else:
    ds1_tp = ds1_xr[VAR].mean(axis=0)
    ds2_tp = ds2_xr[VAR].mean(axis=0)

  ds_diff_tp = ds2_tp - ds1_tp

  # Set vmin and vmax from data
  vmin = np.min([float(ds1_tp.min()), float(ds2_tp.min())])
  vmax = np.max([float(ds1_tp.max()), float(ds2_tp.max())])

  vmindiff = float(ds_diff_tp.min())
  vmaxdiff = float(ds_diff_tp.max())
  if abs(vmindiff) < vmaxdiff:
    vmindiff = -vmaxdiff
  else:
    vmaxdiff = abs(vmindiff)

  ds1_tp.plot(x='lon', y='lat', ax=axes[0,0], 
                vmin=vmin, vmax=vmax, cmap=plt.cm.get_cmap('YlOrRd'),
                    cbar_kwargs={'label': f"{VAR} [{var_dict[VAR]}]"})
  ds2_tp.plot(x='lon', y='lat', ax=axes[0,1], 
                vmin=vmin, vmax=vmax, cmap=plt.cm.get_cmap('YlOrRd'),
                    cbar_kwargs={'label': f"{VAR} [{var_dict[VAR]}]"})
  ds_diff_tp.plot(x='lon', y='lat', ax=axes[0,2], 
                    vmin=vmindiff, vmax=vmaxdiff, cmap=plt.cm.get_cmap('RdYlBu_r'),
                    cbar_kwargs={'label': f"$\Delta${VAR} [{var_dict[VAR]}]"})

  # Plot the data: day-time
  ds1_xr_day = ds1_xr[VAR][(ds1_xr.time.dt.hour >= 6) & (ds1_xr.time.dt.hour < 22),:,:]
  ds2_xr_day = ds2_xr[VAR][(ds2_xr.time.dt.hour >= 6) & (ds2_xr.time.dt.hour < 22),:,:]

  if VAR in ['T2', 'TSK']:
    ds1_day_tp = ds1_xr_day.mean(axis=0) - 273.15
    ds2_day_tp = ds2_xr_day.mean(axis=0) - 273.15
  else:
    ds1_day_tp = ds1_xr_day.mean(axis=0)
    ds2_day_tp = ds2_xr_day.mean(axis=0)

  ds_day_diff_tp = ds2_day_tp - ds1_day_tp

  # Set vmin and vmax from data
  vmin = np.min([float(ds1_day_tp.min()), float(ds2_day_tp.min())])
  vmax = np.max([float(ds1_day_tp.max()), float(ds2_day_tp.max())])

  vmindiff = float(ds_day_diff_tp.min())
  vmaxdiff = float(ds_day_diff_tp.max())
  if abs(vmindiff) < vmaxdiff:
    vmindiff = -vmaxdiff
  else:
    vmaxdiff = abs(vmindiff)

  ds1_day_tp.plot(x='lon', y='lat', ax=axes[1,0], 
                vmin=vmin, vmax=vmax, cmap=plt.cm.get_cmap('YlOrRd'),
                    cbar_kwargs={'label': f"{VAR} [{var_dict[VAR]}]"})
  ds2_day_tp.plot(x='lon', y='lat', ax=axes[1,1], 
                vmin=vmin, vmax=vmax, cmap=plt.cm.get_cmap('YlOrRd'),
                    cbar_kwargs={'label': f"{VAR} [{var_dict[VAR]}]"})
  ds_day_diff_tp.plot(x='lon', y='lat', ax=axes[1,2], 
                    vmin=vmindiff, vmax=vmaxdiff, cmap=plt.cm.get_cmap('RdYlBu_r'),
                    cbar_kwargs={'label': f"$\Delta${VAR} [{var_dict[VAR]}]"})

  # Plot the data: night-time
  ds1_xr_night = ds1_xr[VAR][(ds1_xr.time.dt.hour < 6) | (ds1_xr.time.dt.hour > 22),:,:]
  ds2_xr_night = ds2_xr[VAR][(ds2_xr.time.dt.hour < 6) | (ds2_xr.time.dt.hour > 22),:,:]

  if VAR in ['T2', 'TSK']:
    ds1_night_tp = ds1_xr_night.mean(axis=0) - 273.15
    ds2_night_tp = ds2_xr_night.mean(axis=0) - 273.15
  else:
    ds1_night_tp = ds1_xr_night.mean(axis=0)
    ds2_night_tp = ds2_xr_night.mean(axis=0)

  ds_night_diff_tp = ds2_night_tp - ds1_night_tp

  # Set vmin and vmax from data
  vmin = np.min([float(ds1_night_tp.min()), float(ds2_night_tp.min())])
  vmax = np.max([float(ds1_night_tp.max()), float(ds2_night_tp.max())])

  vmindiff = float(ds_night_diff_tp.min())
  vmaxdiff = float(ds_night_diff_tp.max())
  if abs(vmindiff) < vmaxdiff:
    vmindiff = -vmaxdiff
  else:
    vmaxdiff = abs(vmindiff)

  ds1_night_tp.plot(x='lon', y='lat', ax=axes[2,0], 
                vmin=vmin, vmax=vmax, cmap=plt.cm.get_cmap('YlOrRd'),
                    cbar_kwargs={'label': f"{VAR} [{var_dict[VAR]}]"})
  ds2_night_tp.plot(x='lon', y='lat', ax=axes[2,1], 
                vmin=vmin, vmax=vmax, cmap=plt.cm.get_cmap('YlOrRd'),
                    cbar_kwargs={'label': f"{VAR} [{var_dict[VAR]}]"})
  ds_night_diff_tp.plot(x='lon', y='lat', ax=axes[2,2], 
                    vmin=vmindiff, vmax=vmaxdiff, cmap=plt.cm.get_cmap('RdYlBu_r'),
                    cbar_kwargs={'label': f"$\Delta${VAR} [{var_dict[VAR]}]"})

  # Add aesthetics
  # Add title
  axes[0,0].set_title(f"{SIM_NAME1}", fontsize=FONT_SIZE-2)
  axes[0,1].set_title(f"{SIM_NAME2}", fontsize=FONT_SIZE-2)
  axes[0,2].set_title(f"{SIM_NAME2} - {SIM_NAME1}", fontsize=FONT_SIZE-2)

  # Remove labels x- and y-axis
  for i in range(3):
    for j in range(3):
      axes[j,i].set_ylabel('')
      axes[j,i].set_xlabel('')

  axes[0,0].set_ylabel('DAY (24h)')
  axes[1,0].set_ylabel('DAY-TIME')
  axes[2,0].set_ylabel('NIGHT-TIME')

  plt.tight_layout()

  OFILE = os.path.join(
      RUN_INFO['FIG_DIR'],
      f'Map_WRF_RMS_{VAR}.jpg'
  )
  plt.savefig(OFILE, dpi=RUN_INFO['DPI'])
  print(f"Done! Figure available at: {OFILE}")

In [None]:
def _store_diurnal_lu_T_for_RMS_to_table(VAR, SIM_NAME, RUN_INFO):

  sim_dict = _get_info('wrf_sim')

  # Read the WRF data, focus on the heat-wave event only
  WRF_FILE = os.path.join(RUN_INFO['WRF_DIR'], 'OUTPUT', f'{SIM_NAME}.nc')
  WRF_FILE_GEO = os.path.join(RUN_INFO['WRF_DIR'], 'INPUT', f'{sim_dict[SIM_NAME][0]}.nc')

  ds_sim = xr.open_dataset(WRF_FILE)
  ds_xr = _wrf_to_xr(ds_sim).sel(time=slice('2019-07-21', '2019-07-27'))
  ds_geo = xr.open_dataset(WRF_FILE_GEO)

  # Available LCZ classes?
  lu_idx = np.unique(ds_geo['LU_INDEX'].data)
  lcz_idx = [i for i in lu_idx if i > 30]

  # Create a dataframe to store values in.
  # one for the mean, one for the standard deviation
  df_mean = pd.DataFrame(
      index = [0] + lcz_idx, # Cluster all natural classes under index 0
      columns = range(0, 24, 1),
  ) 
  df_std = pd.DataFrame(
      index = [0] + lcz_idx, # Cluster all natural classes under index 0
      columns = range(0, 24, 1),
  )

  # Make mean and std fields for the simulation, per hour, per LU index.
  total_len = len(range(0, 24, 1)) * len(list(df_mean.index))
  for h_lu_i in tqdm(product(range(0, 24, 1), list(df_mean.index)),
                    total=total_len, desc=f"Processing {VAR} per HOUR and LU for {SIM_NAME}"):

    # Get the LU_INDEX mask
    if h_lu_i[1] == 0:
      lu_mask = ds_geo['LU_INDEX'][0,:,:] < 31
    else:
      lu_mask = ds_geo['LU_INDEX'][0,:,:] == h_lu_i[1]
    
    # Print # LU pixels per class, for first hour.
    # if h_lu_i[0] == 0:
    #   print(f"# LU = {h_lu_i[1]} Pixels: {int(np.sum(lu_mask))}")

    # Subset the dataarray
    mean_h_lu = np.nanmean(ds_xr[VAR].where(ds_xr.time.dt.hour == h_lu_i[0]).data[:, lu_mask])
    std_h_lu = np.nanstd(ds_xr[VAR].where(ds_xr.time.dt.hour == h_lu_i[0]).data[:, lu_mask])

    if VAR in ['T2', 'TSK']:
      mean_h_lu = mean_h_lu - 273.15

    # Add to dataframe, to use for the plot
    df_mean.loc[h_lu_i[1], h_lu_i[0]] = float(mean_h_lu)
    df_std.loc[h_lu_i[1], h_lu_i[0]] = float(std_h_lu)

  # Save as .csv to output/ folder when done.
  OFILE_MEAN = os.path.join(
      RUN_INFO['OUT_DIR'],
      f"WRF_{SIM_NAME}_{VAR}_H_LU_mean.csv",
  )
  df_mean.to_csv(OFILE_MEAN)
  OFILE_STD = os.path.join(
      RUN_INFO['OUT_DIR'],
      f"WRF_{SIM_NAME}_{VAR}_H_LU_std.csv",
  )
  df_std.to_csv(OFILE_STD)
  print(f"Files available at:\n- {OFILE_MEAN}\n- {OFILE_STD}")


In [None]:
def plot_WRF_RMS_var_diurnal_cycle_lcz(VAR, RUN_INFO, METRIC="mean", FONT_SIZE=14):

  plt.rcParams.update({'font.size': FONT_SIZE})

  col_dict = _get_info('wrf_lcz')

  var_dict = {
    'SWDOWN': r"W/m$^2$",
    'GLW': r"W/m$^2$",
    'HFX': r"W/m$^2$",
    'LH': r"W/m$^2$",
    'CM_AC_URB3D': r"W/m$^2$",
    'EP_PV_URB3D': r"W/m$^2$",
    'TSK': "°C",
    'T2': "°C",
  }

  SIM_NAMES = [
    'LCZ_NORMS_SIM', 
    'LCZ_CR_SIM', 
    'LCZ_GR_SIM',
    'LCZ_PVP_SIM',
  ]

  # Initialize plot before looping
  fig, axes = plt.subplots(1,4, figsize=(25,7), sharey=True, sharex=True)

  # Loop over SIM_NAMES
  for s_i, SIM_NAME in enumerate(SIM_NAMES):

    #SIM_NAME = SIM_NAMES[0]

    # Read the .csv data
    FN_FILE = os.path.join(
        RUN_INFO['OUT_DIR'],
        f"WRF_{SIM_NAME}_{VAR}_H_LU_{METRIC}.csv",
    )

    df = pd.read_csv(FN_FILE, index_col=0)

    # Get the appropriate colors
    cols = [i for i in col_dict.values()]
    col_idx = [int(i-1) for i in list(df.index)]
    col_idx[0] = 0
    col_idx
    cols = [cols[i] for i in col_idx]

    # plot the data
    df.T.plot(color=cols, ax=axes[s_i], legend=False)

    # Aesthetics
    axes[s_i].set_ylabel(f"{VAR} [{var_dict[VAR]}]")
    axes[s_i].set_xlabel("Hour of the day (UTC)")
    axes[s_i].set_title(SIM_NAME)
    axes[s_i].grid(color='0.7', ls=":")

  # Add legend in the bottom of the figure
  custom_lines = [Line2D([0], [0], color=col_i, lw=2) for col_i in cols]
  legend_labels = [
      'Natural ', 
      'LCZ 1', 
      'LCZ 2', 
      'LCZ 4', 
      'LCZ 5', 
      'LCZ 6',       
      'LCZ 8', 
      'LCZ 9', 
      'LCZ 10', 
  ]
  axes[3].legend(custom_lines, legend_labels,
            loc='lower left', bbox_to_anchor=(1.05,0),
            ncol=1, fancybox=True, shadow=True
  )

  plt.tight_layout()

  OFILE = os.path.join(
      RUN_INFO['FIG_DIR'],
      f'plot_WRF_RMS_{VAR}_diurnal_cycle_lcz.jpg'
  )
  plt.savefig(OFILE, dpi=RUN_INFO['DPI'])
  print(f"Done! Figure available at: {OFILE}")

In [None]:
def plot_WRF_RMS_var_diurnal_cycle_lcz_diff(VAR, RUN_INFO, METRIC="mean", FONT_SIZE=14):

  plt.rcParams.update({'font.size': FONT_SIZE})

  col_dict = _get_info('wrf_lcz')

  var_dict = {
    'SWDOWN': r"W/m$^2$",
    'GLW': r"W/m$^2$",
    'HFX': r"W/m$^2$",
    'LH': r"W/m$^2$",
    'CM_AC_URB3D': r"W/m$^2$",
    'EP_PV_URB3D': r"W/m$^2$",
    'TSK': "°C",
    'T2': "°C",
  }

  SIM_NAMES = [
    'LCZ_NORMS_SIM', 
    'LCZ_CR_SIM', 
    'LCZ_GR_SIM',
    'LCZ_PVP_SIM',
  ]

  # Initialize plot before looping
  fig, axes = plt.subplots(1,3, figsize=(18,7), sharey=True, sharex=True)

  # Read the .csv data
  FN_NORM = os.path.join(
        RUN_INFO['OUT_DIR'],
        f"WRF_LCZ_NORMS_SIM_{VAR}_H_LU_{METRIC}.csv",
    )
  FN_CR = os.path.join(
        RUN_INFO['OUT_DIR'],
        f"WRF_LCZ_CR_SIM_{VAR}_H_LU_{METRIC}.csv",
    )
  FN_GR = os.path.join(
        RUN_INFO['OUT_DIR'],
        f"WRF_LCZ_GR_SIM_{VAR}_H_LU_{METRIC}.csv",
    )
  FN_PVP = os.path.join(
        RUN_INFO['OUT_DIR'],
        f"WRF_LCZ_PVP_SIM_{VAR}_H_LU_{METRIC}.csv",
    )

  df_norm = pd.read_csv(FN_NORM, index_col=0)
  df_cr = pd.read_csv(FN_GR, index_col=0)
  df_gr = pd.read_csv(FN_GR, index_col=0)
  df_pvp = pd.read_csv(FN_PVP, index_col=0)

  # Get the appropriate colors
  cols = [i for i in col_dict.values()]
  col_idx = [int(i-1) for i in list(df_norm.index)]
  col_idx[0] = 0
  col_idx
  cols = [cols[i] for i in col_idx]

  # plot the data, as a difference
  df_cr_norm_diff = df_cr.T - df_norm.T
  df_gr_norm_diff = df_gr.T - df_norm.T
  df_pvp_norm_diff = df_pvp.T - df_norm.T
  
  # Plot the data
  df_cr_norm_diff.plot(color=cols, ax=axes[0], legend=False)
  df_gr_norm_diff.plot(color=cols, ax=axes[1], legend=False)
  df_pvp_norm_diff.plot(color=cols, ax=axes[2], legend=False)

  # Aesthetics
  axes[0].set_ylabel(f"$\Delta$ {VAR} [{var_dict[VAR]}]")
  
  for i in range(3):
    axes[i].set_xlabel("Hour of the day (UTC)")
    axes[i].grid(color='0.7', ls=":")
    axes[i].axhline(0, color='0.2', ls=":", lw=2)

  axes[0].set_title("CR - NORMS")
  axes[1].set_title("GR - NORMS")
  axes[2].set_title("PVP - NORMS")

  # Add legend in the bottom of the figure
  custom_lines = [Line2D([0], [0], color=col_i, lw=2) for col_i in cols]
  legend_labels = [
      'Natural ', 
      'LCZ 1', 
      'LCZ 2', 
      'LCZ 4', 
      'LCZ 5', 
      'LCZ 6',       
      'LCZ 8', 
      'LCZ 9', 
      'LCZ 10', 
  ]
  axes[2].legend(custom_lines, legend_labels,
            loc='lower left', bbox_to_anchor=(1.05,0),
            ncol=1, fancybox=True, shadow=True
  )

  plt.tight_layout()

  OFILE = os.path.join(
      RUN_INFO['FIG_DIR'],
      f'plot_WRF_RMS_{VAR}_diurnal_cycle_lcz_diff.jpg'
  )
  plt.savefig(OFILE, dpi=RUN_INFO['DPI'])
  print(f"Done! Figure available at: {OFILE}")