# Extraction of cell cycle phases by hierarchical clustering

We try to extract the phases of the cell cycle from our temporal network. We do so by:
1. Computing similarity/distance between every pair of adjacency snapshots
2. Performing hierarchical clustering on these distances

In [None]:
import pathlib
import numpy as np
import matplotlib.pyplot as plt
import scipy
import scipy.cluster.hierarchy as sch
# ToDo: use 'with' statement instead?
import seaborn as sb
sb.set_context("paper")

from sklearn import metrics
from labellines import labelLines
from clustering import clustering, silhouettes
from temporal_networks.TemporalNetwork import TemporalNetwork
from ODEs.ODEsSolutions import ODEsSolutions
from ODEs import ODEs

## Declare parameters to use throughout

In [None]:
distance_type = 'euclidean'
method = 'ward'
max_distance = 2
max_clusters = 6
max_clusters_limit = 15

ode_filepath = 'example_data/bychen04_xpp.ode'
temporal_network_filepath = "example_data/tedges_combined_weighted_binary_method_percentage_minmax_p_0.5_clean2.tedges"
output_directory = 'output/'

## Load our cell cycle temporal network

In [None]:
temporal_network = TemporalNetwork.from_file(temporal_network_filepath, "\s*\t\s*")
temporal_network_name = pathlib.Path(temporal_network_filepath).stem

### Run Chen

In [None]:
times = np.array(temporal_network.time_points(starting_at_zero=True))
true_times = np.array(temporal_network.time_points(starting_at_zero=False))
start_time = int(true_times[0])
end_time = 1 + int(true_times[-1])

odes_solutions = ODEsSolutions(ode_filepath, start_time, end_time)

In [None]:
fig, ax = plt.subplots()

# find local minima
mass_series = odes_solutions.series('mass')
mass_relative_minima_times = scipy.signal.argrelextrema(mass_series, np.less)
mass_relative_minima = mass_series[mass_relative_minima_times]

ax.plot(times, mass_series, 'o-')
ax.plot(times[mass_relative_minima_times], mass_relative_minima, 'ro')

print(f"min of mass: {mass_relative_minima} at indices {mass_relative_minima_times}")


## Complete plot

In [None]:
distance_matrix, distance_matrix_condensed = clustering.compute_snapshot_distances(temporal_network, distance_type)
clusters = sch.linkage(distance_matrix_condensed, method=method)
clustering.configure_colour_map()
fig, (ax1, ax2, ax3) = plt.subplots(3, figsize=(8, 6))

# 1. Plot clustering
dendrogram_title = 'Dendrogram: hierarchical clustering of snapshots'
clustering.plot_dendrogram_from_clusters(clusters, ax=ax1, max_distance=max_distance, title=dendrogram_title)

# 2a. Plot scatter graph of clusters
flat_clusters = sch.fcluster(clusters, max_clusters, criterion='maxclust')
number_of_clusters = len(set(flat_clusters))
scatter_graph_title = \
    f"'Phases' extracted by hierarchical clustering of snapshots, \n " \
    f"with threshold $dist_{{max}} = {max_distance}$ => {number_of_clusters} clusters, " \
    f"with '{method}' method"
clustering.plot_scatter_of_phases_from_flat_clusters(
    flat_clusters,
    times,
    number_of_colours=10,
    ax=ax2,
    title=scatter_graph_title)

# 2b. Overlay events
sb.set_palette('Dark2', n_colors=8)
events = [
    (5, 'START', 'dashed'),
    (33, 'bud', 'solid'),
    (36, 'ori', 'solid'),
    (70, 'E3', 'dashed'),
    (84, 'spn', 'solid'),
    (100, 'mass', 'solid')]
ODEs.plot_events(events, ax=ax2, y_pos=0.005, text_x_offset=1)

# 3. Plot concentrations and phases
variables = ['cln3', 'cln2', 'clb5', 'clb2','mass'] # +['ori']
phases = [
    (0, 35, 'G1'),
    (35, 70, 'S'),
    (70, 78, 'G2'),
    (78, 100, 'M')]
