# Extraction of cell cycle phases by hierarchical clustering

We try to extract the phases of the cell cycle from our temporal network. We do so by:
1. Computing similarity/distance between every pair of adjacency snapshots
2. Performing hierarchical clustering on these distances

In [None]:
import pathlib
import scipy
import drawing
import numpy as np
import matplotlib.pyplot as plt
import scipy.cluster.hierarchy as sch
# ToDo: use 'with' statement instead?
import seaborn as sb
sb.set_context("paper")

from labellines import labelLines
from clustering import clustering, silhouetting
from clustering.Silhouettes import Silhouettes
from temporal_networks.TemporalNetwork import TemporalNetwork
from ODEs.ODEsSolutions import ODEsSolutions
from ODEs import ODEs

## Declare parameters to use throughout

In [None]:
distance_type = 'euclidean'
method = 'ward'
max_distance = 2
max_clusters = 6
max_clusters_limit = 15

events = [
    (5, 'START', 'dashed'),
    (33, 'bud', 'solid'),
    (36, 'ori', 'solid'),
    (70, 'E3', 'dashed'),
    (84, 'spn', 'solid'),
    (100, 'mass', 'solid')]
phases = [
    (0, 35, 'G1'),
    (35, 70, 'S'),
    (70, 78, 'G2'),
    (78, 100, 'M')]
variables = ['cln3', 'cln2', 'clb5', 'clb2','mass'] # +['ori']

ode_filepath = 'example_data/bychen04_xpp.ode'
temporal_network_filepath = "example_data/tedges_combined_weighted_binary_method_percentage_minmax_p_0.5_clean2.tedges"
# output_directory = 'output/'
output_directory = None

## Load our cell cycle temporal network

In [None]:
temporal_network = TemporalNetwork.from_file(temporal_network_filepath, "\s*\t\s*")
temporal_network_name = pathlib.Path(temporal_network_filepath).stem

### Run Chen

In [None]:
times = np.array(temporal_network.time_points(starting_at_zero=True))
true_times = np.array(temporal_network.time_points(starting_at_zero=False))
start_time = int(true_times[0])
end_time = 1 + int(true_times[-1])

odes_solutions = ODEsSolutions(ode_filepath, start_time, end_time)

In [None]:
fig, ax = plt.subplots()

# find local minima
mass_series = odes_solutions.series('mass')
mass_relative_minima_times = scipy.signal.argrelextrema(mass_series, np.less)
mass_relative_minima = mass_series[mass_relative_minima_times]

ax.plot(times, mass_series, 'o-')
ax.plot(times[mass_relative_minima_times], mass_relative_minima, 'ro')

print(f"min of mass: {mass_relative_minima} at indices {mass_relative_minima_times}")


## Complete plot

In [None]:
distance_matrix, distance_matrix_condensed = clustering.compute_snapshot_distances(temporal_network, distance_type)
clusters = sch.linkage(distance_matrix_condensed, method=method)
drawing.utils.configure_colour_map()
fig, (ax1, ax2, ax3) = plt.subplots(3, figsize=(8, 6))

# 1. Plot clustering
dendrogram_title = 'Dendrogram: hierarchical clustering of snapshots'
clustering.plot_dendrogram_from_clusters(clusters, ax=ax1, max_distance=max_distance, title=dendrogram_title)

# 2a. Plot scatter graph of clusters
flat_clusters = sch.fcluster(clusters, max_clusters, criterion='maxclust')
number_of_clusters = len(set(flat_clusters))
scatter_graph_title = \
    f"'Phases' extracted by hierarchical clustering of snapshots, \n " \
    f"with threshold $dist_{{max}} = {max_distance}$ => {number_of_clusters} clusters, " \
    f"with '{method}' method"
clustering.plot_scatter_of_phases_from_flat_clusters(
    flat_clusters,
    times,
    number_of_colours=10,
    ax=ax2,
    title=scatter_graph_title)

# 2b. Overlay events
sb.set_palette('Dark2', n_colors=8)
ODEs.plot_events(events, ax=ax2, y_pos=0.005, text_x_offset=1)

