In [None]:
import mmcci
import pickle
import pandas as pd

data_dir = "/scratch/project/stseq/Levi/SkinCancerAtlas/cci_results/"

In [None]:
samples = {
    'visium_scc': {
        'B18_SCC': 'B18_SCC_stlearn.pkl',
        'E15_SCC': 'E15_SCC_stlearn.pkl',
        'F21_SCC': 'F21_SCC_stlearn.pkl',
        'P13_SCC': 'P13_SCC_stlearn.pkl',
        'P30_SCC': 'P30_SCC_stlearn.pkl'
    },
    'visium_bcc': {
        'B18_BCC': 'B18_BCC_stlearn.pkl',
        'E15_BCC': 'E15_BCC_stlearn.pkl',
        'F21_BCC': 'F21_BCC_stlearn.pkl'
    },
    'visium_mela': {
        '6767_Mel': '6767_Mel_stlearn.pkl',
        '21031_Mel': '21031_Mel_stlearn.pkl',
        '48974_Mel': '48974_Mel_stlearn.pkl',
        '66487_Mel': '66487_Mel_stlearn.pkl'
    },
    'xenium_mela': {
        '23346-10SP': '23346-10SP_stlearn.pkl',
        '30037-07BR': '30037-07BR_stlearn.pkl',
        '53023-07BR': '53023-07BR_stlearn.pkl',
        '6475-07FC': '6475-07FC_stlearn.pkl',
        '9474-06BR': '9474-06BR_stlearn.pkl',
        '98594-09PY': '98594-09PY_stlearn.pkl'
    },
    'cosmx_scc': {
        'B18_SCC': 'B18_SCC_cutoff50_cosmx_stlearn.pkl',
        'P13_SCC': 'P13_cutoff50_cosmx_stlearn.pkl',
        'P30_SCC': 'P30_cutoff50_cosmx_stlearn.pkl'
    },
    'cosmx_bcc': {
        'B18_BCC': 'B18_BCC_cutoff50_cosmx_stlearn.pkl',
        'D12_BCC': 'D12_cutoff50_cosmx_stlearn.pkl',
    },
    'cosmx_mela': {
        '6747-085P': '6747-085P_cutoff50_cosmx_stlearn.pkl',
        '66487-1A': '66487-1A_cutoff50_cosmx_stlearn.pkl',
        '48974-2B': '48974-2B_cutoff50_cosmx_stlearn.pkl',
        '21031-08TB': '21031-08TB_cutoff50_cosmx_stlearn.pkl'
    }
}

In [None]:
for tech, sample in samples.items():
    for sample_name, file_name in samples[tech].items():
        with open(data_dir + file_name, 'rb') as f:
            samples[tech][sample_name] = mmcci.io.convert_stLearn(pickle.load(f))
            # pd.DataFrame(
            #     samples[tech][sample_name]['adata'].obsm['lr_sig_scores'], 
            #     index=samples[tech][sample_name]['adata'].obs_names, 
            #     columns=samples[tech][sample_name]['adata'].uns['lr_summary'].index
            #     ).to_csv(data_dir + tech + '_' + sample_name + '_lr_sig_scores.csv')

In [None]:
colors = {
    "DC": "#5f9d9e",
    "Endothelial Cell": "#f8a41e",
    "Fibroblast": "#458b41",
    "KC Basal": "#f16b6b",
    "KC Cornified": "#9a1f61",
    "KC Granular": "#c72685",
    "KC Differentiating": "#9583bd",
    "KC Hair": "#eb2627",
    "LC": "#37479b",
    "Macrophage": "#eae71d",
    "Melanocytes": "#8b471f",
    "Melanoma": "black",
    "NK": "#99ca3e",
    "T Cell": "#41baeb",
    "Treg": "#bbe5f3",
    "pDC": "#66bf76",
    "Mast Cell": "#7f8133",
    "mRegDC":"#105146",
    'CD8+ T Cell': "#406573",
    "B Cell":"#fed9b9",
    "Pericytes":"#dca566",
    "Sweat gland related":"#f2634b",
    "nan":"grey",
    'Schwann Cell':'#0b507b',
   'Plasma' : '#f1ea9d',
   'B Cell' : '#fed9b9',
   'KC IFN' : '#f06ba8', 
   "Monocytes":"#9cc7a1",
   "KC Dysplastic": "#d8c0dd",
   "Ambiguous":"grey"
}

# plot colors
import matplotlib.pyplot as plt

# plot color2dict
fig, ax = plt.subplots()
for i, (key, value) in enumerate(colors.items()):
    ax.bar(i, 0, color=value, label=key)
#make legend fully opaque
leg = ax.legend()
#remove axis ticks and lines
ax.axis('off')
plt.show()

In [None]:
for group, sample_dict in samples.items():
    for sample_name in sample_dict.keys():
        samples[group][sample_name] = samples[group][sample_name].scale_by_nspots()
        samples[group][sample_name] = samples[group][sample_name].filter_by_p_vals(assay='scaled')

