# Cluster analysis

## Imports, preparations, global variables and functions

In [1]:
from glob import glob
from collections import namedtuple
import pandas as pd
import numpy as np
from scipy.stats import kstest
import napari
from imodmodel import ImodModel
import matplotlib

def sample_to_cell_line(s) -> str:
    return {
        "Rab11wt": "A549-Rab11wt",
        "A549wt": "A549wt",
        "Rab11dn": "A549-Rab11dn",
    }[s.split("_")[1]]

def load_sample(path) -> dict:
    model_by_glob = lambda s: glob(f"{path}/mtk/{s}")[0]
    return dict(
        name=path,
        cell_line=sample_to_cell_line(path),
        vRNP_model_path=model_by_glob("*_mtk_vRNPs.mod"),
        HA_model_path=model_by_glob("*_mtk_HA.mod"),
        vRNP_model_path_shifted=model_by_glob("*_mtk_vRNPs_shifted.mod"),
        HA_model_path_shifted=model_by_glob("*_mtk_HA_shifted.mod"),
    )
                    
def contour_lengths(obj):
    lengths = np.zeros(len(obj.contours))
    for i, contour in enumerate(obj.contours):
        lengths[i] = np.linalg.norm(contour.points[1] - contour.points[0])
    return lengths

def find_clusters(thresh, distance_matrix):
    clusters = np.full(len(distance_matrix), -1, dtype=int)
    while -1 in clusters:
        indices = np.array([np.argmin(clusters)])
        last_indices = np.array([-1])
        while not np.array_equiv(indices, last_indices):
            last_indices = indices
            indices = np.unique(np.concatenate([
                last_indices,
                np.nonzero(np.any(distance_matrix[last_indices] < thresh, axis=0))[0]
            ]))
        clusters[indices] = clusters.max() + 1
    return clusters

def reciprocal_distance_matrix(model, obj_number, subtract_constant) -> np.ndarray:
    distances = contour_lengths(model.objects[obj_number]) * model.header.pixelsize
    n_contours = np.where(np.arange(0, len(distances) + 1).cumsum() == len(distances))[0].item() + 1
    mat = np.full((n_contours, n_contours), np.inf, dtype=float)
    indices = [0] + list(np.arange(1, n_contours)[::-1].cumsum())
    for c in range(n_contours - 1):
        # First, check if any of the other vRNPs is closer than the currently stored closest one
        mat[c, c+1:] = distances[indices[c]:indices[c + 1]]
        mat[c+1:, c] = distances[indices[c]:indices[c + 1]]
    return np.maximum(0, mat - subtract_constant)

def distance_matrix(model, obj_number, n_vRNPs, subtract_constant) -> np.ndarray:
    distances = contour_lengths(model.objects[obj_number]) * model.header.pixelsize
    distances = distances.reshape((n_vRNPs, -1))
    return np.maximum(0, distances - subtract_constant)

def row_by_contour(d, sample: str, contour: int):
    return data[data["sample"].str.fullmatch(sample) & (data["vRNP contour"] == contour)]

samples = [
    load_sample(s) for s in [
        "PR8_Rab11wt_210503_TS_01",
        "PR8_Rab11wt_210120_TS_02",
        "PR8_A549wt_211220_TS_16",
        "PR8_A549wt_220408_TS_07",
        'PR8_Rab11dn_211201_TS_06',
        'PR8_Rab11dn_220815_TS_10',
        'PR8_Rab11dn_220815_TS_07',
    ]]

## Read data into one big dataframe

vRNP_diameter = 14
min_cluster_size = 3
cluster_distance_threshold = 4

