In [None]:
import pandas as pd
import numpy as np
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
import networkx as nx
from importlib import reload
import scipy
from sklearn import preprocessing
from sklearn import metrics

# local imports 
import graph as gr
import utils as ut

# Genome Dynamics

In [None]:
inputDir =  "/nfs/turbo/umms-indikar/shared/projects/spatial_transcriptomics/graph_data/single_cell_expression/"

keys = ["ND", "HFD8", 'HFD14']
macTypes = ['Mac1','Mac2','Mac3','Mac4','Mac5',] # macrophage subtypes

data = {}

for key in keys:
    labelFile  = f"{inputDir}{key}_labels.pq"
    rnaFile  = f"{inputDir}{key}_cpm.pq"

    # load the files
    lf = pd.read_parquet(labelFile)
    rf = pd.read_parquet(rnaFile)

    # subset the genes
    genes = rf.columns

    df = pd.merge(rf, lf, 
             how='left',
             left_index=True,
             right_on='cellId')

    df['metaType'] = np.where(df['cellType'].isin(macTypes), 'Macrophages', df['cellType'])

    X = df[genes].fillna(0)
    print(f"{key} n cells: {X.shape[0]} n genes: {X.shape[1]}")

    data[key] = {
        'labels' : lf,
        'df' : df,
        'X' : X,
    }

print('done')

In [None]:
# break

# Reduce Dimension

In [None]:
n_components = 30 # 90% explained variance
# n_components = 0.90 # 90% explained variance
pca_args = {
    'svd_solver' : 'full',
}

for key in keys:
    X = data[key]['X']
    df = data[key]['df']
    
    # construct the cell type correlation matrices from low-dimensional embeddings
    embedding, reducer = gr.reduce_dim(X, n_components, method='pca', **pca_args)

    # scale embedding so that it's always positive
    embedding = preprocessing.minmax_scale(embedding, feature_range=(0, 1))

    embedding = pd.DataFrame(embedding)
    features = embedding.columns # get the ebedding column names
    
    embedding['cellType'] = df['cellType']
    embedding['metaType'] = df['metaType']

    data[key]['embedding'] = embedding
    data[key]['reducer'] = reducer
    data[key]['features'] = features

    exVar = reducer.explained_variance_ratio_.cumsum()[-1]
    print(f"{key} explained variance: {exVar:.3f} with {n_components=}")

    break

print('done')

In [None]:
plt.rcParams['figure.dpi'] = 300
plt.rcParams['figure.figsize'] = 4, 3

colors = {
    'ND' : 'darkgreen',
    'HFD8' : 'goldenrod',
    'HFD14' : 'red',
}

for key in keys:
    reducer = data[key]['reducer']
    
    sns.lineplot(data=reducer.explained_variance_ratio_.cumsum(),
                 marker=".",
                 mec='k',
                 lw=0.5,
                 c='k',
                 mfc=colors[key])

    break

sns.despine()
plt.ylabel("Explained Variance")
plt.xlabel("Principal Component")

In [None]:
# break

# correlation matrices

In [None]:
for key in keys:
    features = data[key]['features']
    embedding = data[key]['embedding']

    genome = {}

    for celltype, group in embedding.groupby('metaType'):
        A = group[features].corr()
        genome[celltype] = A
    
        eA, _ = np.linalg.eig(A)
        print(f"{key} {celltype} max lambda: {eA.max():.3f}")

    data[key]['genome'] = genome
    print()

    break
    
    # fpath = f"{savePath}{celltype.replace(' ','_')}_{method}_{n_components}.pq"
    # A.to_parquet(fpath)
    # print(f"{celltype} {A.shape=} at: {fpath}")
    # corrs[celltype] = A
    # break

print('done')

In [None]:
data['ND'].keys()

In [None]:
"""SAVE THE DATA"""

savePath = "/nfs/turbo/umms-indikar/shared/projects/spatial_transcriptomics/graph_data/correlation_matrices/"

for i, key in enumerate(keys):
    for j, (celltype, A) in enumerate(data[key]['genome'].items()):
        print(key, celltype, A.shape)
        fpath = f"{savePath}{celltype.replace(' ','_')}_PCA_{n_components}.pq"
        A.to_parquet(fpath)
    break

