# Functional Clustering

## Importing

In [None]:
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.compose import make_column_transformer
from sklearn.preprocessing import MinMaxScaler

from skfda.ml.clustering import KMeans
from skfda.representation.grid import FDataGrid

import salishsea_tools.viz_tools as sa_vi

from tqdm import tqdm

np.warnings.filterwarnings('ignore') # For the nan mean warning


## Drivers Preparation

In [None]:
def drivers_preparation(dataset2):

    indx = np.where((dataset2.time_counter.dt.month==2) & (dataset2.time_counter.dt.day==29))

    inputs = np.stack([
        dataset2['Summation_of_solar_radiation'].to_numpy().reshape(*dataset2['Summation_of_solar_radiation'].to_numpy().shape[:1],-1),
        dataset2['Mean_wind_speed'].to_numpy().reshape(*dataset2['Mean_wind_speed'].to_numpy().shape[:1],-1),
        dataset2['Mean_air_temperature'].to_numpy().reshape(*dataset2['Mean_air_temperature'].to_numpy().shape[:1],-1)
        ])

    # Deleting 29 of February
    inputs = np.delete(inputs,indx,axis=1)

    # Splitting in years
    inputs = np.split(inputs,len(np.unique(dataset2.time_counter.dt.year)),axis=1)

    # Means
    inputs = np.nanmean(inputs,axis=0)

    x =  np.tile(dataset2.x, len(dataset2.y))
    y =  np.tile(np.repeat(dataset2.y, len(dataset2.x)),1)

    indx = np.where((~np.isnan(inputs[0]).any(axis=0))& (x>10) & ((x>100) | (y<880)))
    inputs = inputs[:,:,indx[0]]

    # Scaling the inputs
    temp = np.reshape(inputs,(len(inputs),inputs.shape[1]*inputs.shape[2]))
    temp = temp.transpose()
    scaler_inputs = make_column_transformer((MinMaxScaler(), [0,1,2]))
    temp = scaler_inputs.fit_transform(temp)
    temp = temp.transpose()
    inputs = np.reshape(temp,(len(inputs),inputs.shape[1],inputs.shape[2])) 

    # Converting it to an appropriate format for functional clustering
    inputs = np.transpose(inputs,axes=(2,1,0))
    inputs2 = FDataGrid(inputs, np.arange(0,len(inputs[0])))

    return(inputs2,indx)
    

## Targets Preparation

In [None]:
def targets_preparation(dataset, name):

    indx = np.where((dataset.time_counter.dt.month==2) & (dataset.time_counter.dt.day==29))
    
    targets = dataset[name].to_numpy().reshape(*dataset[name].to_numpy().shape[:1],-1)

    # Deleting 29 of February
    targets = np.delete(targets,indx,axis=0)

    # Splitting in years
    targets = np.split(targets,len(np.unique(dataset.time_counter.dt.year)),axis=0)

    # Means
    targets = np.nanmean(targets,axis=0)

    x =  np.tile(dataset.x, len(dataset.y))
    y =  np.tile(np.repeat(dataset.y, len(dataset.x)),1)

    indx = np.where((~np.isnan(targets).any(axis=0)) & (x>10) & ((x>100) | (y<880)))
    targets = targets[:,indx[0]]

    # Converting it to an appropriate format for functional clustering
    targets = targets.transpose()
    targets2 = FDataGrid(targets,np.arange(0,len(targets[0])))

    return(targets2,indx)


## Summary

In [None]:
def summary(name,clusters,unique,cluster_mean,counts,ind_cluster):

    if name == 'drivers':
        temp = np.vstack((counts,cluster_mean.transpose()))
        temp = temp.reshape(4,len(unique))
        temp = pd.DataFrame(temp.transpose(),columns=['counts','Summation of solar radiation', 'Mean wind speed', 'Mean Temperature'],index=unique+1)
    else:
        temp = np.concatenate((counts,cluster_mean))
        temp = temp.reshape(2,len(unique))
        temp = pd.DataFrame(temp.transpose(),columns=['counts','mean'],index=unique+1)
    temp.index.name = 'Cluster'

    fig, ax = plt.subplots(figsize =(5,9))
    
    cmap = plt.get_cmap('tab20', unique.max()+1)
    cmap.set_bad('gray')
    clus = clusters.plot(ax=ax, cmap=cmap, vmin = unique.min(), vmax = unique.max()+1, add_colorbar=False)
    cbar = fig.colorbar(clus, ticks = unique+0.5) 
    cbar.set_ticklabels(unique+1)
    cbar.set_label('Clusters [count]')
    ax.set_title('Functional Clustering for '+ name + ' (2007-2024)')
    sa_vi.set_aspect(ax)
    plt.show()

    display(temp.transpose())

    fig, axs = plt.subplots(3, 2, figsize=(10, 15))

    axs[0, 0].plot(ind_cluster[0])
    axs[0, 0].set_title('Cluster 1')

    axs[0, 1].plot(ind_cluster[1])
    axs[0, 1].set_title('Cluster 2')

    axs[1, 0].plot(ind_cluster[2])
    axs[1, 0].set_title('Cluster 3')

    axs[1, 1].plot(ind_cluster[3])
    axs[1, 1].set_title('Cluster 4')

    axs[2, 0].plot(ind_cluster[4])
    axs[2, 0].set_title('Cluster 5')
    
    axs[2, 1].plot(ind_cluster[5])
    axs[2, 1].set_title('Cluster 6')

    if name == 'drivers':
        fig.legend(('Summation_of_solar_radiation','Mean_wind_speed','Mean_air_temperature'))
    plt.show()
    