dfs = list()
for sample in samples:
    vRNP_model = ImodModel.from_file(sample["vRNP_model_path"])
    HA_model = ImodModel.from_file(sample["HA_model_path"])
    vRNP_model_shifted = ImodModel.from_file(sample["vRNP_model_path_shifted"])
    HA_model_shifted = ImodModel.from_file(sample["HA_model_path_shifted"])

    data = dict(sample=sample["name"], cell_line=sample["cell_line"])
    sample["pixel_size"] = vRNP_model.header.pixelsize

    # vRNP coordinates
    obj = vRNP_model.objects[0]
    vRNP_coords = list()
    for contour in obj.contours:
        vRNP_coords.append(contour.points)
    data["vRNP coords"] = vRNP_coords
    data = pd.DataFrame(data)

    # vRNP vRNP distances
    sample["vRNP_distance_matrix"] = reciprocal_distance_matrix(vRNP_model, -2, vRNP_diameter)
    n_vRNPs = len(sample["vRNP_distance_matrix"])
    data["vRNP contour"] = np.arange(1, n_vRNPs + 1)
    data["vRNP vRNP distance [nm]"] = np.amin(sample["vRNP_distance_matrix"], axis=0)
    data["Closest vRNP index"] = np.argmin(sample["vRNP_distance_matrix"], axis=0)

    # vRNP membrane distances
    sample["HA_distance_matrix"] = distance_matrix(HA_model, -2, n_vRNPs, vRNP_diameter/2)
    data["vRNP membrane distance [nm]"] = np.amin(sample["HA_distance_matrix"], axis=1)

    # Same for shifted objects
    # vRNP coordinates
    obj = vRNP_model_shifted.objects[0]
    vRNP_coords = list()
    for contour in obj.contours:
        vRNP_coords.append(contour.points)
    data["vRNP coords shifted"] = vRNP_coords
    # vRNP vRNP distances
    sample["vRNP_distance_matrix_shifted"] = reciprocal_distance_matrix(vRNP_model_shifted, -2, vRNP_diameter)
    data["vRNP vRNP distance shifted [nm]"] = np.amin(sample["vRNP_distance_matrix_shifted"], axis=0)
    data["Closest vRNP index shifted"] = np.argmin(sample["vRNP_distance_matrix_shifted"], axis=0)
    # vRNP membrane distances
    sample["HA_distance_matrix_shifted"] = distance_matrix(HA_model_shifted, -2, n_vRNPs, vRNP_diameter/2)
    data["vRNP membrane distance shifted [nm]"] = np.amin(sample["HA_distance_matrix_shifted"], axis=1)

    # Clusters
    data["Cluster ID"] = find_clusters(cluster_distance_threshold, sample["vRNP_distance_matrix"])
    cluster_sizes = np.unique(data["Cluster ID"], return_counts=True)[1]
    data["Cluster size"] = np.apply_along_axis(lambda i: cluster_sizes[i], 0, data["Cluster ID"])
    # Remove small clusters
    colname = f"Cluster ID > {min_cluster_size}"
    data[colname] = data.apply(lambda row: row["Cluster ID"] if row["Cluster size"] > min_cluster_size else -1, axis=1)
    old_ids = np.unique(data[colname])
    new_ids = np.arange(-1, len(old_ids) - 1)
    assert len(old_ids) == len(new_ids)
    data[colname] = data[colname].transform(lambda i: new_ids[np.where(old_ids == i)[0].item()])
    # Find shortest distance of cluster to membrane
    # Assign colors
    data["Color_tab20"] = data[colname].transform(
        lambda i: matplotlib.colormaps['tab20'](i)
        if i >= 0 else np.array([0.5, 0.5, 0.5, 1]))
    data["Color_hsv"] = data[colname].transform(
        lambda i: matplotlib.colormaps['hsv'](i / max(1, data[colname].max()) * 0.9)
        if i >= 0 else np.array([0.5, 0.5, 0.5, 1]))

    # Clusters (shifted)
    data["Cluster ID shifted"] = find_clusters(cluster_distance_threshold, sample["vRNP_distance_matrix_shifted"])
    cluster_sizes = np.unique(data["Cluster ID shifted"], return_counts=True)[1]
    data["Cluster size shifted"] = np.apply_along_axis(lambda i: cluster_sizes[i], 0, data["Cluster ID shifted"])
    # Remove small clusters
    colname = f"Cluster ID shifted > {min_cluster_size}"
    data[colname] = data.apply(lambda row: row["Cluster ID shifted"] if row["Cluster size shifted"] > min_cluster_size else -1, axis=1)
    old_ids = np.unique(data[colname])
    new_ids = np.arange(-1, len(old_ids) - 1)
    assert len(old_ids) == len(new_ids)
    data[colname] = data[colname].transform(lambda i: new_ids[np.where(old_ids == i)[0].item()])
    # Assign colors
    data["Color_tab20 shifted"] = data[colname].transform(
        lambda i: matplotlib.colormaps['tab20'](i)
        if i >= 0 else np.array([0.5, 0.5, 0.5, 1]))
    data["Color_hsv shifted"] = data[colname].transform(
        lambda i: matplotlib.colormaps['hsv'](i / max(1, data[colname].max()) * 0.9)
        if i >= 0 else np.array([0.5, 0.5, 0.5, 1]))

    # Append to final list
    dfs.append(data)

