# Importing

In [None]:
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn_som.som import SOM
import salishsea_tools.viz_tools as sa_vi
from sklearn import preprocessing
import os


# Datasets Preparation

In [None]:
def datasets_preparation(i):

    ds_name = ('/results2/SalishSea/nowcast-green.202111/' + i + '/SalishSea_1d_' + '2022'+ str(dict_month[i[2:5]])+str(i[0:2]) + '_' + '2022'+ str(dict_month[i[2:5]]) + str(i[0:2]) + '_grid_T.nc')
    
    ds_bio_name = ('/results2/SalishSea/nowcast-green.202111/' + i + '/SalishSea_1d_' + '2022'+ str(dict_month[i[2:5]])+str(i[0:2]) + '_' + '2022'+ str(dict_month[i[2:5]]) + str(i[0:2]) + '_biol_T.nc')
    
    ds = xr.open_dataset (ds_name)
    ds_bio = xr. open_dataset (ds_bio_name)

    # Variable selection
    temp = ds.votemper
    saline = ds.vosaline
    date = pd.DatetimeIndex(ds['time_counter'].values)

    # Phytoplankton variables
    flag = ds_bio.flagellates
    diat = ds_bio.diatoms

    # Open the mesh mask
    mesh = xr.open_dataset('/home/sallen/MEOPAR/grid/mesh_mask202108.nc')
    mask = mesh.tmask.to_numpy()
    
    return (ds, temp, saline, date, mask, flag, diat)


# Masking

In [None]:
def masking (depth, temp, saline, mask, flag, diat):

    temp = temp.where(mask[0,depth]==1)
    saline = saline.where(mask[0,depth]==1)

    flag = flag.where(mask[0,depth]==1)
    diat = diat.where(mask[0,depth]==1)

    return (temp, saline, flag, diat)

# SOM

In [None]:
def som (depth, temp, saline):
    
    # Post processing 
    inputs = np.stack((temp[0,depth].values.flatten(), saline[0,depth].values.flatten()))
    indx = np.argwhere(~np.isnan(inputs[0]) & ~np.isnan(inputs[1])) 
    inputs2 = inputs[:,indx[:,0]]
    inputs3 = preprocessing.normalize(inputs2, norm= 'max')
    inputs3 = inputs3.transpose()

    # SOM
    temp_som = SOM(m=3, n=2, dim= inputs3[0,:].size, lr = 0.1)
    temp_som.fit(inputs3, epochs = 5)
    predictions = temp_som.predict(inputs3)

    # Post processing
    unique, counts = np.unique(predictions, return_counts=True)
    indx2 = np.full(inputs[0,:].size,np.nan)
    indx2[indx[:,0]] = predictions
    clusters = np.reshape(indx2,(898,398))   

    return (unique, counts, inputs2.transpose(), predictions, clusters)


# Printing

In [None]:
def printing (inputs, predictions, unique, counts):

    # Preparation of the dataframe
    d = {'temperature': inputs[:,0], 'salinity': inputs[:,1], 'cluster': predictions}
    df = pd.DataFrame(d) 

    # Calculating the metrics
    mean_temp = np.round(df.groupby('cluster')['temperature'].mean(), 2)
    mean_sal = np.round(df.groupby('cluster')['salinity'].mean(), 2)
    min_temp = np.round(df.groupby('cluster')['temperature'].min(), 2)
    max_temp = np.round(df.groupby('cluster')['temperature'].max(), 2)
    min_sal = np.round(df.groupby('cluster')['salinity'].min(), 2)
    max_sal = np.round(df.groupby('cluster')['salinity'].max(), 2)

    # Printing
    lines = []
    for i in unique:
        lines.append(['The amount of grid boxes for cluster ' + str(i), ' is ' + str(counts[i]),'\n'])

        lines.append(['The minimum temperature for cluster '+ str(i), ' is ' + str(min_temp[i]), ' degrees Celsius'])
        lines.append(['The maximum temperature for cluster '+ str(i), ' is ' + str(max_temp[i]), ' degrees Celsius'])
        lines.append(['The mean temperature for cluster '+ str(i), ' is ' + str(mean_temp[i]), ' degrees Celsius', '\n'])

        lines.append(['The minimum salinity for cluster '+ str(i), ' is ' + str(min_sal[i]), ' g/kg'])
        lines.append(['The maximum salinity for cluster '+ str(i), ' is ' + str(max_sal[i]), ' g/kg'])
        lines.append(['The mean salinity for cluster '+ str(i), ' is ' + str(mean_sal[i]), ' g/kg', '\n'*2])
    
    f = open("Statistics_" + str(np.round(ds['deptht'][depth].values,2)) + 'm.txt', "a")
    for line in lines:
          f.writelines(line)
          f.write('\n')

    f.close()

