## Importing

In [31]:
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn_som.som import SOM
import salishsea_tools.viz_tools as sa_vi
from sklearn import preprocessing
import os


## Datasets Preparation

In [32]:
def datasets_preparation ():
    
    # Dataset and date    
    ds_name = ('/results2/SalishSea/nowcast-green.202111/' + i + '/SalishSea_1d_' + '20' + str(i[5:7]) + str(dict_month[i[2:5]])+str(i[0:2]) + '_' + '20' + str(i[5:7]) + str(dict_month[i[2:5]]) + str(i[0:2]) + '_grid_T.nc')
    
    ds_bio_name = ('/results2/SalishSea/nowcast-green.202111/' + i + '/SalishSea_1d_'  + '20' + str(i[5:7]) + str(dict_month[i[2:5]])+str(i[0:2]) + '_' + '20' + str(i[5:7]) + str(dict_month[i[2:5]]) + str(i[0:2]) + '_biol_T.nc')
    
    ds = xr.open_dataset (ds_name)
    ds_bio = xr. open_dataset (ds_bio_name)

    date = pd.DatetimeIndex(ds['time_counter'].values)
    
    temp_i1 = (ds.votemper.where(mask==1)[0,0:15] * ds.e3t.where(mask==1)
               [0,0:15]).sum('deptht', skipna = True, min_count = 15) / mesh.gdepw_0[0,15]

    temp_i2 = (ds.votemper.where(mask==1)[0,15:27] * ds.e3t.where(mask==1)
               [0,15:27]).sum('deptht', skipna = True, min_count = 12) / (mesh.gdepw_0[0,27] - mesh.gdepw_0[0,14])

    saline_i1 = (ds.vosaline.where(mask==1)[0,0:15] * ds.e3t.where(mask==1)
                 [0,0:15]).sum('deptht', skipna = True, min_count = 15) / mesh.gdepw_0[0,15]

    saline_i2 = (ds.vosaline.where(mask==1)[0,15:27] * ds.e3t.where(mask==1)
                 [0,15:27]).sum('deptht', skipna = True, min_count = 12) / (mesh.gdepw_0[0,27] - mesh.gdepw_0[0,14])

    sil_i = (ds_bio.silicon.where(mask==1)[0,0:27] * ds.e3t.where(mask==1)
             [0,0:27]).sum('deptht', skipna = True, min_count = 27) / mesh.gdepw_0[0,27]

    nitr_i = (ds_bio.nitrate.where(mask==1)[0,0:27] * ds.e3t.where(mask==1)
              [0,0:27]).sum('deptht', skipna = True, min_count = 27) / mesh.gdepw_0[0,27]

    ammo_i = (ds_bio.ammonium.where(mask==1)[0,0:27] * ds.e3t.where(mask==1)
              [0,0:27]).sum('deptht', skipna = True, min_count = 27) / mesh.gdepw_0[0,27]

    diat_i = (ds_bio.diatoms.where(mask==1)[0,0:27] * ds.e3t.where(mask==1)
              [0,0:27]).sum('deptht', skipna = True, min_count = 27) / mesh.gdepw_0[0,27]

    flag_i = (ds_bio.flagellates.where(mask==1)[0,0:27] * ds.e3t.where(mask==1)
              [0,0:27]).sum('deptht', skipna = True, min_count = 27) / mesh.gdepw_0[0,27]
    
    micro_i = (ds_bio.microzooplankton.where(mask==1)[0,0:27] * ds.e3t.where(mask==1)
               [0,0:27]).sum('deptht', skipna = True, min_count = 27) / mesh.gdepw_0[0,27]

    meso_i = (ds_bio.mesozooplankton.where(mask==1)[0,0:27] * ds.e3t.where(mask==1)
              [0,0:27]).sum('deptht', skipna = True, min_count = 27) / mesh.gdepw_0[0,27]

    return (date, temp_i1, temp_i2, saline_i1, saline_i2, sil_i, nitr_i, ammo_i, diat_i,  flag_i, micro_i, meso_i)


## SOM (Drivers)

In [33]:
def som (inputs, m, n):

    # Pre processing 
    indx = np.where(~np.isnan(inputs).any(axis=0))
    inputs2 = inputs[:,indx[0]]
    inputs3 = preprocessing.normalize(inputs2, norm= 'max')
    inputs3 = inputs3.transpose()

    # SOM
    temp_som = SOM(m, n, dim= inputs3[0,:].size, lr = 0.1)
    temp_som.fit(inputs3, epochs = 5)
    predictions = temp_som.predict(inputs3)

    # Post processing
    unique, counts = np.unique(predictions, return_counts=True)
    indx2 = np.full(inputs[0,:].size,np.nan)
    indx2[indx[0]] = predictions
    clusters = np.reshape(indx2,(898,398)) 

    # Preparation of the dataarray 
    clusters= xr.DataArray(clusters,
                    coords = {'y': temp_i1.y, 'x': temp_i1.x},
                    dims = ['y','x'],
                    attrs=dict(description="Clusters of the performed self organizing map algorithm",
                                long_name ="Cluster",
                                units="count"),
                    )
    
    return (clusters, unique, counts)
    