data = pd.concat(dfs, ignore_index=True)
del dfs
# Number the tomograms for each sample
sample_nums = list()
for cell_line, df in data.groupby(["cell_line"]):
    idx = 1
    for sample, subdf in df.groupby("sample"):
        sample_nums.append(pd.DataFrame(index=subdf.index, data={"Tomogram #": idx}))
        idx += 1
sample_nums = pd.concat(sample_nums)
data = data.join(sample_nums)
data.to_csv("nearest_neighbor_analysis.csv", index=False)
data.head()

Unnamed: 0,sample,cell_line,vRNP coords,vRNP contour,vRNP vRNP distance [nm],Closest vRNP index,vRNP membrane distance [nm],vRNP coords shifted,vRNP vRNP distance shifted [nm],Closest vRNP index shifted,...,Cluster size,Cluster ID > 3,Color_tab20,Color_hsv,Cluster ID shifted,Cluster size shifted,Cluster ID shifted > 3,Color_tab20 shifted,Color_hsv shifted,Tomogram #
0,PR8_Rab11wt_210503_TS_01,A549-Rab11wt,"[[965.0, 3642.0, 233.50003051757812], [1064.66...",1,0.909301,424,15.347978,"[[1958.518798828125, 2746.15625, 299.0], [2058...",30.1144,183,...,13,0,"(0.12156862745098039, 0.4666666666666667, 0.70...","(1.0, 0.0, 0.0, 1.0)",0,1,-1,"[0.5, 0.5, 0.5, 1.0]","[0.5, 0.5, 0.5, 1.0]",2
1,PR8_Rab11wt_210503_TS_01,A549-Rab11wt,"[[920.2327880859375, 3714.22607421875, 281.858...",2,1.622792,5,31.295976,"[[1456.239013671875, 325.1280212402344, 251.99...",2.468414,77,...,13,0,"(0.12156862745098039, 0.4666666666666667, 0.70...","(1.0, 0.0, 0.0, 1.0)",1,3,-1,"[0.5, 0.5, 0.5, 1.0]","[0.5, 0.5, 0.5, 1.0]",2
2,PR8_Rab11wt_210503_TS_01,A549-Rab11wt,"[[3122.000244140625, 2304.0, 225.5000305175781...",3,3.85139,3,15.994725,"[[3070.345947265625, 641.43212890625, 315.0], ...",7.903561,267,...,2,-1,"[0.5, 0.5, 0.5, 1.0]","[0.5, 0.5, 0.5, 1.0]",2,1,-1,"[0.5, 0.5, 0.5, 1.0]","[0.5, 0.5, 0.5, 1.0]",2
3,PR8_Rab11wt_210503_TS_01,A549-Rab11wt,"[[3108.0, 2442.0, 225.50003051757812], [3006.0...",4,3.85139,2,14.6529,"[[3311.0029296875, 4944.3291015625, 292.0], [3...",34.918543,120,...,2,-1,"[0.5, 0.5, 0.5, 1.0]","[0.5, 0.5, 0.5, 1.0]",3,1,-1,"[0.5, 0.5, 0.5, 1.0]","[0.5, 0.5, 0.5, 1.0]",2
4,PR8_Rab11wt_210503_TS_01,A549-Rab11wt,"[[2848.0, 2106.0, 225.50003051757812], [2828.0...",5,4.752192,493,11.538369,"[[1534.480224609375, 4986.94189453125, 36.0], ...",0.0,398,...,1,-1,"[0.5, 0.5, 0.5, 1.0]","[0.5, 0.5, 0.5, 1.0]",4,2,-1,"[0.5, 0.5, 0.5, 1.0]","[0.5, 0.5, 0.5, 1.0]",2