print('done')

In [None]:
# break

In [None]:
plt.rcParams['figure.dpi'] = 300
plt.rcParams['figure.figsize'] = 17, 7.5

lut = 5
cmap = plt.cm.get_cmap('PuOr', lut=lut)
cmap.set_bad(color='lightgrey')

fig, axs = plt.subplots(3, 7)

for i, key in enumerate(keys):
    for j, (celltype, A) in enumerate(data[key]['genome'].items()):
        axs[i, j].imshow(A, cmap=cmap, vmin=-1, vmax=1)
        axs[i, j].set_xticks([])
        axs[i, j].set_yticks([])
        
        if i == 0:
            axs[i, j].set_title(celltype, fontsize=15)
        # if j == 0:
            # axs[i, j].set_ylabel(r"$t0$", fontsize=15, rotation=0, labelpad=10, ha='right')

    break

plt.tight_layout()
plt.show()

In [None]:
ut.makeColorbar(cmap, 2, 0.3, 'Correlation', 'horizontal', ['-1', '1'])

In [None]:
key = 'ND'
ct = 'B cells'

A = data[key]['genome'][ct]
embedding = data[key]['embedding'] # required for the mean computations
features = data[key]['features'] # required for the mean computations

edf = embedding[embedding['metaType'] == ct]
c = edf[features].mean(axis=0)

# # get eignvalues
# eA, evA = np.linalg.eigh(A)
# print(eA.max())

# get means
embedding.head()

x0, rnorm = scipy.optimize.nnls(A, c, maxiter=100)
print(f"{rnorm=:.3f} {x0.shape=}")
print()

dxdt = np.dot(A, x0) - c

plt.rcParams['figure.dpi'] = 300
plt.rcParams['figure.figsize'] = 4, 2
plt.plot(dxdt, marker=".")
plt.axhline(y=0, c='r')
# plt.plot(-1*c, marker=".")
# print(dxdt)
# print()
# print(c)


# plt.rcParams['figure.dpi'] = 300
# plt.rcParams['figure.figsize'] = 2, 2

# plt.imshow(A)
# ax = plt.gca()
# ax.axis(False)

# divider = make_axes_locatable(ax)
# ax2 = divider.append_axes("right", size="10%", pad="2%")

# ax2.imshow(x0.reshape(-1, 1))
# ax2.axis(False)

# # divider = make_axes_locatable(ax)
# ax3 = divider.append_axes("right", size="10%", pad="20%")
# ax3.imshow(c.to_numpy().reshape(-1, 1))
# ax3.axis(False)

# eA, _ = np.linalg.eig(A)
# eA = np.flip(eA)
# eig = pd.DataFrame(eA, columns=['v'])
# eig = eig.reset_index(drop=False)

# sns.scatterplot(data=eig, 
#                 x='v',
#                 y='index',
#                 ax=ax2)

In [None]:
# break

In [None]:
plt.rcParams['figure.dpi'] = 300
plt.rcParams['figure.figsize'] = 2, 2

for i, key in enumerate(keys):
    res = []
    for j, (celltype, A) in enumerate(data[key]['genome'].items()):
        eA, _ = np.linalg.eig(A)
        row = {
            'cellType' : celltype,
            'maxE' : eA.max()
        }
        res.append(row)

    res = pd.DataFrame(res)
    sns.barplot(data=res, 
                x='cellType',
                y='maxE',
                palette='viridis',
                edgecolor='k')
    
    plt.gca().tick_params(axis='x', rotation=90)
    plt.xlabel("")
    plt.ylabel(r"$\lambda_{max}$")
    # plt.title('Healthy')
    sns.despine()
    plt.show()
    break

In [None]:
# celltype = "Macrophages"
# A = corrs[celltype]
# # A = A * -1/
# eV, eW = np.linalg.eig(A)
# print(f"for A: {eV.max()}")

# n_cells = 5
# # print(f"{A.shape=}")

# B = np.kron(np.eye(n_cells), A)

# eV, eW = np.linalg.eig(B)
# print(f"for B: {eV.max()}")

# B = np.where(B == 0, np.nan, B)