ODEs.plot_concentrations(odes_solutions, variables, times, ax=ax3, norm=True)
labelLines(ax3.get_lines(), zorder=2.5, xvals=[10, 90, 95, 55, 30])
ODEs.plot_phases(phases, ax=ax3, y_pos=1.1)

# Tidy up formatting
ax2.get_shared_x_axes().join(ax2, ax3)
ax3.autoscale()
fig.tight_layout()

# Save plots
if output_directory is not None:
    filename = f"{output_directory}phases_from_clustering_maxclust_{max_clusters}_mtd_{method}_{temporal_network_name}"
    plt.savefig(f"{filename}.png", dpi=250, bbox_inches="tight")
    plt.savefig(f"{filename}.pdf", dpi=250, bbox_inches="tight")

plt.show()


## Silhouettes


In [None]:
max_cluster_range = range(1, max_clusters_limit)
silhouette_data = silhouettes.compute_silhouettes_for_max_cluster_range(
    clusters,
    distance_matrix,
    max_cluster_range=max_cluster_range,
    number_of_time_points=temporal_network.T)
range_of_clusters, numbers_of_clusters, silhouette_samples, average_silhouettes = silhouette_data


In [None]:
# Plot array of clusters

gridspec_kw={"width_ratios": [9,2]}
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(7, 3), gridspec_kw=gridspec_kw)

labels = [
    int(number_of_clusters) if (i == 0 or np.diff(numbers_of_clusters)[i - 1] != 0) else ''
    for i, number_of_clusters
    in enumerate(numbers_of_clusters)]

clustering.plot_time_clusters(times, range_of_clusters, ax=ax1)
clustering.plot_time_clusters_right_axis(max_cluster_range, labels, ax=ax1)
silhouettes.plot_average_silhouettes(
    average_silhouettes,
    numbers_of_clusters,
    max_cluster_range,
    labels,
    ylim=ax1.get_ylim(),
    ax=ax2)
ODEs.plot_events(events, ax=ax1)
ODEs.plot_phases(phases, ax=ax1, y_pos=-1, ymax=0.1)

title = f"Hier. clust. method: '{method}' ({temporal_network_name})"
fig.suptitle(title)
plt.subplots_adjust(wspace=0.4, top=0.8)

if output_directory is not None:
    filename = f"{output_directory}phase_clusters_all_method_{method}_{temporal_network_name}"
    plt.savefig(f"{filename}.png", dpi=250, bbox_inches="tight")
    plt.savefig(f"{filename}.pdf", dpi=250, bbox_inches="tight")

In [None]:
n_clust_unique, idx_unique = np.unique(n_clust_arr, return_index=True) # indices of unique n_clust values
ncols = 4
n_unique = len(n_clust_unique) -1 # minus the 1-cluster
nrows = n_unique // ncols + n_unique % ncols

fig, axs = plt.subplots(nrows=nrows, ncols=ncols, sharex=True, sharey=True, figsize=(10, 2 * nrows))

for i, j_uni in enumerate(idx_unique):

    ax = axs.flatten()[i-1]

    n_clusters = len(set(clusters_arr[j_uni]))

    ax.set_title(f"{n_clusters} clusters")
    plot_silhouette_sample(silhouette_sample_arr[j_uni], clusters_arr[j_uni], silhouette_avg_arr[j_uni], ax=ax)


if nrows > 1 :
    axes_left = axs[:,0]
else :
    axes_left = [axs[0]]

for ax in axes_left :
    ax.set_ylabel("Ordered time index")

for ax in axs.flatten()[-ncols:] :
    ax.set_xlabel("Silhouette score")

fig.suptitle(f"Sample silhouette, method: '{method}' ({tag})")

plt.subplots_adjust(top=0.8)

plt.savefig(f"{dir_}phase_clusters_silhouette_sample_method_{method}_{tag}.png", dpi=250, bbox_inches="tight")
plt.savefig(f"{dir_}phase_clusters_silhouette_sample_method_{method}_{tag}.pdf", dpi=250, bbox_inches="tight")

## K-means

In [None]:
snapshots = tnet.df_to_array()
snapshots = np.swapaxes(snapshots, 0, 2) # put time as zeroth axis
snapshot_flat = snapshots.reshape(T, -1) # each matrix is flattened, represented as a vector