## Cluster analysis

In [2]:
dfs = list()
for sample in samples:
    for thresh in range(0, 13):
        for matrix in ("vRNP_distance_matrix", "vRNP_distance_matrix_shifted"):
            cluster_id = find_clusters(thresh, sample[matrix])
            cluster_sizes = np.unique(cluster_id, return_counts=True)[1]
            cluster_size_per_id = np.apply_along_axis(lambda i: cluster_sizes[i], 0, cluster_id)
            dfs.append(pd.DataFrame(dict(
                sample=sample["name"],
                cell_line=sample["cell_line"],
                matrix=matrix,
                threshold=thresh,
                cluster_id=list(cluster_id),
                cluster_size=list(cluster_size_per_id),
            )))
cluster_size_data = pd.concat(dfs, ignore_index=True)
del dfs
cluster_size_data["matrix"] = cluster_size_data["matrix"].transform({
    "vRNP_distance_matrix": "segmentation",
    "vRNP_distance_matrix_shifted": "random",
}.get)
cluster_size_data.to_csv("cluster_size_analysis.csv", index=False)
cluster_size_data.head()

Unnamed: 0,sample,cell_line,matrix,threshold,cluster_id,cluster_size
0,PR8_Rab11wt_210503_TS_01,A549-Rab11wt,segmentation,0,0,1
1,PR8_Rab11wt_210503_TS_01,A549-Rab11wt,segmentation,0,1,1
2,PR8_Rab11wt_210503_TS_01,A549-Rab11wt,segmentation,0,2,1
3,PR8_Rab11wt_210503_TS_01,A549-Rab11wt,segmentation,0,3,1
4,PR8_Rab11wt_210503_TS_01,A549-Rab11wt,segmentation,0,4,1


## Napari

In [3]:
napari_window = False
if napari_window:
    import napari
    viewer = napari.Viewer(ndisplay=3)

In [4]:
if napari_window:
    sample = samples[4]
    sample_data = data[data["sample"].str.match(sample["name"])]
    vRNP_coords = sample_data["vRNP coords"].transform(lambda contour: contour[:, (2, 1, 0)])
    layer_contours = viewer.add_shapes(vRNP_coords.to_list(), shape_type="path")
    vRNP_coords_shifted = sample_data["vRNP coords shifted"].transform(lambda contour: contour[:, (2, 1, 0)])
    layer_contours_shifted = viewer.add_shapes(vRNP_coords_shifted.to_list(), shape_type="path")
    model_layers = viewer.open(sample["vRNP_model_path"], plugin="napari-imodmodel")

    # Color
    layer_contours.edge_color = list(sample_data["Color_tab20"])
    layer_contours.edge_width = vRNP_diameter / 2 / sample["pixel_size"]
    # Same for shifted
    layer_contours_shifted.edge_color = list(sample_data["Color_tab20 shifted"])
    layer_contours_shifted.edge_width = 8
    viewer.camera.angles = (90, 0, -90)

In [5]:
if napari_window:
    viewer.camera.angles = (90, 0, -90)

In [6]:
if napari_window:
    viewer.camera.angles = (90, 0, 90)