# lut = 5
# cmap = plt.cm.get_cmap('magma', lut=lut)
# cmap.set_bad(color='whitesmoke')

# plt.rcParams['figure.dpi'] = 300
# plt.rcParams['figure.figsize'] = 5, 5

# plt.imshow(B, cmap=cmap, vmin=-1, vmax=1)

# plt.title(f"n={n_cells}")
# _ = plt.yticks([], [])
# _ = plt.xticks([], [])
# plt.gca().set_aspect('equal')

# Interaction Dynamics

In [None]:
dirpath = "/nfs/turbo/umms-indikar/shared/projects/spatial_transcriptomics/graph_data/"

for key in keys:
    distpath = f"{dirpath}/distances/{key}_euclidean_distances.pq"
    labelpath = f"{dirpath}/global_card_outputs.pq"
    expressonPath =  f"{dirpath}/spatial_expression/{key}_spatial_cpm.pq"

    # physical distances
    D = pd.read_parquet(distpath)
    D = D.drop(columns='key')

    # spatial data
    sdf = pd.read_parquet(expressonPath)

    # the cell tye assignments and the coordinates
    cdf = pd.read_parquet(labelpath)
    cdf = cdf[cdf['key'] == key].reset_index(drop=True)
    cdf['Macrophages'] = cdf[macTypes].sum(axis=1)

    data[key]['D'] = D
    data[key]['sdf'] = sdf
    data[key]['cdf'] = cdf

    print(f"{key} {D.shape=} {sdf.shape=} {cdf.shape=}")


print('done')

# visualization exmaple

In [None]:
q = 0.75 # quantile thresholdholding for card output, above this value is a positive hit
n_nodes = 100 # number of nodes per sample after thresholding
celltype = 'Macrophages'
key = 'ND'

cdf = data[key]['cdf']
sdf = data[key]['sdf']
D = data[key]['D']

threshold = np.quantile(cdf[celltype], q)
print(f"{key} {celltype} threshold is: {threshold:.4f}")

nbrhd = gr.get_neighborhood(cdf, 
                            center=True, 
                            n=n_nodes, 
                            metric='minkowski')

# subset the edges and coords of the spaitial data
coords = cdf[cdf['nodeId'].isin(nbrhd)].reset_index()

# subset the high-scoring locations for a specific cell type
coords['flag'] = np.where(coords[celltype] > threshold, 1, 0)
nodeSet = coords['nodeId'].to_list()    

# get the edges between the nodes selected above 
edges = D[(D['node1'].isin(nodeSet)) & (D['node2'].isin(nodeSet))].reset_index()

# get gene expression for these nodes
stx = sdf[sdf.index.isin(nodeSet)]
print(f"{coords.shape=} {edges.shape=} {stx.shape=}")

corr = []
pvals = []
flag1 = []
flag2 = []

# # compute correlations
for n1, n2 in edges[['node1', 'node2']].values:
    g1 = stx.loc[n1, :].values
    g2 = stx.loc[n2, :].values

    score, pval = scipy.stats.pearsonr(g1, g2)
    corr.append(score)
    pvals.append(pval)

    f1 = coords.loc[coords['nodeId'] == n1, 'flag'].values[0]
    f2 = coords.loc[coords['nodeId'] == n2, 'flag'].values[0]

    flag1.append(f1)
    flag2.append(f2)



edges['flag1'] = flag1
edges['flag2'] = flag2
edges['flag'] = edges['flag1'] + edges['flag2']
print(edges['flag'].value_counts())

edges['correlation'] = corr
edges['pvals'] = pvals
edges['w'] = edges['correlation'] / (edges['d'] ** 0.25)
# edges['wt'] = np.where(edges['w'] > 0.15, 1, 0)


edges.head()





In [None]:
G = gr.build_graph(edges, coords)

node_attr = coords.set_index('nodeId').to_dict('index')
nx.set_node_attributes(G, node_attr)    

plt.rcParams['figure.dpi'] = 300
# plt.rcParams['figure.facecolor'] = "none"
plt.rcParams['figure.figsize'] = 8, 4