## Clustering

In [None]:
def clustering(dataset,quant,indx,name):

    # Training
    kmeans = KMeans(n_clusters=6)
    clusters = kmeans.fit_predict(quant)

    # Sorting so that cluster 1 has the minimum mean target value, 6 the maximum

        # Finding the mean of each cluster
    if name == 'drivers':
        cluster_mean_all = np.mean(kmeans.cluster_centers_.data_matrix,axis=1)
        cluster_mean = cluster_mean_all[:,0]  # Sorted based on the first input
    else:
        cluster_mean = np.squeeze(np.mean(kmeans.cluster_centers_.data_matrix,axis=1))

        # The index to sort the clusters
    indx3 = np.argsort(np.argsort(cluster_mean)) # For the complete map we need the double np.argsort

        # Sorting
    for j in np.arange(0,len(np.unique(clusters))):
        clusters = xr.where(kmeans.labels_==j, indx3[j], clusters)

    unique, counts = np.unique(clusters, return_counts=True)
    
    # Creating the map
    indx2 = np.full(len(dataset.y) * len(dataset.x),np.nan)
    indx2[indx[0]] = clusters
    clusters = np.reshape(indx2,(len(dataset.y),len(dataset.x))) 
    clusters2 = xr.DataArray(clusters,dims = ['y','x'])

    # Obtaining & sorting the individual clusters
    if name == 'drivers':
        ind_cluster = kmeans.cluster_centers_.data_matrix[np.argsort(indx3)]
    else:
        ind_cluster = kmeans.cluster_centers_.data_matrix[np.argsort(indx3)]

    # Sorting the mean values
    if name == 'drivers':
        cluster_mean = cluster_mean_all[np.argsort(cluster_mean)]
    else:
        cluster_mean = cluster_mean[np.argsort(cluster_mean)]

    return(clusters2,unique,cluster_mean,counts,ind_cluster)


## Files Reading

In [None]:
ds = xr.open_dataset('/data/ibougoudis/MOAD/files/integrated_original.nc')
ds2 = xr.open_dataset('/data/ibougoudis/MOAD/files/external_inputs.nc')

ds = ds.isel(
    y=(np.arange(ds.y[0], ds.y[-1], 5)), 
    x=(np.arange(ds.x[0], ds.x[-1], 5)))

ds2 = ds2.isel(
    y=(np.arange(ds2.y[0], ds2.y[-1], 5)), 
    x=(np.arange(ds2.x[0], ds2.x[-1], 5)))

dataset = ds.sel(time_counter = slice('2007', '2024'))
dataset2 = ds2.sel(time_counter = slice('2007', '2024'))

# id = 0 # For drivers

id = 1 # For targets


## Drivers Analysis

In [None]:
if id == 0:

    drivers,indx = drivers_preparation(dataset2)
    clusters_all,unique, clusters_mean,counts, ind_clusters = clustering(dataset2,drivers,indx,'drivers')
    summary('drivers',clusters,unique,clusters_mean,counts,ind_clusters)

    ind_clusters = ind_clusters.transpose(1,2,0)
    

## All Years (Targets)

In [None]:
if id == 1:

    name = 'Diatom'
    targets,indx = targets_preparation(dataset,name)
    clusters_all, unique, clusters_mean, counts, ind_clusters = clustering(dataset,targets,indx, name)
    summary(name,clusters,unique,clusters_mean,counts,ind_clusters)

    ind_clusters = ind_clusters.transpose(1,2,0)

## Individual Years (based on the all years clustering)

In [None]:
ind_clusters2 = np.zeros((ind_clusters.shape[0],len(np.unique(ds.time_counter.dt.year)),len(np.unique(clusters_all))-1))
years = np.unique(ds.time_counter.dt.year)