# Printing

In [34]:
def printing ():

    lines = ['Date: ' + str(date.date[0]), '\n']
    for i in unique:
        lines.append(['The amount of grid boxes for cluster ' + str(i), ' is ' + str(counts[i]),'\n'])

        lines.append(['The minimum temperature for cluster '+ str(i), ' is ' + str(np.round(xr.concat([temp_i1.where(clusters==i),temp_i2.where(clusters==i)], 'y').min().values,2)), ' degrees_C m'])
        lines.append(['The maximum temperature for cluster '+ str(i), ' is ' + str(np.round(xr.concat([temp_i1.where(clusters==i),temp_i2.where(clusters==i)], 'y').max().values,2)), ' degrees_C m'])
        lines.append(['The mean temperature for cluster '+ str(i), ' is ' + str(np.round(xr.concat([temp_i1.where(clusters==i),temp_i2.where(clusters==i)], 'y').mean().values,2)), ' degrees_C m', '\n'])

        lines.append(['The minimum salinity for cluster '+ str(i), ' is ' + str(np.round(xr.concat([saline_i1.where(clusters==i),saline_i2.where(clusters==i)], 'y').min().values,2)), ' g m kg-1'])
        lines.append(['The maximum salinity for cluster '+ str(i), ' is ' + str(np.round(xr.concat([saline_i1.where(clusters==i),saline_i2.where(clusters==i)], 'y').max().values,2)), ' g m kg-1'])
        lines.append(['The mean salinity for cluster '+ str(i), ' is ' + str(np.round(xr.concat([saline_i1.where(clusters==i),saline_i2.where(clusters==i)], 'y').mean().values,2)), ' g m kg-1', '\n'*2])
        
    
        lines.append(['The minimum silicon concentration for cluster '+ str(i), ' is ' + str(np.round(sil_i.where(clusters==i).min().values,2)), ' mmol m-2'])
        lines.append(['The maximum silicon concentration for cluster '+ str(i), ' is ' + str(np.round(sil_i.where(clusters==i).max().values,2)), ' mmol m-2'])
        lines.append(['The mean silicon concentration for cluster '+ str(i), ' is ' + str(np.round(sil_i.where(clusters==i).mean().values,2)), ' mmol m-2', '\n'])

        lines.append(['The minimum nitrate concentration for cluster '+ str(i), ' is ' + str(np.round(nitr_i.where(clusters==i).min().values,2)), ' mmol m-2'])
        lines.append(['The maximum nitrate concentration for cluster '+ str(i), ' is ' + str(np.round(nitr_i.where(clusters==i).max().values,2)), ' mmol m-2'])
        lines.append(['The mean nitrate concentration for cluster '+ str(i), ' is ' + str(np.round(nitr_i.where(clusters==i).mean().values,2)), ' mmol m-2', '\n'])

        lines.append(['The minimum ammonium concentration for cluster '+ str(i), ' is ' + str(np.round(ammo_i.where(clusters==i).min().values,2)), ' mmol m-2'])
        lines.append(['The maximum ammonium concentration for cluster '+ str(i), ' is ' + str(np.round(ammo_i.where(clusters==i).max().values,2)), ' mmol m-2'])
        lines.append(['The mean ammonium concentration for cluster '+ str(i), ' is ' + str(np.round(ammo_i.where(clusters==i).mean().values,2)), ' mmol m-2', '\n'*2])


        lines.append(['The minimum diatom concentration for cluster '+ str(i), ' is ' + str(np.round(diat_i.where(clusters==i).min().values,2)), ' mmol m-2'])
        lines.append(['The maximum diatom concentration for cluster '+ str(i), ' is ' + str(np.round(diat_i.where(clusters==i).max().values,2)), ' mmol m-2'])
        lines.append(['The mean diatom concentration for cluster '+ str(i), ' is ' + str(np.round(diat_i.where(clusters==i).mean().values,2)), ' mmol m-2', '\n'])

        lines.append(['The minimum flagellate concentration for cluster '+ str(i), ' is ' + str(np.round(flag_i.where(clusters==i).min().values,2)), ' mmol m-2'])
        lines.append(['The maximum flagellate concentration for cluster '+ str(i), ' is ' + str(np.round(flag_i.where(clusters==i).max().values,2)), ' mmol m-2'])
        lines.append(['The mean flagellate concentration for cluster '+ str(i), ' is ' + str(np.round(flag_i.where(clusters==i).mean().values,2)), ' mmol m-2', '\n'*2])


        lines.append(['The minimum microzooplankton concentration for cluster '+ str(i), ' is ' + str(np.round(micro_i.where(clusters==i).min().values,2)), ' mmol m-2'])
        lines.append(['The maximum microzooplankton concentration for cluster '+ str(i), ' is ' + str(np.round(micro_i.where(clusters==i).max().values,2)), ' mmol m-2'])
        lines.append(['The mean microzooplankton concentration for cluster '+ str(i), ' is ' + str(np.round(micro_i.where(clusters==i).mean().values,2)), ' mmol m-2', '\n'])

        lines.append(['The minimum mesozooplankton concentration for cluster '+ str(i), ' is ' + str(np.round(meso_i.where(clusters==i).min().values,2)), ' mmol m-2'])
        lines.append(['The maximum mesozooplankton concentration for cluster '+ str(i), ' is ' + str(np.round(meso_i.where(clusters==i).max().values,2)), ' mmol m-2'])
        lines.append(['The mean mesozooplankton concentration for cluster '+ str(i), ' is ' + str(np.round(meso_i.where(clusters==i).mean().values,2)), ' mmol m-2', '\n'*2])


    f = open("Statistics.txt", "a")
    for line in lines:
          f.writelines(line)
          f.write('\n')

    f.close()