# plotting params
node_size = 60
lut = 5
# cmap = plt.cm.get_cmap('Greys', lut=lut)
cmap = plt.cm.get_cmap('RdYlGn', lut=lut)
fig, ax = plt.subplots(1, 3)


"""DRAW NODE SELECTION"""\
# map nodes to colors
colorMap = {
    1 : 'r',
    0 : 'lightgrey',
}
node_colors = [colorMap[n[1]['flag']] for n in G.nodes(data=True)]
node_colors

# network plots
nx.draw_networkx_nodes(G,
                       pos=G.pos,
                       node_size=node_size,
                       node_color=node_colors,
                       edgecolors='k',
                       linewidths=1.5,
                       ax=ax[0])

ax[0].set_aspect('equal')
ax[0].axis(False)

"""DISTANCE WEIGHTED GRAPH"""
node_list = [n[0] for n in G.nodes(data=True) if n[1]['flag'] == 1]
edge_list = [(e[0], e[1]) for e in G.edges(data=True) if e[2]['flag'] == 2]
eweights = [e[2]['w'] for e in G.edges(data=True) if e[2]['flag'] == 2]
eweights = preprocessing.minmax_scale(eweights, feature_range=(0, 1))

nx.draw_networkx_nodes(G,
                       pos=G.pos,
                       node_size=node_size,
                       node_color=colorMap[1],
                       nodelist=node_list,
                       edgecolors='k',
                       linewidths=1.5,
                       ax=ax[1])

nx.draw_networkx_edges(G,
                       pos=G.pos,
                       nodelist=node_list,
                       edgelist=edge_list,
                       width=eweights*2,
                       # edge_color='k',
                       edge_color=eweights*5,
                       edge_cmap=cmap, 
                       # alpha=0.75,
                       alpha=eweights,
                       ax=ax[1])   

ax[1].set_aspect('equal')
ax[1].axis(False)

""" Spatial Adjancency """
node_list = [n[0] for n in G.nodes(data=True) if n[1]['flag'] == 1]
edge_list = [(e[0], e[1]) for e in G.edges(data=True) if e[2]['flag'] == 2]
eweights = [e[2]['w'] for e in G.edges(data=True) if e[2]['flag'] == 2]
eweights = preprocessing.minmax_scale(eweights, feature_range=(0, 1))

eweights = np.where(eweights > 0.56, 'k', 'none')

nx.draw_networkx_nodes(G,
                       pos=G.pos,
                       node_size=node_size,
                       node_color=colorMap[1],
                       nodelist=node_list,
                       edgecolors='k',
                       linewidths=1.5,
                       ax=ax[2])

nx.draw_networkx_edges(G,
                       pos=G.pos,
                       nodelist=node_list,
                       edgelist=edge_list,
                       width=1,
                       edge_color=eweights,
                       # alpha=0.5,
                       ax=ax[2])   

ax[2].set_aspect('equal')
ax[2].axis(False)



In [None]:
""" Save the spatial data """

outdir = "/nfs/turbo/umms-indikar/shared/projects/spatial_transcriptomics/graph_data/for_amit/"

# get the spatial adjancency matrix
W = nx.adjacency_matrix(H, weight='w')
W = W.todense()
W *= (1.0/W.max()) #
W = np.where(W > 0.4, 1, 0)

nodes = pd.DataFrame.from_dict(dict(H.nodes(data=True)), orient='index')
nodes = nodes.drop(columns=['index'])
nodes = nodes.reset_index(drop=False)
nodes = nodes.rename(columns={'index' : 'nodeId'})

""" Save the node data """
cols = [
    'nodeId', 'x', 'y', 'Macrophages',
]

fname = f"{outdir}node_data.csv"
nodes[cols].to_csv(fname, index=False)
node_list = nodes['nodeId'].to_list()
print(f"Saved node data: {fname}")

""" save the spatial adjancency """
W = pd.DataFrame(W, index=node_list, columns=node_list)
W = W.reset_index(drop=False)
W = W.rename(columns={'index' : 'nodeId'})

fname = f"{outdir}A.csv"
W.to_csv(fname, index=False)
print(f"Saved node data: {fname}")

""" save the edge data """

