In [None]:
!pip install interpret
!pip install kaleido
!pip install tabulate

In [None]:
from google.colab import drive
drive.mount('/content/drive')
from google.colab import files
working_directory = '/content/drive/My Drive/COS Seesaw Research'

Mounted at /content/drive


In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import warnings
import shutil
from sklearn.model_selection import train_test_split
from interpret import show
from interpret import data
from tabulate import tabulate
from dateutil.relativedelta import relativedelta
from datetime import datetime as dt

warnings.filterwarnings("ignore")



In [None]:
cos_sites = ['alt', 'brw', 'cgo', 'hfm', 'kum', 'lef', 'mhd', 'mlo', 'nwr', 'psa', 'smo', 'spo', 'sum', 'thd']
base_path = working_directory + '/Sensitivity/interval_correlation'
interval = relativedelta(months=1)
time_series_length = 5
short_list_length = 10

#make required directory
if not os.path.isdir(base_path):
  os.mkdir(base_path)
else:
  shutil.rmtree(base_path)
  os.mkdir(base_path)

for site in cos_sites:
  print(site)
  window_end = dt(year=2016, month=1, day=1)
  window_start = window_end + relativedelta(months=-2)

  cos_target = 'COS_' + site
  df = pd.read_pickle(working_directory + '/Data/Pickles/correlation_pickles/' + site + '_dataframe.pkl')
  df = df.set_index('time')

  save_path = base_path + '/' + site
  short_list_dict = {}
  short_list_time_dict = {}

  if not os.path.isdir(save_path):
    os.mkdir(save_path)
  else:
    shutil.rmtree(save_path)
    os.mkdir(save_path)

  for i in range(12):
    interval_df = df[window_start:window_end]
    print(window_start.date())
    print(window_end.date())
    
    columns = list(interval_df.columns)
    if cos_target in columns:
      columns.remove(cos_target)
    else:
      print('Error, target column not in dataframe')
      quit()

    x = interval_df[columns]
    y = interval_df[cos_target]

    marginal = data.Marginal()
    marginal_explanation = marginal.explain_data(x,y)

    current_dict = {}
    fifteen_dict = {}
    month_dict = {}
    month_fifteen_dict = {}
    two_month_dict = {}
    index = 0
    all_dict = {}
    for variable in marginal_explanation.feature_names:
      val_to_append = marginal_explanation.data(key=index)['correlation']

      if not variable in short_list_time_dict.keys():
        short_list_time_dict[variable] = []
      short_list_time_dict[variable].append(abs(val_to_append))
      
      
      all_dict[variable] = val_to_append
      temp = variable.split('-')

      # build the short list entry
      if not temp[0] in short_list_dict.keys():
        short_list_dict[temp[0]] = []
      if abs(val_to_append) > 0.5:
        short_list_dict[temp[0]].append(abs(val_to_append))
      else:
        short_list_dict[temp[0]].append(0)

      # abs correlation less than 0.5 means no correlation
      if not (abs(val_to_append) > 0.5):
        val_to_append = '-'
        
      if len(temp) == 1:
        current_dict[temp[0]] = val_to_append
      elif temp[1] == '15d':
        fifteen_dict[temp[0]] = val_to_append
      elif temp[1] == '1m':
        month_dict[temp[0]] = val_to_append
      elif temp[1] == '1m15d':
        month_fifteen_dict[temp[0]] = val_to_append
      elif temp[1] == '2m':
        two_month_dict[temp[0]] = val_to_append
      else:
        print("Error, what is this: ", temp[0], ", ", temp[1])
      index += 1
    columns = ['Time']
    #rows = ['Observation_time', '-15d', '-1m', '-1m15d', '-2m']
    current_row = ['current']
    fifteen_d_row = ['-15d']
    month_row = ['-1m']
    month_fifteen_row = ['-1m15d']
    two_month_row = ['-2m']
    for var in current_dict.keys():
      columns.append(var)
      current_row.append(current_dict[var])
      fifteen_d_row.append(fifteen_dict[var])
      month_row.append(month_dict[var])
      month_fifteen_row.append(month_fifteen_dict[var])
      two_month_row.append(two_month_dict[var])
    
    correlation_data = [current_row, fifteen_d_row, month_row, month_fifteen_row, two_month_row]
    save_table = tabulate(correlation_data, headers=columns, tablefmt='tsv')
    print(save_table)
    csv_path = save_path +'/' + site + '_correlation_' + str(window_end.date()) +'.csv'
    csv_file=open(csv_path, 'w')
    csv_file.write(save_table)
    csv_file.close()

    window_end += interval
    window_start += interval

  #calculate the short list
  for key in short_list_dict.keys():
    short_list_dict[key] = (sum(short_list_dict[key]) / len(short_list_dict[key]))

  for key in short_list_time_dict.keys():
    short_list_time_dict[key] = (sum(short_list_time_dict[key]) / len(short_list_time_dict[key]))

  short_list_time = []
  while len(short_list_time_dict.keys()) > 0:
    best = max(short_list_time_dict, key=short_list_time_dict.get)
    best_tuple = (best, short_list_time_dict.pop(best))
    short_list_time.append(best_tuple)

  short_list = []
  while len(short_list_dict.keys()) > 0:
    best = max(short_list_dict, key=short_list_dict.get)
    best_tuple = (best, short_list_dict.pop(best))
    short_list.append(best_tuple)
  
  short_path = save_path + '/shortlist'
  if not os.path.isdir(short_path):
    os.mkdir(short_path)
  else:
    shutil.rmtree(short_path)
    os.mkdir(short_path)

  short_list_path = short_path + '/' + site + '_short_list.txt'
  site_short_list = open(short_list_path, 'w')

  i = 0
  for i in range(len(short_list)):
    if i == short_list_length:
      write_string = '\n'
      site_short_list.write(write_string)

    write_string = str(i + 1) + '. '+ str(short_list[i][0]) + ': '+ str(short_list[i][1]) + '\n'
    site_short_list.write(write_string)
    i += 1

  site_short_list.close()

  short_list_path = short_path + '/' + site + '_short_list_time.txt'
  site_short_list = open(short_list_path, 'w')

  i = 0
  for i in range(len(short_list_time)):
    if i == short_list_length:
      write_string = '\n'
      site_short_list.write(write_string)
    
    write_string = str(i + 1) + '. '+ str(short_list_time[i][0]) + ': '+ str(short_list_time[i][1]) + '\n'
    site_short_list.write(write_string)
    i += 1

  site_short_list.close()