In [None]:
visium_scc = mmcci.it.lr_integration(list(samples['visium_scc'].values()), method=">=50%", assay="scaled", metadata={'tech': 'visium'})
cosmx_scc = mmcci.it.lr_integration(list(samples['cosmx_scc'].values()), method=">=50%", assay="scaled", metadata={'tech': 'cosmx'})

visium_bcc = mmcci.it.lr_integration(list(samples['visium_bcc'].values()), method=">=50%", assay="scaled", metadata={'tech': 'visium'})
cosmx_bcc = mmcci.it.lr_integration(list(samples['cosmx_bcc'].values()), method=">=50%", assay="scaled", metadata={'tech': 'cosmx'})

visium_mela = mmcci.it.lr_integration(list(samples['visium_mela'].values()), method=">=50%", assay="scaled", metadata={'tech': 'visium'})
cosmx_mela = mmcci.it.lr_integration(list(samples['cosmx_mela'].values()), method=">=50%", assay="scaled", metadata={'tech': 'cosmx'})

In [None]:
scale_factors = mmcci.it.calc_scale_factors([visium_scc, cosmx_scc], assay='raw', group_key='tech')
visium_scc = visium_scc.scale(scale_factors['visium'], assay='raw')
cosmx_scc = cosmx_scc.scale(scale_factors['cosmx'], assay='raw')
scc = mmcci.it.lr_integration([visium_scc, cosmx_scc], method=">=50%", assay="scaled", metadata={'group': 'scc'})

scale_factors = mmcci.it.calc_scale_factors([visium_bcc, cosmx_bcc], assay='raw', group_key='tech')
visium_bcc = visium_bcc.scale(scale_factors['visium'], assay='raw')
cosmx_bcc = cosmx_bcc.scale(scale_factors['cosmx'], assay='raw')
bcc = mmcci.it.lr_integration([visium_bcc, cosmx_bcc], method=">=50%", assay="scaled", metadata={'group': 'bcc'})

scale_factors = mmcci.it.calc_scale_factors([visium_mela, cosmx_mela], assay='raw', group_key='tech')
visium_mela = visium_mela.scale(scale_factors['visium'], assay='raw')
cosmx_mela = cosmx_mela.scale(scale_factors['cosmx'], assay='raw')
mela = mmcci.it.lr_integration([visium_mela, cosmx_mela], method=">=50%", assay="scaled", metadata={'group': 'mela'})

In [None]:
scale_factors = mmcci.it.calc_scale_factors([scc, bcc], assay='raw', group_key='group')
scc = scc.scale(scale_factors['scc'], assay='raw')
bcc = bcc.scale(scale_factors['bcc'], assay='raw')
scc_bcc = mmcci.it.lr_integration([scc, bcc], method=">=50%", assay="raw", metadata={'group': 'scc_bcc'})

scale_factors = mmcci.it.calc_scale_factors([scc_bcc, mela], assay='raw', group_key='group')
scc_bcc = scc_bcc.scale(scale_factors['scc_bcc'], assay='raw')
mela = mela.scale(scale_factors['mela'], assay='raw')

scc_bcc = scc_bcc.filter_by_p_vals().calc_overall(assay='filtered')
mela = mela.filter_by_p_vals().calc_overall(assay='filtered')

In [None]:
scc = scc.filter_by_p_vals().calc_overall(assay='filtered')
mmcci.pl.network_plot(scc.assays['filtered']['overall'], node_colors=colors, figsize=(12, 10))

In [None]:
bcc = bcc.filter_by_p_vals().calc_overall(assay='filtered')
mmcci.pl.network_plot(bcc.assays['filtered']['overall'], node_colors=colors, figsize=(12, 10))

In [None]:
mela = mela.filter_by_p_vals().calc_overall(assay='filtered')
mmcci.pl.network_plot(mela.assays['filtered']['overall'], node_colors=colors, figsize=(12, 10))

In [None]:
mmcci.pl.lrs_per_celltype(mela, "Fibroblast", "T Cell", title="Fibroblast -> T Cell")
mmcci.an.run_gsea(mela, lrs=mela.get_lr_proportions("Fibroblast", "T Cell"), top_term=10, return_results=False)

mmcci.pl.lrs_per_celltype(mela, "T Cell", "Fibroblast", title="T Cell -> Fibroblast")
mmcci.an.run_gsea(mela, lrs=mela.get_lr_proportions("T Cell", "Fibroblast"), top_term=10, return_results=False)

mmcci.pl.lrs_per_celltype(mela, "Melanocytes", "T Cell", title="Melanocytes -> T Cell")
mmcci.an.run_gsea(mela, lrs=mela.get_lr_proportions("Melanocytes", "T Cell"), top_term=10, return_results=False)

mmcci.pl.lrs_per_celltype(mela, "T Cell", "Melanocytes", title="T Cell -> Melanocytes")
mmcci.an.run_gsea(mela, lrs=mela.get_lr_proportions("T Cell", "Melanocytes"), top_term=10, return_results=False)