# Plotting

In [None]:
def plotting (depth, clusters, unique, flag, diat):

    # Preparation of the dataarray 
    map = xr.DataArray(clusters,
                    coords = {'y': flag[0,depth].y, 'x': flag[0,depth].x},
                    dims = ['y','x'],
                    attrs=dict(description="Clusters of the performed self organizing map algorithm",
                                long_name ="Cluster",
                                units="count"),
                    ) 

    cmap = plt.get_cmap('tab10', unique.max()+1)
    cmap.set_bad('gray')
    fig, ax = plt.subplots(2,2, figsize=(10, 15))
    clus = map.plot.pcolormesh(ax=ax[0,0], cmap=cmap, vmin = unique.min(), vmax = unique.max()+1, add_colorbar=False)

    cbar = fig.colorbar(clus, ticks = unique+0.5) 
    cbar.set_ticklabels(unique)
    cbar.set_label('Clusters [count]')
    
    plt.subplots_adjust(left=0.1,
        bottom=0.1, 
        right=0.9, 
        top=0.95, 
        wspace=0.15, 
        hspace=0.15)
        
    cmap = plt.get_cmap('cubehelix')
    cmap.set_bad('gray')
    flag[0,depth].plot.pcolormesh(ax=ax[0,1], cmap=cmap) 
    diat[0,depth].plot.pcolormesh(ax=ax[1,0], cmap=cmap) 

    sa_vi.set_aspect(ax[0,0])
    sa_vi.set_aspect(ax[0,1])
    sa_vi.set_aspect(ax[1,0])

    ax[0,0].title.set_text('Clustering')
    ax[0,1].title.set_text('Flagellates')
    ax[1,0].title.set_text('Diatoms')

    if date[0].month < 10:
        month = '0' + str(date[0].month)
    else:
        month = str(date[0].month)

    if date[0].day < 10:
        day = '0' + str(date[0].day)  
    else:
        day = str(date[0].day) 

    fig.suptitle('Depth: ' + str(np.round(ds['deptht'][depth].values,2)) + ' meters, ' + str(date[0].year) + '/' + month + '/' + day)

    fig.savefig('Depth_' + str(np.round(ds['deptht'][depth].values,2))+ '.png')
    plt.close(fig)


# Main FOR Loop From Where All Functions are Called

In [None]:
parent_dir = '/data/ibougoudis/MOAD/analysis-ilias/notebooks/som_depths_driv_phy_r'
os.makedirs(parent_dir, exist_ok= True)

dict_month = {'jan': '01',
         'feb': '02',
         'mar': '03',
         'apr': '04',
         'may': '05',
         'jun': '06',
         'jul': '07',
         'aug': '08',
         'sep': '09',
         'oct': '10',
         'nov': '11',
         'dec': '12'}

path = os.listdir('/results2/SalishSea/nowcast-green.202111/')

folders = [x for x in path if (x[5]=='2') and (x[6]=='2') and (x[2:5]=='mar' or x[2:5]=='apr' or x[2:5] == 'may')]
folders.sort()

for i in folders:

    os.makedirs(os.path.join(parent_dir, i), exist_ok= True) 
    os.chdir(os.path.join(parent_dir, i))

    ds, temp, saline, date, mask, flag, diat = datasets_preparation(i)

    for depth in range (0, 10):

        temp, saline, flag, diat = masking(depth, temp, saline, mask, flag, diat)
        unique, counts, inputs, predictions, clusters = som(depth, temp, saline)
        printing(inputs, predictions, unique, counts)
        plotting (depth, clusters, unique, flag, diat)
    
    print([i])