alt
2015-11-01
2016-01-01
Time   	  alt_sst	  brw_sst	  cgo_sst	  hfm_sst	  kum_sst	  lef_sst	  mhd_sst	  mlo_sst	  nwr_sst	  psa_sst	  smo_sst	  spo_sst	  sum_sst	  thd_sst	  alt_cdom	brw_cdom           	cgo_cdom           	  hfm_cdom	kum_cdom          	  lef_cdom	  mhd_cdom	mlo_cdom          	nwr_cdom           	psa_cdom           	  smo_cdom	spo_cdom           	sum_cdom           	thd_cdom           	  alt_dswrf	  brw_dswrf	cgo_dswrf         	hfm_dswrf          	  kum_dswrf	  lef_dswrf	  mhd_dswrf	  mlo_dswrf	  nwr_dswrf	  psa_dswrf	smo_dswrf          	  spo_dswrf	  sum_dswrf	  thd_dswrf	alt_vwnd           	brw_vwnd          	cgo_vwnd           	hfm_vwnd           	kum_vwnd           	lef_vwnd           	mhd_vwnd           	mlo_vwnd           	nwr_vwnd           	psa_vwnd           	smo_vwnd           	spo_vwnd          	sum_vwnd           	thd_vwnd           	alt_uwnd          	brw_uwnd          	cgo_uwnd           	hfm_uwnd           	kum_uwnd           	lef_uwnd           	mhd_uw