fedges = edges[(edges['node1'].isin(node_list)) & (edges['node2'].isin(node_list)) ]
cols = [
    'node1', 
    'node2', 
    'd', 
    'correlation',
]

fname = f"{outdir}edges.csv"
fedges[cols].to_csv(fname, index=False)
print(f"Saved node data: {fname}")

# fedges[cols].head()

# W = pd.DataFrame()





In [None]:
fedges.columns

In [None]:


A = data[key]['genome'][celltype]
A = np.where(A.abs() > 0.5, 1, 0)

H = G.edge_subgraph(edge_list)
W = nx.adjacency_matrix(H, weight='w')
W = W.todense()
W *= (1.0/W.max()) #
W = np.where(W > 0.4, 1, 0)

K = np.kron(W, A)

eK, evK = np.linalg.eigh(K)

# rowSum = K.sum(axis=1)
# dinv = [1/np.sqrt(x) for x in rowSum]

# degree = np.diag(dinv)
# L = degree - K
# L = np.dot(np.dot(degree, L), degree) #normalized laplacian 


fig, ax = plt.subplots(1, 3)
ax[0].imshow(A, cmap='plasma')
ax[1].imshow(W)
ax[2].imshow(K)
# ax[3].imshow(L)

ax[0].axis(False)
ax[1].axis(False)
ax[2].axis(False)
plt.tight_layout()

In [None]:
plt.plot(np.flip(eK))
# plt.xscale('log')
# plt.yscale('log')

In [None]:
break

In [None]:
data[key].keys()
    

In [None]:
sample_size = 1 # number of samples
q = 0.9 # quantile thresholdholding for card output, above this value is a positive hit
n_nodes = 30 # number of nodes per sample after thresholding




for key in keys:
    cdf = data[key]['cdf']
    df = data[key]['df']

    # graphs for each cell type
    for celltype in sorted(df['metaType'].unique()):
        threshold = np.quantile(cdf[celltype], q)
        print(f"{key} {celltype} threshold is: {threshold:.4f}")
        
        coords = cdf.copy()
        coords['flag'] = np.where(coords[celltype] > threshold, 1, 0)

        break
    
    print()
    break

    # cdf = cdf[cdf[celltype] > threshold]
    
    # print(cdf.shape)

In [None]:
reload(gr)

celltype = "Macrophages"
sample_size = 1 # number of samples
nodes = 100 # number of nodes per sample
q = 0.7 # quantile tyresholdholding for card output

threshold = np.quantile(cdf[celltype], q)
print(f"{celltype} threshold={threshold:.4f} in {key}")

graphs = []

for i in range(sample_size):
    # point = cdf[['x', 'y']].sample(1).to_numpy().ravel()

    nbrhd = gr.get_neighborhood(cdf, 
                                center=True, 
                                n=nodes, 
                                metric='minkowski')

    # subset the edges and coords of the spaitial data
    coords = cdf[cdf['nodeId'].isin(nbrhd)].reset_index()

    # subset the high-scoring locations for a specific cell type
    # coords = coords[coords[celltype] > threshold]
    nodeSet = coords['nodeId'].to_list()    

    # get the edges between the nodes selected above 
    edges = D[(D['node1'].isin(nodeSet)) & (D['node2'].isin(nodeSet))].reset_index()


    # get gene expression for these nodes
    stx = sdf[sdf.index.isin(nodeSet)]
    print(f"{coords.shape=} {edges.shape=} {stx.shape=}")

    corr = []
    pvals = []

    # # compute correlations
    for n1, n2 in edges[['node1', 'node2']].values:

        g1 = stx.loc[n1, :].values
        g2 = stx.loc[n2, :].values

        score, pval = scipy.stats.pearsonr(g1, g2)
        corr.append(score)
        pvals.append(pval)
    

    edges['correlation'] = corr
    edges['pvals'] = pvals
    edges['w'] = edges['correlation'] / (edges['d'] ** 0.25)
    edges['wt'] = np.where(edges['w'] > 0.15, 1, 0)
    
    G = gr.build_graph(edges, coords)
    graphs.append(G)


print('done')

In [None]:
node_colors = G.nodes()
node_colors

In [None]:

G = graphs[0]
relationship = 'w'

