In [1]:
import polars as pl
import numpy as np
from scipy.cluster.hierarchy import linkage, leaves_list
from scipy.spatial.distance import pdist
import plotly.graph_objects as go

In [2]:
ccm = pl.read_csv("../data/output_CellCycle2X_CCM.csv", null_values=["NA"])

In [3]:
ccm.head()

Unnamed: 0_level_0,E,tau,tp,nn,lib_column,target_column,lib_size,num_pred,rho,mae,rmse
i64,i64,i64,i64,i64,i64,i64,i64,i64,f64,f64,f64
1,4,4,0,5,1,1,44,44,0.726545,0.416029,0.642524
2,4,4,0,5,1,2,44,44,0.535088,0.583298,0.714398
3,4,4,0,5,1,3,44,44,-0.283443,0.5247,0.651745
4,4,4,0,5,1,4,44,44,-0.225491,0.772743,0.976238
5,4,4,0,5,1,5,44,44,0.222429,0.646938,0.848623


In [4]:
uniq = {
    col: ccm.unique(subset=col)[col]
    for col in [
        "E",
        "tau",
        "tp",
        "nn",
        "lib_column",
        "target_column",
        "lib_size",
        "num_pred",
    ]
}

In [5]:
uniq["lib_column"]

lib_column
i64
4085
4368
4487
1209
4109
…
235
2575
1295
375


In [6]:
best_rho = (
    ccm.group_by(["lib_column", "target_column"])
    .agg(pl.max("rho"))
    .sort("lib_column", "target_column")
)

In [7]:
best_rho

lib_column,target_column,rho
i64,i64,f64
1,1,0.726545
1,2,0.535088
1,3,-0.283443
1,4,-0.225491
1,5,0.222429
…,…,…
4824,4820,-0.564208
4824,4821,0.065406
4824,4822,0.310207
4824,4823,0.196878


In [8]:
N = len(uniq["lib_column"])

Matrix = best_rho["rho"].to_numpy().reshape((N, N))
Matrix = np.nan_to_num(Matrix)

In [9]:
N_nan = np.isnan(Matrix).sum().sum()
print(f"Number of NaNs: {N_nan}", f"Percentage of NaNs: {N_nan / (N * N) * 100}%")

Number of NaNs: 0 Percentage of NaNs: 0.0%


In [10]:
def cluster(M: np.ndarray, method="ward"):
    Normailzed = (M - M.mean(axis=0)) / M.std(axis=0)
    D_row, D_col = (
        pdist(Normailzed, metric="correlation"),
        pdist(Normailzed.T, metric="correlation"),
    )
    Z_row, Z_col = linkage(D_row, method=method), linkage(D_col, method=method)
    col_indecies, row_indecies = leaves_list(Z_row), leaves_list(Z_col)
    return M[row_indecies, :][:, col_indecies]

In [19]:
Clustered = cluster(Matrix, method="average")

In [25]:
from PIL import Image
from matplotlib import colormaps

CELL_SIZE = 2


def plot_matrix(Matrix: np.ndarray, filename: str):
    Matrix = (Matrix + 1) / 2

    colormap = colormaps.get_cmap("jet")
    colors = (colormap(Matrix) * 255).astype(np.uint8)
    img = Image.fromarray(colors)

    width, height = img.size
    img = img.resize(
        (width * CELL_SIZE, height * CELL_SIZE), resample=Image.Resampling.NEAREST
    )

    img.save(filename)


plot_matrix(Clustered, "clustered_complete.png")