## Plotting

In [35]:
def plotting (labels, titles, variables, filename):
    
    fig, ax = plt.subplots(2,2, figsize=(10, 15))

    if variables[0].all() == clusters.all():

        cmap = plt.get_cmap('tab20', unique.max()+1)
        cmap.set_bad('gray')
        clus = variables[0].plot(ax=ax[0,0], cmap=cmap, vmin = unique.min(), vmax = unique.max()+1, add_colorbar=False)

        cbar = fig.colorbar(clus, ticks = unique+0.5) 
        cbar.set_ticklabels(unique)
        cbar.set_label(labels[0])

    else:
        
        cmap = plt.get_cmap('cubehelix')
        cmap.set_bad('gray')
        variables[0].plot(ax=ax[0,0], cmap=cmap, cbar_kwargs={'label': labels[0]})

    plt.subplots_adjust(left=0.1,
        bottom=0.1, 
        right=0.9, 
        top=0.95, 
        wspace=0.15, 
        hspace=0.15)
        
    cmap = plt.get_cmap('cubehelix')
    cmap.set_bad('gray')

    variables[1].plot(ax=ax[0,1], cmap=cmap, cbar_kwargs={'label': labels[1]})
    variables[2].plot(ax=ax[1,0], cmap=cmap, cbar_kwargs={'label': labels[2]})

    sa_vi.set_aspect(ax[0,0])
    sa_vi.set_aspect(ax[0,1])
    sa_vi.set_aspect(ax[1,0])

    ax[0,0].title.set_text(titles[0])
    ax[0,1].title.set_text(titles[1])
    ax[1,0].title.set_text(titles[2])

    if len(titles) == 3:

        ax[1,1].axis('off')
    
    else:

        variables[3].plot(ax=ax[1,1], cmap=cmap, cbar_kwargs={'label': labels[3]})
        sa_vi.set_aspect(ax[1,1])
        ax[1,1].title.set_text(titles[3])

    fig.suptitle(date.date[0])

    fig.savefig(filename)
    
    plt.close(fig)


## Main Body

In [36]:
parent_dir = '/data/ibougoudis/MOAD/analysis-ilias/notebooks/integration_r'
os.makedirs(parent_dir, exist_ok= True)

dict_month = {'jan': '01',
         'feb': '02',
         'mar': '03',
         'apr': '04',
         'may': '05',
         'jun': '06',
         'jul': '07',
         'aug': '08',
         'sep': '09',
         'oct': '10',
         'nov': '11',
         'dec': '12'}

path = os.listdir('/results2/SalishSea/nowcast-green.202111/')

folders = [x for x in path if (x[2:5]=='mar' or x[2:5]=='apr' or x[2:5] == 'may')]
folders.sort()

# Open the mesh mask
mesh = xr.open_dataset('/home/sallen/MEOPAR/grid/mesh_mask202108.nc')
mask = mesh.tmask.to_numpy()