print(snapshots.shape)

In [None]:
# methods = ['single', 'complete', 'average', 'ward']
# i = 3
# method = methods[i]
method = "kmeans"

# # compute dendrogram
# linked = shc.linkage(dist_mat_condensed, method=method) #"ward")

maxclust_max = 15
maxclust_range = range(1, maxclust_max)
n_maxclust = len(maxclust_range)
clusters_arr = np.zeros((n_maxclust, T))
n_clust_arr = np.zeros(n_maxclust)

silhouette_avg_arr = np.zeros((n_maxclust))
silhouette_sample_arr = np.zeros((n_maxclust, T))


# compute array of clusters
for i, nclust in enumerate(maxclust_range) :

    # compute clusters
#     clusters = shc.fcluster(linked, maxclust, criterion='maxclust')
    clusters = KMeans(n_clusters=nclust, random_state=None).fit_predict(snapshot_flat)

    clusters_arr[i] = clusters

    n_clusters = len(set(clusters))
    n_clust_arr[i] = n_clusters

    if n_clusters > 1 :
        silhouette_avg = metrics.silhouette_score(snapshot_flat, clusters, metric="euclidean")
        silhouette_avg_arr[i] = silhouette_avg

        silhouette_sample = metrics.silhouette_samples(snapshot_flat, clusters, metric="euclidean")
        silhouette_sample_arr[i] = silhouette_sample


In [None]:
# plot array of clusters

gridspec_kw={"width_ratios": [9,2]}

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(7, 3), gridspec_kw=gridspec_kw)

times = np.array(list(set(tnet.network.t))) # todo: deal with times in tnet, must be over 100
#-------------- main plot with time clusters
cs.plot_time_clusters(times, clusters_arr, ax=ax1)

ax1.set_ylabel("Max # clusters")
ax1.set_xlabel("Times (min)")

ax1.set_xticks(range(0, 100+5, 10))
ax1.set_ylim([-1, ax1.get_ylim()[1]])
ax1.grid(axis="x")
ax1.set_axisbelow(True)
sb.despine(ax=ax1)

#---------------- twin plot for labels on right
ax11 = ax1.twinx()
ax11.set_ylim(ax1.get_ylim())
ax11.set_yticks(maxclust_range)

labels_right = [int(n_clust) if (i==0 or np.diff(n_clust_arr)[i-1]!=0) else '' for i, n_clust in enumerate(n_clust_arr)]
ax11.set_yticklabels(labels_right)
sb.despine(ax=ax11, right=False)

# ax2.set_ylabel("Actual # clusters")

#----------- side plot
# divider = make_axes_locatable(ax)
# ax3 = divider.append_axes("top", 1.2, pad=0.1, sharex=axScatter)
# ax2 = divider.append_axes("right", size=1, pad=0.5)#, sharey=ax)

ax2.plot(silhouette_avg_arr, n_clust_arr, 'ko-')
ax2.set_xlim(xmax=1.1)
ax2.set_ylim(ax1.get_ylim())
ax2.set_yticks(maxclust_range)
ax2.set_yticklabels(labels_right)
ax2.set_ylabel("Actual # clusters")
ax2.set_xlabel("Average silhouette")

fig.suptitle(f"Hier. clust. method: '{method}' (with {tag})")

# cs.plot_events(ax=ax1)
# cs.plot_phases(ax=ax1)

plt.subplots_adjust(wspace=0.4, top=0.8)

plt.savefig(f"phase_clusters_kmeans_{tag}.png", dpi=250, bbox_inches="tight")

In [None]:
times

# IGNORE BELOW HERE!

## Test on toy snapshots

In [None]:
N = 3
T = 4
snapshots = np.random.randint(2, size=(T, N, N))

In [None]:
snapshots

In [None]:
_, distance_matrix_condensed = cluster_snapshots.compute_snapshot_distances(snapshots)
linked = shc.linkage(distance_matrix_condensed, method="ward")

fig, ax = plt.subplots()
dend = shc.dendrogram(linked)
plt.show()