In [None]:
!apt-get install -qq libgdal-dev libproj-dev
!pip install --no-binary shapely shapely --force
!pip install cartopy
!pip install regionmask
!pip install interpret
!pip install kaleido

In [None]:
#import required packages
import os
import warnings
import time
import regionmask
import gc
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import xarray as xr
import geopandas as gpd
import cartopy.crs as ccrs
import shapely
from datetime import datetime as dt
from shapely.geometry import Point
from shapely.geometry.polygon import Polygon
from scipy.interpolate import interp1d
from dateutil.relativedelta import relativedelta
from google.colab import drive
from google.colab import files
import matplotlib as mpl
from interpret import show
from interpret import data
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.gridspec as gridspec

from os import listdir
from os.path import isfile, join

In [None]:
#establish working directory and mount drive
drive.mount('/content/drive')
base_path = '/content/drive/My Drive/COS Seesaw Research'

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
cos_sites = ['alt', 'brw', 'cgo', 'hfm', 'kum', 'lef', 'mhd', 'mlo', 'nwr', 'psa', 'smo', 'spo', 'sum', 'thd' ]
start_site = 12
end_site = 15
cos_site_centers = [(-62.3, 82.5), (-156.6, 71.3), (144.7,-40.7), (-72.2, 42.5), (-154.8, 19.5), (-90.3, 45.6), (-9.9, 53.3), (-155.6, 19.5), (-105.5, 40.1), (-64.0, -64.6), (-170.6, -14.2), (0, -90), (-38.4, 72.6), (-124.1,41.0)]
region_size = 30
divider = ('-------------------------------------------------------------------------------------------------------')
warnings.filterwarnings("ignore")


In [None]:
regions = None
names = []
abbrevs = []
region_list = []
region_dict = {}

for i in range(len(cos_sites)):
  names.append(cos_sites[i])
  abbrevs.append(cos_sites[i])
  center_point = Point(cos_site_centers[i][0], cos_site_centers[i][1])
  circle = center_point.buffer(region_size)

  #region_bound = np.array([list(cos_site_centers[i])])
  
  #region_list.append(region_bound)
  region_list.append(circle)
  region_dict[cos_sites[i]] = circle

regions = regionmask.Regions(region_list, names=names, abbrevs=abbrevs, name='Ocean Regions', overlap=True)

In [None]:
def execute(site):
  #print(site)
  window_end = dt(year=2016, month=1, day=1)
  window_start = window_end + relativedelta(months=-2)

  working_directory = base_path + '/Data/Pickles/correlation_pickles' 
  site_path = working_directory + '/' + site + '_dataframe.pkl'

  cos_target = 'COS_' + site
  df = pd.read_pickle(site_path)
  df = df.set_index('time')

  for i in range(12):
    interval_df = df[window_start:window_end]
    columns = list(interval_df.columns)
    if cos_target in columns:
      columns.remove(cos_target)
    else:
      print('Error, target column not in dataframe')
      quit()

    x = interval_df[columns]
    y = interval_df[cos_target]

    marginal = data.Marginal()
    marginal_explanation = marginal.explain_data(x,y)

    index = 0
    features = {}

    for variable in marginal_explanation.feature_names:
      val_to_append = marginal_explanation.data(key=index)['correlation']

      variable_split = variable.split('_')
      site_name = variable_split[0]
      feature_split = variable_split[1].split('-')
      feature_time = None
      if len(feature_split) == 1:
        feature_time = 'current'
      else:
        feature_time = feature_split[1]
      
      feature_name = feature_split[0]
      if feature_name not in features.keys():
        features[feature_name] = {}

      if feature_time not in features[feature_name].keys():
        features[feature_name][feature_time] = {}

      features[feature_name][feature_time][site_name] = val_to_append
      index += 1
    gc.collect()
    # everything is sorted for this time interval, make the plots and save them
    for feature_name in features.keys():
      for feature_time in features[feature_name].keys():

        fig = plt.figure(figsize=(24,12), constrained_layout=True)
        gs = fig.add_gridspec(10,3)
        ax1 = fig.add_subplot(gs[0:9, :], projection=ccrs.PlateCarree(central_longitude=0))

        ax2 = fig.add_subplot(gs[9, :])


        #ax_map = fig.add_subplot(1,2,1,projection=ccrs.PlateCarree(central_longitude=0))

        #ax = regions.plot(label='abbrev')
        regions.plot(ax=ax1, label='abbrev')
        #divider = make_axes_locatable(ax)
        #cax = divider.append_axes('right', size='5%', pad=0.05)

        title = str(window_end.date()) + ' -- ' + site + " sensitivity to " + feature_name + " - " + feature_time 
        #print(title)
        ax1.set_title(title)

        for site_name in features[feature_name][feature_time].keys():
          correlation = abs(features[feature_name][feature_time][site_name])
          #print(site_name, ": ",  features[feature_name][feature_time][site_name])

          region_color='white'
          if correlation >= 0.9:
            region_color = 'firebrick'
          elif correlation >= 0.8:
            region_color = 'orangered'
          elif correlation >= 0.7:
            region_color = 'orange'
          elif correlation >= 0.6:
            region_color = 'gold'
          elif correlation >= 0.5:
            region_color = 'khaki'

          del correlation
          p = gpd.GeoSeries(region_dict[site_name])
          p.plot(ax=ax1, color=region_color)

        cmap = mpl.colors.ListedColormap(['white', 'khaki', 'gold', 'orange', 'orangered', 'firebrick'])
        bounds = [0, 0.5, 0.6, 0.7, 0.8, 0.9, 1]

        norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
        colorbar = mpl.colorbar.ColorbarBase(ax=ax2, cmap=cmap, norm=norm, boundaries =bounds, 
                                             spacing='uniform',
                                             orientation='horizontal')
        colorbar.set_label('Absolute Pearson Correlation')

        heatmap_path = save_directory + '/' + site + '/heatmaps'
        if not os.path.isdir(heatmap_path):
          os.mkdir(heatmap_path)
        
        fig_name = heatmap_path + '/' + str(window_end.date()) + '--' + site + "_sensitivity_to_ " + feature_name + "-" + feature_time + '.png'
        fig.savefig(fig_name)
        plt.close(fig)
        del fig
        del gs
        del ax1
        del ax2
        del title
        del norm
        del bounds
        del cmap
        del colorbar

        gc.collect()
    del features
    del marginal
    del marginal_explanation
    del x
    del y
    window_end += interval
    window_start += interval    

In [None]:
import gc
#get the correlation for each 
interval = relativedelta(months=1)

save_directory = base_path + '/Sensitivity/interval_correlation'
'''
for site in cos_sites:
  execute(site)
  gc.collect()
  print(site, " completed")
'''
for i in range(start_site, end_site):
  site = cos_sites[i]
  execute(site)
  gc.collect()
  print(site, " completed")