# 3. Plot concentrations and phases
ODEs.plot_concentrations(odes_solutions, variables, times, ax=ax3, norm=True)
labelLines(ax3.get_lines(), zorder=2.5, xvals=[10, 90, 95, 55, 30])
ODEs.plot_phases(phases, ax=ax3, y_pos=1.1)

# Tidy up formatting
ax2.get_shared_x_axes().join(ax2, ax3)
ax3.autoscale()
fig.tight_layout()

# Save plots
if output_directory is not None:
    filename = f"{output_directory}phases_from_clustering_maxclust_{max_clusters}_mtd_{method}_{temporal_network_name}"
    plt.savefig(f"{filename}.png", dpi=250, bbox_inches="tight")
    plt.savefig(f"{filename}.pdf", dpi=250, bbox_inches="tight")

plt.show()


## Silhouettes

In [None]:
max_cluster_range = range(1, max_clusters_limit)
flat_clusters_silhouettes = Silhouettes.via_flat_clusters(
    clusters,
    distance_matrix,
    max_cluster_range,
    metric='precomputed',
    number_of_time_points=temporal_network.T)

In [None]:
# Plot clusters and silhouette scores

title = f"Hier. clust. method: '{method}' ({temporal_network_name})"
fig, (ax1, ax2) = silhouetting.plot_average_silhouettes_and_clusters(
    flat_clusters_silhouettes,
    max_cluster_range,
    times,
    title)

ODEs.plot_events(events, ax=ax1)
ODEs.plot_phases(phases, ax=ax1, y_pos=-1, ymax=0.1)

if output_directory is not None:
    filename = f"{output_directory}phase_clusters_all_method_{method}_{temporal_network_name}"
    plt.savefig(f"{filename}.png", dpi=250, bbox_inches="tight")
    plt.savefig(f"{filename}.pdf", dpi=250, bbox_inches="tight")


In [None]:
fig, axs = silhouetting.plot_silhouette_samples(flat_clusters_silhouettes, columns=4)
fig.suptitle(f"Sample silhouette, method: '{method}' ({temporal_network_name})")
plt.subplots_adjust(hspace=0.4)
plt.show()

if output_directory is not None:
    filename = f"{output_directory}phase_clusters_silhouette_sample_method_{method}_{temporal_network_name}"
    plt.savefig(f"{filename}.png", dpi=250, bbox_inches="tight")
    plt.savefig(f"{filename}.pdf", dpi=250, bbox_inches="tight")

## K-means

In [None]:
snapshots = temporal_network.df_to_array()
snapshots = np.swapaxes(snapshots, 0, 2)  # Put time as zero-th axis
flat_snapshots = snapshots.reshape(temporal_network.T, -1)

k_means_silhouettes = Silhouettes.via_k_means(
    flat_snapshots,
    max_cluster_range,
    metric='euclidean',
    number_of_time_points=temporal_network.T)

In [None]:
method = 'k-means'
fig, axs = silhouetting.plot_silhouette_samples(k_means_silhouettes, columns=4)
fig.suptitle(f"Sample silhouette, method: '{method}' ({temporal_network_name})")
plt.subplots_adjust(hspace=0.4)
plt.show()

if output_directory is not None:
    filename = f"{output_directory}phase_clusters_silhouette_sample_method_{method}_{temporal_network_name}"
    plt.savefig(f"{filename}.png", dpi=250, bbox_inches="tight")
    plt.savefig(f"{filename}.pdf", dpi=250, bbox_inches="tight")


In [None]:
# Plot clusters and silhouette scores

title = f"Hier. clust. method: '{method}' ({temporal_network_name})"
fig, (ax1, ax2) = silhouetting.plot_average_silhouettes_and_clusters(
    k_means_silhouettes,
    max_cluster_range,
    times,
    title)

ODEs.plot_events(events, ax=ax1)
ODEs.plot_phases(phases, ax=ax1, y_pos=-1, ymax=0.1)

if output_directory is not None:
    filename = f"{output_directory}phase_clusters_all_method_{method}_{temporal_network_name}"
    plt.savefig(f"{filename}.png", dpi=250, bbox_inches="tight")
    plt.savefig(f"{filename}.pdf", dpi=250, bbox_inches="tight")