In [None]:
import pandas as pd
import pathlib
import networkx
import matplotlib.pyplot as plt

from ODEs import ODEs
from clustering.ClusterSets.HierarchicalClusterSets import HierarchicalClusterSets
from clustering.ClusterSets.KMeansClusterSets import KMeansClusterSets
from clustering.Snapshots import Snapshots
from drawing.silhouettes import calculate_and_plot_silhouettes
from drawing.utils import display_name
from networks.TemporalNetwork import TemporalNetwork

## Declare parameters to use throughout

In [None]:
distance_metric = 'jaccard'
cluster_method = 'ward'
cluster_limit_type = 'maxclust'
cluster_limit_range = [2 + (1 * i) for i in range(0, 10)]

events = [(12 * i, '', 'dashed') for i in range(4 + 1)]
phases = [
    (0, 12, 'Day 1'),
    (12, 24, 'Night 1'),
    (24, 36, 'Day 2'),
    (36, 48, 'Night 2')
]

output_directory = '/Users/GeorginaTeague/PycharmProjects/scyclic/data/output'

### Load/create temporal networks

In [None]:
node_table_filepath = '../data/temporal_data/circadian/circadian_temporal_node_data_mean_48.csv'
static_network_filepath = '../data/static_networks/circadian_net.edgelist'
binary = True
normalise = True
thresholds = [round(0 + (0.1 * i), 1) for i in range(10)]

node_table = pd.read_csv(node_table_filepath, sep='\t', index_col=0)
static_network = networkx.read_edgelist(static_network_filepath)

temporal_networks = [
    TemporalNetwork.from_static_network_and_node_table_dataframe(
        static_network,
        node_table,
        combine_node_weights=lambda x, y: x*y,
        threshold=threshold,
        binary=binary,
        normalise=normalise)
    for threshold
    in thresholds
]

### Plot across different thresholds

In [None]:
gridspec_kw = {"width_ratios": [3, 1, 2]}
figsize = (9, 4*len(temporal_networks))
fig, axs = plt.subplots(len(temporal_networks), 3, figsize=figsize, gridspec_kw=gridspec_kw, sharey=True)


i = 0
for temporal_network, threshold in zip(temporal_networks, thresholds):
    row = (axs[i, 0], axs[i, 1], axs[i, 2])
    time_ticks = [6 * i for i in range(8 + 1)]
    plotted = calculate_and_plot_silhouettes(
        row, temporal_network, cluster_method, distance_metric, cluster_limit_type, cluster_limit_range, events,
        phases, variable_name='threshold', variable=threshold, time_ticks=time_ticks)
    if plotted:
        i += 1

for j in range(i, len(temporal_networks)):
    for ax in (axs[j, 0], axs[j, 1], axs[j, 2]):
        fig.delaxes(ax)

title = f'data={pathlib.Path(node_table_filepath).stem}, binary={binary}, normalise={normalise}'
title += f'\ndistance_metric={distance_metric}, cluster_method={cluster_method}, cluster_limit_type={cluster_limit_type} '
fig.suptitle(title, y=0.90, weight='bold')


# Save
if output_directory is not None:
    filename = f"{output_directory}/silhouette_scores_over_thresholds"
    plt.savefig(f"{filename}.png", dpi=250, bbox_inches="tight")
