In [None]:
import pandas as pd
import pathlib
import networkx
import matplotlib.pyplot as plt

from clustering.ClusterSets.HierarchicalClusterSets import HierarchicalClusterSets
from clustering.ClusterSets.KMeansClusterSets import KMeansClusterSets
from clustering.Snapshots import Snapshots
from networks.TemporalNetwork import TemporalNetwork

## Declare parameters to use throughout

In [None]:
distance_metric = 'jaccard'
cluster_method = 'ward'
cluster_limit_type = 'maxclust'

cluster_limit_range = [2 + (1 * i) for i in range(0, 10)]
events = [(12 * i, '', 'dashed') for i in range(4 + 1)]
phases = [
    (0, 12, 'Day 1'),
    (12, 24, 'Night 1'),
    (24, 36, 'Day 2'),
    (36, 48, 'Night 2')
]
time_ticks = [6 * i for i in range(8 + 1)]

output_directory = None

### Load/create temporal networks

In [None]:
node_table_filepath = '../data/temporal_data/circadian/circadian_temporal_node_data_mean_48.csv'
static_network_filepath = '../data/static_networks/circadian_net.edgelist'
binary = True
normalise = True
thresholds = [round(0 + (0.1 * i), 1) for i in range(10)]

node_table = pd.read_csv(node_table_filepath, sep='\t', index_col=0)
static_network = networkx.read_edgelist(static_network_filepath)

temporal_networks = [
    TemporalNetwork.from_static_network_and_node_table_dataframe(
        static_network,
        node_table,
        combine_node_weights=lambda x, y: x*y,
        threshold=threshold,
        binary=binary,
        normalise=normalise)
    for threshold
    in thresholds
]

### Calculate cluster sets for each threshold

In [None]:
valid_cluster_sets = []
for threshold, temporal_network in zip(thresholds, temporal_networks):
    try:
        snapshots = Snapshots.from_temporal_network(temporal_network, cluster_method, distance_metric)
        constructor = HierarchicalClusterSets if cluster_method != 'k_means' else KMeansClusterSets
        cluster_sets = constructor(snapshots, cluster_limit_type, cluster_limit_range)
        valid_cluster_sets.append((cluster_sets, threshold))
    except Exception as e:
        print(f'Error when threshold = {threshold}: {e}')

### Plot across different thresholds

In [None]:
gridspec_kw = {"width_ratios": [3, 1, 2]}
figsize = (9, 4*len(valid_cluster_sets))
fig, axs = plt.subplots(len(valid_cluster_sets), 3, figsize=figsize, gridspec_kw=gridspec_kw, sharey='row', sharex='col')

for i, (cluster_sets, threshold) in enumerate(valid_cluster_sets):
    row = (axs[i, 0], axs[i, 1], axs[i, 2])
    cluster_sets.plot_and_format_with_average_silhouettes(row, events, phases, time_ticks)

    plt.subplots_adjust(wspace=0.4, hspace=0.4)
    fontdict = {'horizontalalignment': 'left'}
    axs[i, 0].set_title(f'Clusters and silhouette scores (threshold = {threshold})', fontdict=fontdict, pad=12)

title = f'data={pathlib.Path(node_table_filepath).stem}, binary={binary}, normalise={normalise}'
title += f'\nmetric={distance_metric}, cluster_method={cluster_method}, cluster_limit_type={cluster_limit_type}'
fig.suptitle(title, y=0.91, weight='bold')

# Save
if output_directory is not None:
    filename = f"{output_directory}/silhouette_scores_over_thresholds"
    plt.savefig(f"{filename}.png", dpi=250, bbox_inches="tight")
