In [None]:
import sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns 
import scanpy as sc

from importlib import reload

import plotting as plt2

In [None]:
fpath = "/scratch/indikar_root/indikar1/cstansbu/HSC/scanpy/clustered.anndata.h5ad"

adata = sc.read_h5ad(fpath)
adata.uns['X_umap'] = adata.obsm['X_umap']
adata

In [None]:
# load BJ data
bj = pd.read_csv("/nfs/turbo/umms-indikar/shared/projects/R01/BJ_FIB_GENES.csv")
bj = bj.sort_values(by='tpm', ascending=False)

bj_nz = bj[bj['log_tpm'] > 0 ] # a non-zero version

bj.head()



In [None]:
n_genes = 20

pan = adata.uns['panglaodb']
fib = pan[pan['cell_type'] == 'Fibroblasts']
fib = pd.merge(bj, fib, how='left', 
               left_on='gene_name',
               right_on='gene_name',
              )

fib = fib[fib['cell_type'].notna()]
fib = fib.sort_values(by='log_tpm', ascending=False)
fib = fib.head(n_genes)

fib.head()

In [None]:
sample_size = 100
adata.var_names = adata.var['gene_name'].values
df = adata[:, fib['gene_name'].values].to_df().reset_index()
df = df.sample(sample_size)

df = pd.melt(df, id_vars='cell_id')
df.head()

In [None]:
plt.rcParams['figure.dpi'] = 300
plt.rcParams['figure.figsize'] = 5, 2

sns.barplot(data=fib,
            x='gene_name',
            y='log_tpm',
            ec='k',
            color='lightgrey',
           )

sns.stripplot(data=df[df['value'] > 0], 
            x='variable',
            y='value',
            c='r',
            size=2,
            alpha=0.6,
             )

plt.ylabel("TPM (log)")
plt.xlabel("")

plt.gca().tick_params(axis='x', rotation=90)
plt.margins(x=0.05, y=0.15)

In [None]:
def plot_gene_expression_by_cell_type(
    adata,
    bj,
    n_genes: int = 20,
    sample_size: int = 100,
    database_key: str = 'panglaodb',
    filter_column: str = 'cell_type',
    filter_value: str = 'Fibroblasts',
    sort_column: str = 'log_tpm',
    ascending: bool = False
) -> None:
    """
    Plots top expressed genes for a specific cell type and their expression in a sample of cells.

    Args:
        adata: An AnnData object containing gene expression data.
        bj: A DataFrame containing gene information (must have 'gene_name' column).
        n_genes: Number of top genes to plot. Defaults to 20.
        sample_size: Number of cells to sample for the stripplot. Defaults to 100.
        database_key: Key in adata.uns for the DataFrame containing cell type information.
        filter_column: Column in the database to filter by.
        filter_value: Value in the filter_column to select cells.
        sort_column: Column to sort the cell_data by (e.g., 'log_tpm', 'mean_expression').
        ascending: Whether to sort in ascending order. Defaults to False (descending).
    """
    database = adata.uns.get(database_key, pd.DataFrame())
    filtered_cells = database[database[filter_column] == filter_value]
    cell_data = pd.merge(
        bj, filtered_cells, how='left', left_on='gene_name', right_on='gene_name'
    )

    cell_data = cell_data.dropna(subset=[filter_column])
    cell_data = cell_data.sort_values(by=sort_column, ascending=ascending).head(n_genes)

    adata.var_names = adata.var.get('gene_name', pd.Series()).values
    df = adata[:, cell_data['gene_name'].values].to_df().reset_index()
    df = df.sample(sample_size)
    df = pd.melt(df, id_vars='cell_id')

    sns.barplot(
        data=cell_data,
        x='gene_name',
        y=sort_column,  # Use the specified sort column for the y-axis
        color='lightgrey',
        edgecolor='black',
        linewidth=0.75,
    )

    sns.stripplot(
        data=df[df['value'] > 0],
        x='variable',
        y='value',
        color='red',
        size=2,
        alpha=0.6,
        jitter=0.2
    )

    plt.ylabel("TPM (log)")
    plt.xlabel("")
    plt.gca().tick_params(axis='x', rotation=90)
    plt.margins(x=0.05, y=0.15)
    plt.show()
    
    
plt.rcParams['figure.dpi'] = 300
plt.rcParams['figure.figsize'] = 4, 1.25
plot_gene_expression_by_cell_type(adata, bj, n_genes=15, sample_size=300)    

In [None]:
plt.rcParams['figure.dpi'] = 300
plt.rcParams['figure.figsize'] = 4, 1.25
plot_gene_expression_by_cell_type(adata, bj_nz, n_genes=15, sample_size=300, 
                                  filter_value='Hematopoietic stem cells',
                                  ascending=True)    