# Importing

In [28]:
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn_som.som import SOM
import salishsea_tools.viz_tools as sa_vi


# Datasets Preparation

In [29]:
ds = xr.open_dataset ('/results2/SalishSea/nowcast-green.202111/31mar22/SalishSea_1d_20220331_20220331_grid_T.nc')

# Variable selection
temp = ds.votemper 
saline = ds.vosaline
date = pd.DatetimeIndex(ds['time_counter'].values)

# Open the mesh mask
mesh = xr.open_dataset('/home/sallen/MEOPAR/grid/mesh_mask202108.nc')
mask = mesh.tmask.to_numpy()


# Masking

In [30]:
def masking (depth, temp, saline):

    temp = temp.where(mask[0,depth]==1)
    saline = saline.where(mask[0,depth]==1)
    

# SOM

In [36]:
def som (depth, temp, saline):
    
    # Post processing 
    inputs = np.stack((temp[0,depth].values.flatten(), saline[0,depth].values.flatten()))
    indx = np.argwhere(~np.isnan(inputs[0]) & ~np.isnan(inputs[1])) 
    inputs2 = np.transpose(inputs)  
    inputs2 = inputs2[indx[:,0]]

    # SOM
    temp_som = SOM(m=2, n=3, dim= inputs2[0,:].size)
    temp_som.fit(inputs2)
    predictions = temp_som.predict(inputs2)

    # Post processing
    unique, counts = np.unique(predictions, return_counts=True)
    indx2 = np.full(inputs[0,:].size,np.nan)
    indx2[indx[:,0]] = predictions

    # Preparation of the dataarray
    predictions = np.reshape(indx2,(898,398))
    map = xr.DataArray(predictions,
                    coords = {'y': temp[0,depth].y, 'x': temp[0,depth].x},
                    dims = ['y','x'],
                    attrs=dict(description="Clusters of the performed self organizing map algorithm",
                                long_name ="Cluster",
                                units="count"),
                    ) 


# Plotting

In [32]:
def plotting (unique, map):



    fig, ax = plt.subplots(ncols=1)

    cmap = plt.get_cmap('viridis', unique.max()+1)
    cmap.set_bad('gray')

    map.plot.pcolormesh(ax=ax[0], cmap=cmap, levels= np.concatenate((unique,[6])), cbar_kwargs={'ticks': unique}) 
    sa_vi.set_aspect(ax[0])
    ax[0].title.set_text('Depth of' + str(ds.deptht[depth]) + 'meters')
    fig.suptitle('SOM clustering using temperature and salinity for ' + str(date[0].year) + '/0' + str(date[0].month) + '/' + str(date[0].day))
    plt.show()


# Main FOR Loop From Where All Functions are Called

In [33]:

for depth in range (0, 1):
    masking(depth, temp, saline)
    som(depth, temp, saline)
    # plotting (unique, map)


In [34]:
temp.deptht