In [None]:
import pandas as pd
import pathlib
import networkx
import matplotlib.pyplot as plt

from clustering.ClusterSets.HierarchicalClusterSets import HierarchicalClusterSets
from clustering.ClusterSets.KMeansClusterSets import KMeansClusterSets
from clustering.Snapshots import Snapshots
from utils import paths
from networks.TemporalNetwork import TemporalNetwork

## Declare parameters to use throughout

In [None]:
cluster_method = 'ward'
cluster_limit_type = 'maxclust'

# distance_metrics = [
#     'cityblock', 'cosine', 'euclidean', 'l1', 'l2', 'manhattan', 'braycurtis', 'canberra', 'chebyshev', 'correlation',
#     'dice', 'hamming', 'jaccard', 'kulsinski', 'mahalanobis', 'minkowski', 'rogerstanimoto', 'russellrao',
#     'seuclidean', 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'yule',
# ]

distance_metrics = ['cosine', 'euclidean', 'dice']

output_directory = '../data/output'

# Circadian parameters
cluster_limit_range = [2 + (1 * i) for i in range(0, 10)]
events = [(24 + 12 * i, '', 'dashed') for i in range(4)]
phases = [
    (18, 24, 'N'),
    (24, 36, 'Day'),
    (36, 48, 'Night'),
    (48, 60, 'Day'),
    (60, 66, 'N')]
time_ticks = [18 + 6*i for i in range(8 + 1)]


# # Cell cycle parameters
# cluster_limit_range = [2 + (1 * i) for i in range(0, 10)]
# events = [
#     (5, 'START', 'dashed'),
#     (33, 'bud', 'solid'),
#     (36, 'ori', 'solid'),
#     (70, 'E3', 'dashed'),
#     (84, 'spn', 'solid'),
#     (100, 'mass', 'solid')]
# phases = [
#     (0, 35, 'G1'),
#     (35, 70, 'S'),
#     (70, 78, 'G2'),
#     (78, 100, 'M')]
# time_ticks = None


### Load/create temporal network

In [None]:
# Circadian network
node_table_filepath = '../data/temporal_data/circadian/circadian_temporal_node_data_mean_normalised_full.csv'
static_network_filepath = '../data/static_networks/circadian_full.edgelist'
binary = False
normalise = None
threshold = 0.0

node_table = pd.read_csv(node_table_filepath, sep='\t', index_col=0)
static_network = networkx.read_edgelist(static_network_filepath, delimiter=', ')

temporal_network = TemporalNetwork.from_static_network_and_node_table_dataframe(
    static_network,
    node_table,
    combine_node_weights=lambda x, y: x*y,
    threshold=threshold,
    binary=binary,
    normalise=normalise)
temporal_network_name = \
    f'{pathlib.Path(node_table_filepath).stem}, binary={binary}, normalise={normalise}, threshold={threshold}'

# # Cell cycle network
# temporal_network_filepath = '../data/temporal_networks/cell_cycle/tedges_combined_weighted_binary_method_percentage_p_0.5_clean2.tedges'
# temporal_network_separator = '\\s*\\t\\s*'
#
# edges = pd.read_csv(temporal_network_filepath, sep=temporal_network_separator, engine='python')
# temporal_network = TemporalNetwork.from_edge_list_dataframe(edges)
# temporal_network_name = pathlib.Path(temporal_network_filepath).stem

### Calculate cluster sets for each distance metric

In [None]:
valid_cluster_sets = []
for distance_metric in distance_metrics:
    try:
        snapshots = Snapshots.from_temporal_network(temporal_network, cluster_method, distance_metric)
        constructor = HierarchicalClusterSets if cluster_method != 'k_means' else KMeansClusterSets
        cluster_sets = constructor(snapshots, cluster_limit_type, cluster_limit_range)
        valid_cluster_sets.append((cluster_sets, distance_metric))
    except Exception as e:
        print(f'Error when distance_metric = {distance_metric}: {e}')

### Plot across different distance metrics

In [None]:
gridspec_kw = {"width_ratios": [3, 1, 2]}
figsize = (9, 4*len(valid_cluster_sets))
fig, axs = plt.subplots(len(valid_cluster_sets), 3, figsize=figsize, gridspec_kw=gridspec_kw, sharey='row', sharex='col')

for i, (cluster_sets, distance_metric) in enumerate(valid_cluster_sets):
    row = (axs[i, 0], axs[i, 1], axs[i, 2])
    cluster_sets.plot_and_format_with_average_silhouettes(row, events, phases, time_ticks)

    plt.subplots_adjust(wspace=0.4, hspace=0.4)
    fontdict = {'horizontalalignment': 'left'}
    axs[i, 0].set_title(f'Clusters and silhouette scores (distance_metric = {distance_metric})', fontdict=fontdict, pad=12)

title = temporal_network_name + f'\ncluster_method={cluster_method}, cluster_limit_type={cluster_limit_type}'
fig.suptitle(title, y=0.97, weight='bold')

# Save
if output_directory is not None:
    filename = f"{output_directory}/silhouette_scores_over_distance_metrics_{paths.slugify(temporal_network_name)}"
    plt.savefig(f"{filename}.png", dpi=250, bbox_inches="tight")