mmcci.pl.lrs_per_celltype(mela, "Melanocytes", "Fibroblast", title="Melanocytes -> Fibroblast")
mmcci.an.run_gsea(mela, lrs=mela.get_lr_proportions("Melanocytes", "Fibroblast"), top_term=10, return_results=False)

mmcci.pl.lrs_per_celltype(mela, "Fibroblast", "Melanocytes", title="Fibroblast -> Melanocytes")
mmcci.an.run_gsea(mela, lrs=mela.get_lr_proportions("Fibroblast", "Melanocytes"), top_term=10, return_results=False)

In [None]:
def lrs_per_celltype(
    sample,
    sender = None,
    receiver = None,
    assay="raw",
    key="cci_scores",
    p_vals=None,
    n=15,
    x_label_size=18,
    y_label_size=24,
    x_tick_size=14,
    y_tick_size=12,
    figsize=(6, 5),
    show=True,
    title=None,
    title_size=14
):
    """Plots a bar plot of LR pairs and their proportions for a sender and receiver cell type pair along with p_values (optional).

    Args:
        sample (CCIData): The CCIData object.
        sender (str): The sender cell type. Defaults to None.
        receiver (str): The receiver cell type. Defaults to None.
        assay (str): The assay to use. Defaults to "raw".
        key (str): The key to use. Defaults to "cci_scores".
        p_vals (dict): A dictionary of p-values. Defaults to None.
        n (int): Number of LR pairs to plot. Defaults to 15.
        x_label_size (int): Font size for x-axis label. Defaults to 18.
        y_label_size (int): Font size for y-axis label. Defaults to 24.
        x_tick_size (int): Font size for tick labels. Defaults to 14.
        y_tick_size (int): Font size for tick labels. Defaults to 12.
        figsize (tuple): Size of the figure. Defaults to (10, 8).
        title (str) (optional): Title for the plot. Defaults to None.
        title_size (int): Font size of the title. Defaults to 14.

    Returns:
        matplotlib.figure.Figure: The figure
    """

    pairs = sample.get_lr_proportions(sender, receiver, assay, key)
    keys = list(pairs.keys())[:n]
    values = list(pairs.values())[:n]
    keys.reverse()
    values.reverse()
    if p_vals is not None:
        p_val_pairs = an.get_p_vals_per_celltype(p_vals, sender, receiver)
        labels = [p_val_pairs[key] for key in keys]

        # make labels readable (if less than 0.00001, show as <0.00001)
        for i in range(len(labels)):
            if labels[i] < 0.001:
                labels[i] = "<0.001"
            else:
                labels[i] = f"{labels[i]:.3f}"

        # Define colors based on p-values
        colors = [
            '#1f77b4' if val < 0.05 else 'grey' for val in [
                p_val_pairs[key] for key in keys]]

    # Create the figure and axis
    fig, ax = plt.subplots(figsize=figsize)
    plt.style.use('default')

    if p_vals is None:
        ax.barh(keys, values)
    else:
        bars = ax.barh(keys, values, color=colors)
        ax.bar_label(bars, labels)

    ax.set_xlabel("Proportion", fontsize=x_label_size)
    ax.set_ylabel("LR Pair", fontsize=y_label_size)
    ax.tick_params(axis='x', which='major', labelsize=x_tick_size)
    ax.tick_params(axis='y', which='major', labelsize=y_tick_size)

    if title:
        plt.title(title, pad=20, fontsize=title_size)
        
    plt.tight_layout()

    if show:
        plt.show()
    else:
        # plt.close(fig)
        return fig

with plt.rc_context({"figure.figsize": (6, 5), "figure.dpi": (300)}):
    lrs_per_celltype(mela, "T Cell", "Melanocytes", title="T Cell -> Melanocytes", show=False)
    plt.savefig("barplot.pdf", bbox_inches="tight")

In [None]:
mmcci.pl.network_plot(scc.assays['filtered']['cci_scores']['CD80_CTLA4'], node_colors=colors, figsize=(12, 10), remove_unconnected=False)
mmcci.pl.network_plot(bcc.assays['filtered']['cci_scores']['CD80_CTLA4'], node_colors=colors, figsize=(12, 10), remove_unconnected=False)
mmcci.pl.network_plot(mela.assays['filtered']['cci_scores']['CD80_CTLA4'], node_colors=colors, figsize=(12, 10), remove_unconnected=False)

In [None]:
diff = mmcci.an.get_network_diff(scc_bcc.assays['filtered']['cci_scores']['CD80_CTLA4'], mela.assays['filtered']['cci_scores']['CD80_CTLA4'])
mmcci.pl.network_plot(diff['diff'], diff['p_vals'], diff_plot=True, node_colors=colors, figsize=(13, 10))

In [None]:
with plt.rc_context({"figure.figsize": (13, 10), "figure.dpi": (300)}):
    mmcci.pl.network_plot(diff['diff'], diff['p_vals'], diff_plot=True, node_colors=colors, figsize=(13, 10), show=False)
    plt.savefig("network_plot.pdf", bbox_inches="tight")