In [None]:
import pandas as pd
import pathlib
import networkx
import matplotlib.pyplot as plt
import drawing

from ODEs import ODEs
from clustering.ClusterSets.HierarchicalClusterSets import HierarchicalClusterSets
from clustering.ClusterSets.KMeansClusterSets import KMeansClusterSets
from clustering.Snapshots import Snapshots
from drawing.utils import display_name
from networks.TemporalNetwork import TemporalNetwork

## Declare parameters to use throughout

In [None]:
distance_metric = 'euclidean'
cluster_method = 'ward'
cluster_limit_type = 'maxclust'
cluster_limit_range = [2 + (1 * i) for i in range(0, 10)]

events = [(12 * i, '', 'dashed') for i in range(4 + 1)]

phases = [
    (0, 12, 'Day 1'),
    (12, 24, 'Night 1'),
    (24, 36, 'Day 2'),
    (36, 48, 'Night 2')
]

output_directory = None

### Load/create temporal networks

In [None]:
node_table_filepath = '../data/temporal_data/circadian_temporal_node_data_mean_normalised.csv'
static_network_filepath = '../data/static_networks/circadian_net.edgelist'

node_table = pd.read_csv(node_table_filepath, sep='\t', index_col=0)
static_network = networkx.read_edgelist(static_network_filepath)

thresholds = [round(0 + (0.1 * i), 1) for i in range(10)]
temporal_networks = [
    TemporalNetwork.from_static_network_and_node_table_dataframe(
        static_network,
        node_table,
        combine_node_weights=lambda x, y: x*y,
        threshold=threshold,
        binary=True,
        normalise=False)
    for threshold
    in thresholds
]

### Plot across different thresholds

In [None]:
gridspec_kw = {"width_ratios": [3, 1, 2]}
figsize = (9, 4*len(temporal_networks))
fig, axs = plt.subplots(len(temporal_networks), 3, figsize=figsize, gridspec_kw=gridspec_kw, sharey=True)

for i, temporal_network in enumerate(temporal_networks):

    snapshots = Snapshots.from_temporal_network(temporal_network, cluster_method, distance_metric)
    constructor = HierarchicalClusterSets if cluster_method != 'k_means' else KMeansClusterSets
    cluster_sets = constructor(snapshots, cluster_limit_type, cluster_limit_range)

    # Plot
    (ax1, ax2, ax3) = (axs[i, 0], axs[i, 1], axs[i, 2])
    cluster_sets.plot_with_average_silhouettes((ax1, ax2, ax3))
    ymax = max(cluster_limit_range) + 1
    ax1.set_ylim(0, ymax)
    ODEs.plot_events(events, ax=ax1)
    ODEs.plot_phases(phases, ax=ax1, y_pos=0.05, ymax=0.1)

    # Format
    ax1.set_xlabel("Time")
    ax1.set_xticks([6*i for i in range(9)])
    ax1.set_axisbelow(True)
    ax1.set_ylabel(display_name(cluster_sets.limit_type))

    ax2.set_xlabel("Average silhouette")
    ax2.set_xlim((0, 1))
    ax2.yaxis.set_tick_params(labelleft=True)

    ax3.set_xlabel("Actual # clusters")
    ax3.yaxis.set_tick_params(labelleft=True)

    plt.subplots_adjust(wspace=0.4, hspace=0.4)
    ax1.set_title(f"Hier. clust. method: '{cluster_method}' (threshold={thresholds[i]})")

# Save
if output_directory is not None:
    filename = f"{output_directory}/cluster_range_using_{cluster_method}_method"
    plt.savefig(f"{filename}.png", dpi=250, bbox_inches="tight")