plt.rcParams['figure.dpi'] = 300
plt.rcParams['figure.facecolor'] = "none"
plt.rcParams['figure.figsize'] = 9, 9

# plotting params
lut = 5
# cmap = plt.cm.get_cmap('Greys', lut=lut)
cmap = plt.cm.get_cmap('RdYlGn', lut=lut)
fig, ax = plt.subplots()

node_colors = []


for n in G.nodes():
    nodeRecord = coords[coords['nodeId'] == n]

    if nodeRecord[celltype].values[0] > threshold:
        node_colors.append("r")
    else:
        node_colors.append("lightgrey")
    
    # break


# network plots
nx.draw_networkx_nodes(G,
                       pos=G.pos,
                       node_size=150,
                       # node_color='lightgrey',
                       node_color=node_colors,
                       edgecolors='k',
                       linewidths=1.5,
                       ax=ax)

eweights = np.array([e[relationship] for node1, node2, e in G.edges(data=True)])
eweights = preprocessing.minmax_scale(eweights, feature_range=(0, 1))


# print(eweights)

# nx.draw_networkx_edges(G,
#                        pos=G.pos,
#                        # width=3,
#                        width=eweights*6,
#                        # edge_color='k',
#                        edge_color=eweights,
#                        edge_cmap=cmap, 
#                        # alpha=0.75,
#                        alpha=eweights,
#                        ax=ax)   

# sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin = -1, vmax=1))
# sm._A = []
# plt.colorbar(sm)

ax.set_aspect('equal')
ax.axis(False)

In [None]:
# np.isnan(edgecolors)

In [None]:
# plot the adjacency matrix

plt.rcParams['figure.dpi'] = 300
plt.rcParams['figure.facecolor'] = "none"
plt.rcParams['figure.figsize'] = 9, 9

Adj = nx.adjacency_matrix(G, weight='w')
Adj = Adj.todense()
Adj *= (1.0/Adj.max()) # need something different for arrays

lut = 5
cmap = plt.cm.get_cmap('viridis', lut=lut)
cmap.set_bad(color='whitesmoke')

fig, axs = plt.subplots(1, 2)

axs[0].imshow(Adj, cmap, vmin=0, vmax=1)

axs[0].set_yticks([], [])
axs[0].set_xticks([], [])
axs[0].set_aspect('equal')


Adj_binary = np.where(Adj > 0.6, 1, 0)
axs[1].imshow(Adj_binary, cmap='binary')
axs[1].set_yticks([], [])
axs[1].set_xticks([], [])
axs[1].set_aspect('equal')

plt.tight_layout()

In [None]:
# plot the adjacency matrix

plt.rcParams['figure.dpi'] = 300
plt.rcParams['figure.facecolor'] = "none"
plt.rcParams['figure.figsize'] = 9, 9

Adj = nx.adjacency_matrix(G, weight='d')
Adj = Adj.todense()
Adj *= (1.0/Adj.max()) # need something different for arrays

lut = 5
cmap = plt.cm.get_cmap('coolwarm', lut=lut)
cmap.set_bad(color='whitesmoke')


plt.imshow(Adj, cmap, vmin=0, vmax=1)

axs = plt.gca()

axs.set_yticks([], [])
axs.set_xticks([], [])
axs.set_aspect('equal')


# Adj_binary = np.where(Adj > 0.6, 1, 0)
# axs[1].imshow(Adj_binary, cmap='binary')
# axs[1].set_yticks([], [])
# axs[1].set_xticks([], [])
# axs[1].set_aspect('equal')

# plt.tight_layout()

In [None]:
break

In [None]:
ldf.head()

In [None]:

break

In [None]:
# n = len(corrs) * n_components
# B = np.zeros((n, n))

# # plotting stuff
# yticks = []
# ylabels = []
# lines = []

# for i, (celltype, A) in enumerate(corrs.items()):
#     start = i * n_components
#     end = ((i + 1) * n_components)
#     midpoint = (end + start) / 2
#     yticks.append(midpoint)
#     ylabels.append(celltype)
#     lines.append(end)

#     # build the block matrix
#     B[start:end, start:end] = A
    

# plt.rcParams['figure.dpi'] = 300