for i in tqdm(range(0, len(years))):

    dataset = ds.sel(time_counter = slice(str(years[i]), str(years[i])))
    dataset2 = ds2.sel(time_counter = slice(str(years[i]), str(years[i])))

    # # Drivers
    # drivers, indx = drivers_preparation(dataset2)
    # drivers = drivers.data_matrix.transpose(2,1,0)
    # clusters2 = np.ravel(clusters_all)[indx]
    # for j in range (0,len(np.unique(clusters2))):
    
    #     temp = xr.where(clusters2==j, drivers, np.nan)
    #     ind_clusters2[:,i,j] = np.nanmean(temp,axis=1)

    # Targets
    targets, indx = targets_preparation(dataset,name)
    targets = np.squeeze(targets.data_matrix).transpose()
    clusters2 = np.ravel(clusters_all)[indx]
    for j in range (0,len(np.unique(clusters2))):
    
        temp = xr.where(clusters2==j, targets, np.nan)
        ind_clusters2[:,i,j] = np.nanmean(temp,axis=1)

clusters_mean2 = np.round(np.mean(ind_clusters2,axis=0),10)
clusters_mean2 = np.append(clusters_mean2,np.expand_dims(clusters_mean,0),axis=0)

years2 = np.append(years,'2007-2024')


In [None]:
# a = dataset[name].mean('time_counter')

In [None]:
for i in range (0,len(np.unique(clusters_all))-1):

    temp = pd.DataFrame(clusters_mean2[:,i].transpose(),columns=['mean'],index=years2)
    temp.index.name = 'Year'
    print ('Cluster '+ str(i+1))
    display(temp.transpose())

    k=0
    l=0

    fig, ax = plt.subplots(5, 4, figsize=(10, 15))

    for j in np.arange (0,len(years)):

        ax[k, l].plot(ind_clusters2[:,j,i])
        ax[k, l].set_title(str(years[j]))

        l=l+1

        if l==4:
            l=0
            k=k+1

    ax[4,2].axis('off')
    ax[4,3].plot(ind_clusters[:,0,i])
    ax[4,3].set_title('2007-2024')

    fig.tight_layout(rect=[0, 0, 1, 0.97])
    fig.suptitle('Cluster '+ str(i+1))

    plt.show()


In [None]:
clusters_per_year.shape

In [None]:
years = np.unique(ds.time_counter.dt.year)

clusters_per_year =  np.zeros((len(np.unique(ds.time_counter.dt.year)),dataset[name].shape[1],dataset[name].shape[2]))
counts_per_year = np.zeros((len(np.unique(ds.time_counter.dt.year)),dataset[name].shape[1],dataset[name].shape[2]))

for i in tqdm(range(0, len(years))):

    dataset = ds.sel(time_counter = slice(str(years[i]), str(years[i])))
    dataset2 = ds2.sel(time_counter = slice(str(years[i]), str(years[i])))

    targets,indx = targets_preparation(dataset,name)
    clusters, unique, clusters_mean, counts, ind_clusters = clustering(dataset,targets,indx, name)
    clusters_per_year[i,:,:] = clusters

a = xr.DataArray(clusters_per_year,dims = ['years','y','x'])
b = xr.DataArray(clusters_all,dims = ['y','x'])


    

In [None]:
fig, ax = plt.subplots(5, 4, figsize=(10, 15))

cmap = plt.get_cmap('tab20', unique.max()+1)
cmap.set_bad('gray')

k=0
l=0

for j in np.arange (0,len(years)):

    temp = np.concatenate((counts,clusters_mean))
    temp = temp.reshape(2,len(unique))
    temp = pd.DataFrame(temp.transpose(),columns=['counts','mean'],index=unique+1)
    temp.index.name = 'Cluster'

    display(temp.transpose())

    clus = a[j].plot(ax=ax[k,l], cmap=cmap, vmin = unique.min(), vmax = unique.max()+1, add_colorbar=False)

    cbar = fig.colorbar(clus, ticks=unique+0.5, fraction=0.08, pad=0.08) 
    cbar.set_ticklabels(unique+1)
    # cbar.set_label('Clusters [count]')
    ax[k,l].set_title(str(years[j]))

    sa_vi.set_aspect(ax[k,l])

    l=l+1

    if l==4:
        l=0
        k=k+1

ax[4,2].axis('off')

clus = b.plot(ax=ax[4,3], cmap=cmap, vmin = unique.min(), vmax = unique.max()+1, add_colorbar=False)
cbar = fig.colorbar(clus, ticks=unique+0.5, fraction=0.08, pad=0.08) 
cbar.set_ticklabels(unique+1)
# cbar.set_label('Clusters [count]')
ax[4,3].set_title('2007-2024')
sa_vi.set_aspect(ax[4,3])

fig.tight_layout(rect=[0, 0, 1, 0.97])
fig.suptitle('Functional Clustering for ' + str(name))

plt.show()


In [None]:
b