## Importing

In [None]:
import xarray as xr
import numpy as np

from sklearn import preprocessing
from sklearn_som.som import SOM

from tqdm.auto import tqdm


## SOM (Drivers)

In [None]:
def som (inputs, m, n):

    # Pre processing 
    indx = np.where(~np.isnan(inputs).any(axis=0))
    inputs2 = inputs[:,indx[0]]
    inputs3 = preprocessing.normalize(inputs2, norm= 'max')
    inputs3 = inputs3.transpose()

    # SOM
    temp_som = SOM(m, n, dim= inputs3[0,:].size, lr = 0.1)
    temp_som.fit(inputs3, epochs = 5)
    predictions = temp_som.predict(inputs3)

    # Post processing
    unique, counts = np.unique(predictions, return_counts=True)
    indx2 = np.full(inputs[0,:].size,np.nan)
    indx2[indx[0]] = predictions
    clusters = np.reshape(indx2,(898,398)) 
    
    clusters = datasets_preparation(clusters)

    return(clusters)
    

## Datasets Preparataion

In [None]:
def datasets_preparation(clusters):

    coords = dict(time_counter=dataset.time_counter, y=dataset.y, x=dataset.x) 
    clusters= xr.DataArray(clusters,
        coords = coords,
        dims = ['y', 'x'],
        attrs=dict(description="Clusters of the performed self organizing map algorithm",
        long_name ="Cluster",
        units="count"))
    clusters = xr.concat([clusters],'time_counter')

    return(clusters)


## File Creation

In [None]:
def file_creation(variable, name):

    temp = variable.to_dataset(name=name)
    temp.to_netcdf(path='D:\\nc\clustering.nc', mode='a', encoding={name:{"zlib": True, "complevel": 9}})


## Main Body

In [None]:
ds = xr.open_dataset('D:\\nc\integrated_original.nc')

# Dimensions of the map
m = 3
n = 2
    
for i in tqdm(range (0, len(ds.time_counter)), leave=False):        

    dataset = ds.isel(time_counter=i)        
   
    drivers = np.stack([np.ravel(dataset['Temperature_(0m-15m)']),
        np.ravel(dataset['Temperature_(15m-100m)']), 
        np.ravel(dataset['Salinity_(0m-15m)']),
        np.ravel(dataset['Salinity_(15m-100m)'])])

    # nutrients = np.stack([np.ravel(dataset['Silicon']),
    #     np.ravel(dataset['Nitrate']), np.ravel(dataset['Ammonium'])])

    # phyto = np.stack([np.ravel(dataset['Diatom']),
    #     np.ravel(dataset['Flagellate'])])

    # zoo = np.stack([np.ravel(dataset['Microzooplankton']),
    #     np.ravel(dataset['Mesozooplankton'])])
    
    clusters = som(drivers, m, n)

    if i ==0:
        
        clusters_all = clusters
    
    else:

        clusters_all = xr.concat((clusters_all, clusters), dim='time_counter')

# Calling file creation

file_creation(clusters_all, 'Clusters')
