In [None]:
import pathlib
import networkx
import drawing
import matplotlib.pyplot as plt
import seaborn as sb
import pandas as pd

from clustering.Snapshots import Snapshots
from clustering.ClusterSet.HierarchicalClusterSet import HierarchicalClusterSet
from clustering.ClusterSet.KMeansClusterSet import KMeansClusterSet
from networks.TemporalNetwork import TemporalNetwork
from ODEs.ODEsSolutions import ODEsSolutions
from ODEs import ODEs

## Declare parameters to use throughout

In [None]:
distance_metric = 'euclidean'
cluster_method = 'ward'
cluster_limit_type = 'maxclust'
cluster_limit = 7

events = [(12 * i, '', 'dashed') for i in range(4 + 1)]
phases = [
    (0, 12, 'Day 1'),
    (12, 24, 'Night 1'),
    (24, 36, 'Day 2'),
    (36, 48, 'Night 2')]

ode_filepath = None
ode_variables = []
xpp_alias = 'xppmac64'

output_directory = None

## Load temporal network

Can be done in a number of ways; see constructors for TemporalNetwork for full list of options.

In [None]:
node_table_filepath = '../data/temporal_data/circadian/circadian_temporal_node_data_mean_48.csv'
static_network_filepath = '../data/static_networks/circadian_net.edgelist'

node_table = pd.read_csv(node_table_filepath, sep='\t', index_col=0)
static_network = networkx.read_edgelist(static_network_filepath)

temporal_network = TemporalNetwork.from_static_network_and_node_table_dataframe(
    static_network,
    node_table,
    combine_node_weights=lambda x, y: x*y,
    threshold=0.5,
    binary=True,
    normalise=True)
temporal_network_name = pathlib.Path(static_network_filepath).stem

## Solve system of ODEs (if applicable)

In [None]:
if ode_filepath is not None:
    start_time = int(temporal_network.true_times[0])
    end_time = 1 + int(temporal_network.true_times[-1])
    odes_solutions = ODEsSolutions(ode_filepath, start_time, end_time, xpp_alias)
else:
    odes_solutions = None

## Compute single set of clusters

In [None]:
snapshots = Snapshots.from_temporal_network(temporal_network, cluster_method, distance_metric)
constructor = HierarchicalClusterSet if cluster_method != 'k_means' else KMeansClusterSet
cluster_set = constructor(snapshots, cluster_limit_type, cluster_limit)

## Plot dendrogram, scatter graph and ODE variables

In [None]:
drawing.utils.configure_sch_color_map()
sb.set_palette('Dark2', n_colors=8)
norm = True
fig = plt.figure(figsize=(8, 6))

# Plot
if cluster_method != 'k_means':
    ax1 = fig.add_subplot(3, 1, 1)
    cluster_set.plot_dendrogram(ax=ax1)

ax2 = fig.add_subplot(3, 1, 2)
cluster_set.plot(ax=ax2)
ODEs.plot_events(events, ax=ax2, y_pos=0.005, text_x_offset=1)
ODEs.plot_phases(phases, ax=ax2, y_pos=0.15, ymax=0.3)

if ode_filepath is not None:
    ax3 = fig.add_subplot(3, 1, 3)
    odes_solutions.plot_concentrations(ode_variables, ax=ax3, norm=norm)

# Format
if cluster_method != 'k_means':
    ax1.set_ylabel('Distance threshold')
    ax1.set_xlabel("Times")
    ax1.set_title('Dendrogram: hierarchical clustering of snapshots', weight="bold")

title = f"Phases extracted by '{cluster_method}' clustering with '{cluster_limit_type}' = {cluster_limit}"
ax2.set_title(title, weight="bold")
ax2.set_yticks([])
ax2.set_xticks([6*i for i in range(9)])
sb.despine(ax=ax2, left=True)
ax2.grid(axis='x')

if ode_filepath is not None:
    ax3.set_xlabel('Time')
    ax3.set_ylabel('Concentration (normed)' if norm else 'Concentration')
    sb.despine(ax=ax3)
    ax3.autoscale()
    ax2.get_shared_x_axes().join(ax2, ax3)

fig.tight_layout()

# Save
if output_directory is not None:
    filename = f"phases_from_clustering_{cluster_limit_type}_{cluster_limit}_method_{cluster_method}"
    filename = f"{output_directory}/{filename}_{temporal_network_name}"
    plt.savefig(f"{filename}.png", dpi=250, bbox_inches="tight")
    plt.savefig(f"{filename}.pdf", dpi=250, bbox_inches="tight")

## Plot distance matrix heatmap

In [None]:
fig, ax = plt.subplots(figsize=(9, 9))
cluster_set.snapshots.distance_matrix.plot_heatmap(ax=ax, triangular=True)
title = f'Snapshots distance matrix heatmap using \'{cluster_set.snapshots.distance_matrix.metric}\' metric'
ax.set_title(title, weight='bold')
plt.show()