# plt.imshow(B, 
#            cmap='magma',
#            vmin=0)    
# _ = plt.yticks(yticks, ylabels)
# _ = plt.xticks([], [])

# for l in lines:
#     plt.axvline(x=l-0.5, c='w', lw=1)
#     plt.axhline(y=l-0.5, c='w', lw=1)


In [None]:
plt.rcParams['figure.dpi'] = 300
plt.rcParams['figure.figsize'] = 15, 3

fig, axs = plt.subplots(1, len(corrs))

for i, (celltype, A) in enumerate(corrs.items()):
    axs[i].imshow(A, cmap='magma', vmin=-1, vmax=1)
    axs[i].set_title(celltype)
    axs[i].axis(False)

In [None]:
celltype = "Macrophages"
A = corrs[celltype]
n_cells = 5
print(f"{A.shape=}")

B = np.kron(np.eye(n_cells), A)

plt.imshow(B, cmap='coolwarm')    

_ = plt.yticks([], [])
_ = plt.xticks([], [])

In [None]:
key = "ND"

dirpath = "/nfs/turbo/umms-indikar/shared/projects/spatial_transcriptomics/graph_data/"
coordpath = f"{dirpath}coordinates.pq"
distpath = f"{dirpath}/distances/{key}_euclidean_distances.pq"
edgepath = f"{dirpath}/edge_lists/{key}_harmonic_highlevel_edgelist.pq"

In [None]:
break

In [None]:
# # build an arbitrary sizes A matrix for a sinfle cell type

# celltype = "Macrophages"
# A = corrs[celltype]
# print(f"{A.shape=}")
# n_cells = 5
# n = n_cells * n_components
# B = np.zeros((n, n))
# print(f"{B.shape=}")

# # plotting stuff
# yticks = []
# ylabels = []
# lines = []

# for i in range(n_cells):
#     start = i * n_components
#     end = ((i + 1) * n_components)
#     midpoint = (end + start) / 2

#     label = f"{celltype}_{i+1}"
    
#     yticks.append(midpoint)
#     ylabels.append(label)
#     lines.append(end)
    
#     B[start:end, start:end] = A
    
# plt.rcParams['figure.dpi'] = 300

# plt.imshow(B, 
#            cmap='magma',)    
# _ = plt.yticks(yticks, ylabels)
# _ = plt.xticks([], [])

# for l in lines:
#     plt.axvline(x=l-0.5, c='w', lw=1)
#     plt.axhline(y=l-0.5, c='w', lw=1)


In [None]:
celltype = "Macrophages"
A = corrs[celltype]
n_cells = 5
print(f"{A.shape=}")

B = np.kron(np.eye(n_cells), A)

plt.imshow(B, cmap='coolwarm')    

_ = plt.yticks([], [])
_ = plt.xticks([], [])

In [None]:
key = "ND"

dirpath = "/nfs/turbo/umms-indikar/shared/projects/spatial_transcriptomics/graph_data/"
coordpath = f"{dirpath}coordinates.pq"
distpath = f"{dirpath}/distances/{key}_euclidean_distances.pq"
edgepath = f"{dirpath}/edge_lists/{key}_harmonic_highlevel_edgelist.pq"

In [None]:
celltype = "Macrophages"
A = corrs[celltype]
n_cells = 5
B = np.kron(np.eye(n_cells), A)

print(f"{A.shape=}")
print(f"{B.shape=}")

eA, _ = np.linalg.eigh(A)
eB, _ = np.linalg.eigh(B)

print(eA.round(2))

# plt.rcParams['figure.dpi'] = 300
# plt.rcParams['figure.figsize'] = 3, 3
# sns.lineplot(data=np.flip(eA),
#              marker=".", 
#              markeredgecolor='k',
#              lw=1)

# sns.lineplot(data=np.flip(eB),
#              marker=".", 
#              markeredgecolor='k',
#              lw=1,
#              color='C1')

# sns.despine()
# # plt.xscale('log')
# plt.yscale('log')

In [None]:
eA, _ = np.linalg.eig(-1*A)
eA.max()

In [None]:
eA, _ = np.linalg.eig(A)
eA.max()