labels_dr = ['Conservative Temperature [degree_C m]', 
            'Conservative Temperature [degree_C m]', 'Reference Salinity [g kg-1 m]',
            'Reference Salinity [g kg-1 m]']
titles_dr = ['Conservative Temperature (0m - 15m)',
            'Conservative Temperature (15m - 100m)', 'Reference Salinity (0m - 15m)',
            'Reference Salinity (15m - 100m)']

labels_nu = ['Clusters [count]', 'Silicon Concentration [mmol m-2]',
            'Nitrate Concentration [mmol m-2]', 'Ammonium Concentration [mmol m-2]']
titles_nu = ['Clusters', 'Silicon', 'Nitrate', 'Ammonium']

labels_ph = ['Clusters [count]', 'Diatoms Concentration [mmol m-2]',
            'Flagellates Concentration [mmol m-2]']
titles_ph = ['Clusters', 'Diatoms', 'Flagellates']

labels_zo = ['Clusters [count]', 'Microzooplankton Concentration [mmol m-2]',
            'Mesozooplankton Concentration [mmol m-2]']
titles_zo = ['Clusters', 'Microzooplankton', 'Mesozooplankton']

for i in folders:

    # if the day does not exist
    if os.path.exists(os.path.join(parent_dir, i)) == False:

        os.makedirs(os.path.join(parent_dir, i), exist_ok= True) 
        os.chdir(os.path.join(parent_dir, i))

        date, temp_i1, temp_i2, saline_i1, saline_i2, sil_i, nitr_i, ammo_i, diat_i, flag_i, micro_i, meso_i = datasets_preparation()

        # Potential input variables
        drivers = np.stack([np.ravel(temp_i1), np.ravel(temp_i2), np.ravel(saline_i1), np.ravel(saline_i2)])
        nutrients = np.stack([np.ravel(sil_i), np.ravel(nitr_i), np.ravel(ammo_i)])
        phyto = np.stack([np.ravel(diat_i), np.ravel(flag_i)])
        zoo = np.stack([np.ravel(micro_i), np.ravel(meso_i)])

        # Dimensions of the map
        m = 3
        n = 2

        clusters, unique, counts = som(drivers, m, n)

        printing ()

        plotting (labels_dr, titles_dr, variables = [temp_i1, temp_i2, saline_i1, saline_i2], filename = 'Drivers.png')
        plotting (labels_nu, titles_nu, variables = [clusters, sil_i, nitr_i, ammo_i], filename = 'Nutrients.png')
        plotting (labels_ph, titles_ph, variables = [clusters, diat_i, flag_i], filename = 'Phytoplankton.png')
        plotting (labels_zo, titles_zo, variables = [clusters, micro_i, meso_i], filename = 'Zooplankton.png')

    # if the day is empty
    elif os.listdir((os.path.join(parent_dir, i))) == []:

        os.chdir(os.path.join(parent_dir, i))

        date, temp_i1, temp_i2, saline_i1, saline_i2, sil_i, nitr_i, ammo_i, diat_i, flag_i, micro_i, meso_i = datasets_preparation()

        # Potential input variables
        drivers = np.stack([np.ravel(temp_i1), np.ravel(temp_i2), np.ravel(saline_i1), np.ravel(saline_i2)])
        nutrients = np.stack([np.ravel(sil_i), np.ravel(nitr_i), np.ravel(ammo_i)])
        phyto = np.stack([np.ravel(diat_i), np.ravel(flag_i)])
        zoo = np.stack([np.ravel(micro_i), np.ravel(meso_i)])

        # Dimensions of the map
        m = 3
        n = 2

        clusters, unique, counts = som(drivers, m, n)

        printing ()

        plotting (labels_dr, titles_dr, variables = [temp_i1, temp_i2, saline_i1, saline_i2], filename = 'Drivers.png')
        plotting (labels_nu, titles_nu, variables = [clusters, sil_i, nitr_i, ammo_i], filename = 'Nutrients.png')
        plotting (labels_ph, titles_ph, variables = [clusters, diat_i, flag_i], filename = 'Phytoplankton.png')
        plotting (labels_zo, titles_zo, variables = [clusters, micro_i, meso_i], filename = 'Zooplankton.png')        
    
    print(str(folders.index(i)+1), '/', str(len(folders)))


1 / 1564
2 / 1564
3 / 1564
4 / 1564
5 / 1564
6 / 1564
7 / 1564
8 / 1564
9 / 1564
10 / 1564
11 / 1564
12 / 1564
13 / 1564
14 / 1564
15 / 1564
16 / 1564
17 / 1564
18 / 1564


KeyboardInterrupt: 