# Script to analyse models for province 2

In [None]:
import joblib
import xarray as xr
import pandas as pd

## Load model

In [None]:
filename = "/home/jovyan/lustre_scratch/models/random_forest_2_otherhalf.joblib"

In [None]:
rf_2 = joblib.load(open(filename, "rb"))

In [None]:
filename = "/home/jovyan/lustre_scratch/models/random_forest_2.joblib"

In [None]:
rf_1 = joblib.load(open(filename, "rb"))

## Load .nc in xarray for province

In [None]:
filename = '/home/jovyan/lustre_scratch/province_dataframes/full_province2'+'.nc'

In [None]:
d = xr.load_dataset(filename)

In [None]:
d_1 = d.sel(longitude=slice(-25,10))
d_2 = d.sel(longitude=slice(10,50))

## Split to test data

In [None]:
#slices
slice_1 = slice('1998-01-01','2000-12-01')
slice_2 = slice('2001-01-01','2004-12-01')
slice_3 = slice('2005-01-01','2008-12-01')
slice_4 = slice('2009-01-01','2012-12-01')
slice_5 = slice('2013-01-01','2016-12-01')

In [None]:
#fold 3 
test_1 = d_1.sel(time=slice_3)
test_2 = d_2.sel(time=slice_3)

In [None]:
stacked = d_1.stack(coord=['longitude', 'latitude']).to_dataframe()
stacked.drop(columns=['spatial_ref','longitude', 'latitude'],inplace=True)
stacked.dropna(inplace=True)
stacked.reset_index(drop=True,inplace=True)
std_1 = stacked.rrs.std()

In [None]:
stacked = d_2.stack(coord=['longitude', 'latitude']).to_dataframe()
stacked.drop(columns=['spatial_ref','longitude', 'latitude'],inplace=True)
stacked.dropna(inplace=True)
stacked.reset_index(drop=True,inplace=True)
std_2 = stacked.rrs.std()

## Prep data

In [None]:
stacked_test = test_1.stack(coord=['longitude', 'latitude']).to_dataframe()
stacked_test.drop(columns=['spatial_ref','longitude', 'latitude'],inplace=True)
stacked_test.dropna(inplace=True)
stacked_test.reset_index(drop=True,inplace=True)

stacked_test['rrs'].where(stacked_test['rrs']<std_1, other=1, inplace=True)
rrs_ones = (stacked_test['rrs'] == 1).sum()
rrs_zeros = (stacked_test['rrs'] == 0).sum()

non_zero = stacked_test.loc[stacked_test['rrs'] == 1.]
non_zero = non_zero[stacked_test.columns]

zero = stacked_test.loc[stacked_test['rrs'] == 0.]
zero = zero[stacked_test.columns]
zero_samp = zero.sample(rrs_ones)

full_test_1 = pd.concat([zero_samp,non_zero])
X_test_1 = full_test_1.drop(columns='rrs')
y_test_1 = full_test_1['rrs']

In [None]:
test_accuracy = rf_1.score(X_test_1, y_test_1)

In [None]:
print('test_accuracy model 1:' ,test_accuracy)

In [None]:
stacked_test = test_2.stack(coord=['longitude', 'latitude']).to_dataframe()
stacked_test.drop(columns=['spatial_ref','longitude', 'latitude'],inplace=True)
stacked_test.dropna(inplace=True)
stacked_test.reset_index(drop=True,inplace=True)

stacked_test['rrs'].where(stacked_test['rrs']<std_2, other=1, inplace=True)
rrs_ones = (stacked_test['rrs'] == 1).sum()
rrs_zeros = (stacked_test['rrs'] == 0).sum()

non_zero = stacked_test.loc[stacked_test['rrs'] == 1.]
non_zero = non_zero[stacked_test.columns]

zero = stacked_test.loc[stacked_test['rrs'] == 0.]
zero = zero[stacked_test.columns]
zero_samp = zero.sample(rrs_ones)

full_test_2 = pd.concat([zero_samp,non_zero])
X_test_2 = full_test_2.drop(columns='rrs')
y_test_2 = full_test_2['rrs']

In [None]:
print('test_accuracy model 1:' ,test_accuracy)

## SHAP

In [None]:
import shap

In [None]:
sh_1 = shap.TreeExplainer(rf_1)

In [None]:
sample_1 = X_test_1.sample(2000)

In [None]:
sample_1

In [None]:
sh_val_1 = sh_1.shap_values(sample_1)

In [None]:
import numpy as np

In [None]:
global_values = [np.mean(i) for i in np.rollaxis(abs(sh_val_1[1]), 1)]

In [None]:
global_values

In [None]:
import matplotlib.pyplot as plt

In [None]:
from matplotlib.colors import LinearSegmentedColormap
colors = [(0, 24/255, 95/255),(0, 154/255, 162/255),(126/255, 201/255, 201/255),(173/255, 255/255, 251/255)]#,'#C5FFFC']
cmap = LinearSegmentedColormap.from_list('coccolithphores', colors, N=100)

In [None]:

shap.summary_plot(sh_val_1[1],sample_1,plot_type='dot',cmap=cmap,show=False) #show=False)
ax = plt.gca()
ax.set_xlim(-0.5, 0.5) 
plt.title('Province 2a')
plt.savefig('/home/jovyan/lustre_scratch/Figures/model_shap_analysis_final_tree_redo/province_2a.png')
plt.show()

In [None]:
sh_2 = shap.TreeExplainer(rf_2)

In [None]:
sample_2 = X_test_2.sample(2000)

In [None]:
sample_2

In [None]:
sh_val_2 = sh_2.shap_values(sample_2)

In [None]:
global_values = [np.mean(i) for i in np.rollaxis(abs(sh_val_2[1]), 1)]

In [None]:
global_values

In [None]:
shap.summary_plot(sh_val_2[1],sample_2,plot_type='dot',cmap=cmap,show=False) #show=False)
ax = plt.gca()
ax.set_xlim(-0.5, 0.5) 
plt.title('Province 2b')
plt.savefig('/home/jovyan/lustre_scratch/Figures/model_shap_analysis_final_tree_redo/province_2b.png')
plt.show()