In [18]:
import sys
import seaborn as sns
import pandas as pd 
import numpy as np
from itertools import combinations
from scipy.spatial.distance import squareform, pdist
from scipy.spatial.distance import cdist
import matplotlib.pyplot as plt
import torch
import anndata as an
import scanpy as sc
import re
import scipy
from collections import Counter
import os
import gc
from importlib import reload

from datasets import Dataset, load_from_disk
from torch.utils.data import DataLoader, TensorDataset
from datasets import load_dataset
from geneformer import EmbExtractor
import geneformer

# classifer tools
import xgboost
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split, KFold, cross_val_score
from sklearn.metrics import accuracy_score, classification_report
from sklearn.metrics import confusion_matrix
from sklearn.preprocessing import MinMaxScaler
from sklearn.preprocessing import minmax_scale

# local imports
sys.path.insert(0, '../../scripts/')
import geneformer_utils as gtu

sns.set_style('white')
torch.cuda.empty_cache()

# Load the model

In [2]:
"""Load the model"""
model_path = "/scratch/indikar_root/indikar1/shared_data/geneformer/fine_tune/240715_geneformer_cellClassifier_no_induced/ksplit1/"
model = gtu.load_model(model_path)
print('loaded!')

Some weights of BertForMaskedLM were not initialized from the model checkpoint at /scratch/indikar_root/indikar1/shared_data/geneformer/fine_tune/240715_geneformer_cellClassifier_no_induced/ksplit1/ and are newly initialized: ['cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


loaded!


In [3]:
token_data_path = "/scratch/indikar_root/indikar1/shared_data/geneformer/resources/token_mapping.csv"
token_df = pd.read_csv(token_data_path)
token_df.head()

Unnamed: 0,gene_id,token_id,gene_name,nonzero_median,gene_version,gene_biotype,Chromosome,Start,End,scenic_tf
0,<pad>,0,,,,,,,,False
1,<mask>,1,,,,,,,,False
2,ENSG00000000003,2,TSPAN6,2.001186,15.0,protein_coding,X,100627107.0,100639991.0,False
3,ENSG00000000005,3,TNMD,3.228213,6.0,protein_coding,X,100584935.0,100599885.0,False
4,ENSG00000000419,4,DPM1,2.218874,14.0,protein_coding,20,50934866.0,50959140.0,False


# Load data

In [4]:
data_path = "/scratch/indikar_root/indikar1/shared_data/geneformer/fine_tune/fb_hsc_only.dataset"

# Load from pre-trained data
print(f"Loading dataset from '{data_path}'...")
data = load_from_disk(data_path)

# Convert to DataFrame for filtering
df = data.to_pandas()
df = df.drop(columns=['__index_level_0__'])
print(f"\nDataset loaded successfully!")
print("\nOriginal Dataset:")
print(f" - Number of samples: {df.shape[0]:,}")
print(f" - Number of columns: {df.shape[1]:,}")


def apply_pad(x):
    """A function to pad in the inputs"""
    return list(x) + [0] * (2048 - len(x)) if len(x) < 2048 else list(x)
    
# need the opriginal length to properly build the attention mask
df['raw_length'] = df['length'].copy()
# pre-pad the input ids to avoid doing this when extracting embeddings
df["input_ids"] = df["input_ids"].apply(lambda x: apply_pad(x))

# update the lengths
df['length'] = df['input_ids'].apply(lambda x: len(x))

# Filter datasets (assuming 'standardized_cell_type' column exists)
fb_df = df[df['standardized_cell_type'] == 'Fibroblast']
hsc_df = df[df['standardized_cell_type'] == 'HSC']

print(f"\nFiltered datasets:")
print(f" - Fibroblasts samples: {fb_df.shape[0]:,}")
print(f" - HSC samples: {hsc_df.shape[0]:,}")

# create separate data objects
fb_data = Dataset.from_pandas(fb_df)
hsc_data = Dataset.from_pandas(hsc_df)

df.head()

Loading dataset from '/scratch/indikar_root/indikar1/shared_data/geneformer/fine_tune/fb_hsc_only.dataset'...

Dataset loaded successfully!

Original Dataset:
 - Number of samples: 35,398
 - Number of columns: 8

Filtered datasets:
 - Fibroblasts samples: 15,308
 - HSC samples: 20,090


Unnamed: 0,input_ids,cell_type,dataset,length,ignore,standardized_cell_type,broad_type,cell_id,raw_length
0,"[14577, 17163, 10265, 7725, 18049, 6816, 806, ...",HSC,weng_old1_BMMC_HSPC,2048,HSC,HSC,stem/progenitor,cell_1,2048
1,"[14577, 3649, 17163, 9855, 5575, 7725, 8687, 1...",HSC,weng_old1_BMMC_HSPC,2048,HSC,HSC,stem/progenitor,cell_2,2048
2,"[10062, 3659, 17163, 7725, 9855, 9408, 2560, 5...",HSC,weng_old1_BMMC_HSPC,2048,HSC,HSC,stem/progenitor,cell_3,2048
3,"[17163, 10265, 7725, 9855, 6876, 1911, 9951, 1...",HSC,weng_old1_BMMC_HSPC,2048,HSC,HSC,stem/progenitor,cell_4,1743
4,"[14577, 10265, 1734, 3187, 7725, 1329, 9512, 9...",HSC,weng_old1_BMMC_HSPC,2048,HSC,HSC,stem/progenitor,cell_5,2048


In [20]:
def masked_mean(tensor, lengths):
    A, B, C = tensor.shape
    mask = torch.arange(B, device=tensor.device).expand(A, B) >= lengths.unsqueeze(1)  # Inverted mask
    masked_tensor = tensor.masked_fill(mask.unsqueeze(-1), 0)  # Use masked_fill
    summed_tensor = masked_tensor.sum(dim=1)    
    mean_tensor = summed_tensor / lengths.unsqueeze(-1)
    return mean_tensor

def extract_cell_embs(model, data, layer=-1, batch_size=32):
    """a function to get cell embeddings """
    torch.cuda.empty_cache()
    total_length = len(data)
    
    # for the attention mask
    max_range_tensor = torch.arange(2048, device="cuda") 
    
    def create_attention_mask(lengths):
        return max_range_tensor.unsqueeze(0) < lengths.unsqueeze(-1)
    
    embs_list = []
    
    for i in range(0, total_length, batch_size):
        max_range = min(i + batch_size, total_length)
        minibatch = data.select([i for i in range(i, max_range)])
        minibatch.set_format(type="torch")
        
        input_ids = minibatch["input_ids"]

        attention_mask = create_attention_mask(
            minibatch['raw_length'].to("cuda")
        )
                
        with torch.no_grad():
            outputs = model(
                input_ids = input_ids.to("cuda"),
                attention_mask = attention_mask.to("cuda"),
            )

        embs_i = outputs.hidden_states[layer]
        embs_i = masked_mean(
            embs_i, 
            minibatch['raw_length'].to("cuda"),
        )
        embs_list.append(embs_i)
        
    return torch.cat(embs_list, dim=0).to("cpu").numpy()


embs = extract_cell_embs(
    model, 
    fb_data,
    layer=-1,
)

print(f"{embs.shape=}")

embs.shape=(15308, 512)


In [34]:
# sns.scatterplot(
#     x=embs[:, 0],
#     y=embs[:, 300],
#     hue=fb_data['cell_type']
# )

In [6]:
break

SyntaxError: 'break' outside loop (668683560.py, line 1)

In [None]:
def fast_extract(model, data, layer=-1, batch_size=500):
    """Extracts embeddings from a model and returns them as a DataFrame.
    """

    embs = geneformer.emb_extractor.get_embs(
        model,
        data,
        'cell', # embedding mode
        layer,
        0,  # Assuming this is a constant parameter for the function
        batch_size,
        summary_stat=None,  
        silent=False, 
    )
    data = embs.cpu().numpy()
    return pd.DataFrame(data)


embs = fast_extract(
    model, 
    Dataset.from_pandas(fb_df),
    layer=-1,
    
)

embs.shape


In [None]:
break

# Compute HSC mean and plot

In [None]:
hsc_df = df[df['standardized_cell_type'] == 'HSC'].reset_index(drop=True)
hsc_data = Dataset.from_pandas(hsc_df)

reload(gtu)
torch.cuda.empty_cache()
hsc_embs = gtu.extract_embedding_in_mem(
    model, 
    hsc_data, 
    layer_to_quant=-1,
    forward_batch_size=100,
)
print(f"{hsc_embs.shape=}")

# translate into an anndata object and plot
hsc_adata = gtu.embedding_to_adata(hsc_embs)
hsc_adata.obs = hsc_df.copy()

sc.tl.pca(hsc_adata, n_comps=25)
sc.pp.neighbors(hsc_adata, n_neighbors=200)
sc.tl.umap(hsc_adata, 
           min_dist=0.75,
          )

hsc_adata

In [None]:
# create the true mean vector based on the embedding space
hsc_mean = hsc_adata.X.mean(axis=0)
hsc_mean = hsc_mean.reshape(-1, 1)
print(f"{hsc_mean.shape=}")


# plot the cells around the mean
plt.rcParams['figure.dpi'] = 200
plt.rcParams['figure.figsize'] = 4, 4

# Add Contour Plot
x = hsc_adata.obsm['X_umap'][:, 0]
y = hsc_adata.obsm['X_umap'][:, 1]

# Create a grid for density estimation
x_grid, y_grid = np.meshgrid(np.linspace(x.min(), x.max(), 100),
                             np.linspace(y.min(), y.max(), 100))

# Calculate kernel density estimation and plot
z = sns.kdeplot(x=x, y=y, levels=5, thresh=0.1, cmap="Blues", fill=True)

# Set zorder to place the contour behind the scatter points
z.collections[0].set_zorder(0) 

# Plot scatterplot on top of the contour
sns.scatterplot(
    data=hsc_adata.obs,
    x=hsc_adata.obsm['X_umap'][:, 0],
    y=hsc_adata.obsm['X_umap'][:, 1],
    color='k',
    ec='k',
    s=2,
    zorder=1
)

# Plot the mean point
sns.scatterplot(
    x=[hsc_adata.obsm['X_umap'][:, 0].mean()],
    y=[hsc_adata.obsm['X_umap'][:, 1].mean()],
    color='red',
    ec='k',
    s=150,
    zorder=2  # Set zorder to place the mean point on top
)

# Add plot details
plt.yticks([])
plt.xticks([])
plt.ylabel("UMAP 2")
plt.xlabel("UMAP 1")
plt.title('HSC Cells')

# Display the plot
plt.show()

# Plot the Fibroblast data

In [None]:
fb_df = df[df['standardized_cell_type'] == 'Fibroblast'].reset_index(drop=True)
fb_data = Dataset.from_pandas(fb_df)

reload(gtu)
torch.cuda.empty_cache()
fb_embs = gtu.extract_embedding_in_mem(
    model, 
    fb_data, 
    layer_to_quant=-1,
    forward_batch_size=100,
)
print(f"{fb_embs.shape=}")

# translate into an anndata object and plot
fb_adata = gtu.embedding_to_adata(fb_embs)
fb_adata.obs = fb_df.copy()

sc.tl.pca(fb_adata, n_comps=25)
sc.pp.neighbors(fb_adata, n_neighbors=200)
sc.tl.umap(fb_adata, 
           min_dist=0.75,
          )

fb_adata

In [None]:
# plot the cells around the mean
plt.rcParams['figure.dpi'] = 200
plt.rcParams['figure.figsize'] = 4, 4

# Add Contour Plot
x = fb_adata.obsm['X_umap'][:, 0]
y = fb_adata.obsm['X_umap'][:, 1]

# Create a grid for density estimation
x_grid, y_grid = np.meshgrid(np.linspace(x.min(), x.max(), 100),
                             np.linspace(y.min(), y.max(), 100))

# Calculate kernel density estimation and plot
z = sns.kdeplot(x=x, y=y, levels=5, thresh=0.1, cmap="Greens", fill=True)

# Set zorder to place the contour behind the scatter points
z.collections[0].set_zorder(0) 

# Plot scatterplot on top of the contour
sns.scatterplot(
    data=fb_adata.obs,
    x=fb_adata.obsm['X_umap'][:, 0],
    y=fb_adata.obsm['X_umap'][:, 1],
    color='k',
    ec='k',
    s=2,
    zorder=1
)


# Add plot details
plt.yticks([])
plt.xticks([])
plt.ylabel("UMAP 2")
plt.xlabel("UMAP 1")
plt.title('Fibroblast Cells')

# Display the plot
plt.show()

# Plot them together

In [None]:
reload(gtu)
torch.cuda.empty_cache()
embs = gtu.extract_embedding_in_mem(model, data, layer_to_quant=-1)
print(f"{embs.shape=}")

# translate into an anndata object and plot
adata = gtu.embedding_to_adata(embs)
adata.obs = df.copy()

sc.tl.pca(adata, n_comps=25)
sc.pp.neighbors(adata, n_neighbors=200)
sc.tl.umap(adata, 
           min_dist=0.75,
          )
adata

In [None]:
sns.scatterplot(
    data=adata.obs,
    x=adata.obsm['X_umap'][:, 0],
    y=adata.obsm['X_umap'][:, 1],
    hue='standardized_cell_type',
    ec='none',
    palette=['C0', 'Green'],
    s=10,
    alpha=0.7,
    zorder=1
)

# Add plot details
plt.yticks([])
plt.xticks([])
plt.ylabel("UMAP 2")
plt.xlabel("UMAP 1")
plt.title('All Cells')

sns.move_legend(
    plt.gca(),
    title="",
    loc='upper right',
    bbox_to_anchor=(1.5, 1.02),
)

# pre-compute cell-cell distances

In [None]:
def make_colorbar(cmap='viridis', 
                  width=0.2,
                  height=2.5, 
                  title='', 
                  orientation='vertical', 
                  tick_labels=[0, 1]):
    """
    Creates and displays a standalone colorbar using Matplotlib.

    Args:
        cmap (str or matplotlib.colors.Colormap): The colormap to use for the colorbar.
        width (float): The width of the colorbar figure in inches.
        height (float): The height of the colorbar figure in inches.
        title (str): The title to display above or next to the colorbar.
        orientation (str): The orientation of the colorbar ('vertical' or 'horizontal').
        tick_labels (list of str): The labels to display at each tick on the colorbar.

    Returns:
        None: This function displays the colorbar directly using Matplotlib.

    Raises:
        ValueError: If the `orientation` is not 'vertical' or 'horizontal'.
    """
    
    a = np.array([[0, 1]])  # Dummy data for the image
    plt.figure(figsize=(width, height))
    img = plt.imshow(a, cmap=cmap)
    plt.gca().set_visible(False)  # Hide the axes of the image
    cax = plt.axes([0.1, 0.2, 0.8, 0.6])  # Define the colorbar position

    ticks = np.linspace(0, 1, len(tick_labels)) 
    cbar = plt.colorbar(
        orientation=orientation,
        cax=cax,
        label=title,
        ticks=ticks
    )

    if orientation == 'vertical':
        cbar.ax.set_yticklabels(tick_labels)
    elif orientation == 'horizontal':
        cbar.ax.set_xticklabels(tick_labels)

In [None]:
# sort by cell-type

"""
In this block I compute the consine distance in embedding space between all fibroblasts and all hsc cells
"""

sorted_cells = adata.obs.sort_values(by='standardized_cell_type')
change_indices = np.argwhere(sorted_cells['standardized_cell_type'] != sorted_cells['standardized_cell_type'].shift(1))
change_indices = np.ravel(change_indices)
X = adata[sorted_cells.index, :].X
print(f"{X.shape=}")

metric = 'cosine'
cmap = 'hot'
D = squareform(pdist(X, metric=metric))
print(f"{D.shape=}")

plt.rcParams['figure.dpi'] = 200
plt.rcParams['figure.figsize'] = 6, 6
plt.imshow(D, cmap=cmap)

# add dividing lines
for pos in change_indices[1:]:
    plt.axvline(x=pos, c='k', lw=1)
    plt.axhline(y=pos, c='k', lw=1)
    
# label the blocks
all_indices = list(change_indices) + [len(X)]
midpoints = [(all_indices[i] + all_indices[i + 1]) / 2 for i in range(len(all_indices) - 1)]
plt.yticks(midpoints, sorted_cells['standardized_cell_type'].unique())

plt.title('Cell-cell Distances')
plt.xticks([])

make_colorbar(cmap=cmap, tick_labels=['Low', 'High'])

# Visualize the variability

In [None]:
""" Compute the hsc-hsc distances """
hsc_dists = D[853:2000, 853:2000]
print(f"{hsc_dists.shape=}")

plt.rcParams['figure.dpi'] = 200
plt.rcParams['figure.figsize'] = 4, 3

sns.histplot(
    x=hsc_dists.mean(axis=1),
    ec='k',
    kde=True,
    color='C0',
)

# add mean distance
x = hsc_dists.mean(axis=1).mean()
plt.axvline(x=x, c='r', label=f"Mean: {x:.3f}")
plt.legend()

plt.xlabel('Cosine Distance')
plt.title('HSC-HSC Distances')

In [None]:
""" Compute the fibroblast-fibroblast distances """
fb_dists = D[0:853, 0:853]
print(f"{fb_dists.shape=}")

plt.rcParams['figure.dpi'] = 200
plt.rcParams['figure.figsize'] = 4, 3

sns.histplot(
    x=fb_dists.mean(axis=1),
    ec='k',
    kde=True,
    color='C2',
)

# add mean distance
x = fb_dists.mean(axis=1).mean()
plt.axvline(x=x, c='r', label=f"Mean: {x:.3f}")
plt.legend()

plt.xlabel('Cosine Distance')
plt.title('FB-FB Distances')

In [None]:
""" Compute the fibroblast-hsc distances """
fb_hsc_dists = D[0:853, 853:2000]
print(f"{fb_hsc_dists.shape=}")

plt.rcParams['figure.dpi'] = 200
plt.rcParams['figure.figsize'] = 4, 3

sns.histplot(
    x=fb_hsc_dists.mean(axis=1),
    ec='k',
    kde=True,
    color='C4',
)

# add mean distance
x = fb_hsc_dists.mean(axis=1).mean()
plt.axvline(x=x, c='r', label=f"Mean: {x:.3f}")
plt.legend()

plt.xlabel('Cosine Distance')
plt.title('FB-HSC Distances')

# Pertrubations

In [None]:
gene_list = [
    'GATA2', 
    'GFI1B', 
    'FOS', 
    'STAT5A',
    'REL',
    'FOSB',
    'IKZF1',
    'RUNX3',
    'MEF2C',
    'ETV6',
]

genes = token_df[token_df['gene_name'].isin(gene_list)]
tf_map = dict(zip(genes['token_id'].values, genes['gene_name'].values))

genes

# Set up the inputs

In [None]:
# compute all possible combinations of 5 TFs,
n_tf = 5
inputs = list(combinations(genes['token_id'], n_tf))
print(f'Number of recipes: {len(inputs)}')

def map_tfs(tokens):
    return list(map(tf_map.get, tokens))

print(inputs[0])
print(map_tfs(inputs[0]))

In [None]:
def add_perturbations_to_cell(cell_tokens, perturbation_tokens):
    """
    Modifies a list of cell tokens by adding perturbation tokens and padding.

    Args:
        cell_tokens (list): A list of integers representing gene tokens.
        perturbation_tokens (list): A list of integers representing perturbation tokens.

    Returns:
        list: A new list of tokens with perturbations added, existing perturbations removed,
             and truncated/padded to the original length.
    """

    original_length = len(cell_tokens)

    # Remove existing perturbation tokens from the cell
    cell_tokens = [token for token in cell_tokens if token not in perturbation_tokens]

    # Add perturbations, then slice or pad to match original length
    final_tokens = (perturbation_tokens + cell_tokens)[:original_length]  # Slice if too long
    final_tokens += [0] * (original_length - len(final_tokens))            # Pad if too short

    return final_tokens


# test
test_cell = df.head(1)['input_ids'].values[0]
perturbed = add_perturbations_to_cell(test_cell, list(inputs[0]))

print(test_cell[:10])
print(perturbed[:10])

In [None]:
# set up some perturbations
sample_size = 100

raw_cells = fb_df.sample(sample_size).reset_index(drop=True)
print(f"{raw_cells.shape=}")
raw_cells['recipe'] = 'raw'
raw_cells['type'] = 'initial'

hsc_sample = hsc_df.sample(sample_size).reset_index(drop=True)
hsc_sample['recipe'] = 'hsc'
hsc_sample['type'] = 'target'

reprogramming_df = [
    raw_cells,
    hsc_sample,
]

for i, tfs in enumerate(inputs):
    
    if i % 25 == 0:
        print(f"Pertubation {i}/{len(inputs)}...")
    
    # make the dataframe easily useable
    perturb = raw_cells.copy()
    recipe = ";".join(map_tfs(tfs))
    perturb['recipe'] = recipe
    perturb['type'] = 'reprogrammed'
    
    # do the actual perturbation
    perturb['input_ids'] = perturb['input_ids'].apply(lambda x: add_perturbations_to_cell(x, list(tfs)))
    
    # store the updated data
    reprogramming_df.append(perturb)
    
reprogramming_df = pd.concat(reprogramming_df)
reprogramming_df = reprogramming_df.reset_index(drop=True)
print(f"{reprogramming_df.shape=}")
reprogramming_df.head()

# Perturbation Embedding

In [None]:
reload(gtu)
torch.cuda.empty_cache()

reprogramming_data = Dataset.from_pandas(reprogramming_df)

reprogramming_embs = gtu.extract_embedding_in_mem(
    model, 
    reprogramming_data, 
    layer_to_quant=-1,
    forward_batch_size=100,
)
print(f"{reprogramming_embs.shape=}")

# translate into an anndata object and plot
reprogramming_adata = gtu.embedding_to_adata(reprogramming_embs)
reprogramming_adata.obs = reprogramming_df.copy()

sc.tl.pca(reprogramming_adata, n_comps=25)
sc.pp.neighbors(reprogramming_adata, n_neighbors=200)
sc.tl.umap(reprogramming_adata, 
           min_dist=0.75,
          )

reprogramming_adata

# Plot

In [None]:
pdf = reprogramming_adata.obs.copy()
pdf['UMAP 1'] = reprogramming_adata.obsm['X_umap'][:, 0]
pdf['UMAP 2'] = reprogramming_adata.obsm['X_umap'][:, 1]

pdf = pdf.sort_values(by='type', ascending=False)

plt.rcParams['figure.dpi'] = 200
plt.rcParams['figure.figsize'] = 7, 7

sns.scatterplot(
    data=pdf,
    x='UMAP 1',
    y='UMAP 2',
    hue='type',
    ec='none',
    palette=['C0', 'lightgrey', 'green'],
    s=20,
    alpha=0.9,
    zorder=1
)

# # Add plot details
plt.yticks([])
plt.xticks([])
# # plt.title('All Cells')

sns.move_legend(
    plt.gca(),
    title="",
    loc='upper right',
    bbox_to_anchor=(1.4, 1.02),
)

# minimal distance to target

In [None]:
initial = reprogramming_adata.obs[reprogramming_adata.obs['type'] == 'initial']
target = reprogramming_adata.obs[reprogramming_adata.obs['type'] == 'target']
repro = reprogramming_adata.obs[reprogramming_adata.obs['type'] == 'reprogrammed']

# precompute all distances
metric = 'cosine'
D = squareform(pdist(reprogramming_adata.X, metric=metric))
print(f"{D.shape=}")

inital_to_target = D[initial.index, target.index].mean()
print(f"{inital_to_target=:.4f}")

result = []

for i, (recipe, group) in enumerate(repro.groupby('recipe')):
    
    if i % 25 == 0:
        print(f"Recipe {i}/{len(inputs)}...")
    
    # compute group to intial
    recipe_to_initial = D[group.index, initial.index].mean() # average over all cells
    
    # compute group to target
    recipe_to_target = D[group.index, target.index].mean() # average over all cells
    
    row = {
        'recipe' : recipe,
        'recipe_to_initial' : recipe_to_initial,
        'recipe_to_target' : recipe_to_target,
        'recipe_diff' : inital_to_target - recipe_to_target,
    }
    result.append(row)
    
result = pd.DataFrame(result)
result.head()

In [None]:
result = result.sort_values(by='recipe_diff', ascending=False)
result.head(15)

In [None]:
result['recipe_list'] = result['recipe'].str.split(";")
result['recipe_list'] = result['recipe_list'].apply(lambda x: ";".join(sorted(x)))
result['rank'] = result['recipe_diff'].rank(ascending=False)
result.head()

In [None]:
iHSC_tf = [
    'GATA2', 
    'GFI1B', 
    'FOS', 
    'STAT5A',
    'REL',
]

query = ";".join(sorted(iHSC_tf))

result[result['recipe_list'] == query]

In [None]:
plt.rcParams['figure.dpi'] = 200
plt.rcParams['figure.figsize'] = 5, 5

pdf = result.copy()
pdf['init'] = minmax_scale(pdf['recipe_to_initial'])
pdf['target'] = minmax_scale(pdf['recipe_to_target'])
pdf['Score'] = minmax_scale(pdf['recipe_diff'])

sns.scatterplot(
    data=pdf,
    x='target',
    y='init',
    ec='k',
    hue='Score',
    palette='RdYlGn',
)

plt.ylabel('Initial to Reprogramed')
plt.xlabel('Reprogramed to Target')

sns.move_legend(
    plt.gca(),
    loc='best',            
)

In [None]:
top = 25
pdf = result.copy()
pdf = pdf.sort_values(by='recipe_diff', ascending=False)
pdf = pdf.head(top)

def flatten_list(nested_list):
    flat_list = []
    for sublist in nested_list:
        for item in sublist:
            flat_list.append(item)
    return flat_list

genes = pdf['recipe'].to_list()
genes = [x.split(";") for x in genes]
genes = flatten_list(genes)

counts = pd.DataFrame.from_dict(
    Counter(genes), 
    orient='index', 
    columns=['Count'],
)

counts = counts.reset_index(names='TF')
counts = counts.sort_values(by='Count', ascending=False)

plt.rcParams['figure.figsize'] = 2, 3.5

sns.barplot(data=counts,
            x='Count',
            y='TF',
            ec='k',
            width=0.5,
           )

plt.ylabel("")

In [None]:
break

In [None]:
pdf = reprogramming_adata.obs.copy()
pdf['UMAP 1'] = reprogramming_adata.obsm['X_umap'][:, 0]
pdf['UMAP 2'] = reprogramming_adata.obsm['X_umap'][:, 1]


# # extract the cells by UMAP coords
pdf = pdf[(pdf['UMAP 1'] < 10) & (pdf['UMAP 1'] > -5)]
pdf = pdf[(pdf['UMAP 2'] < 25) & (pdf['UMAP 2'] > 15)]

pdf = pdf.sort_values(by='type', ascending=False)

plt.rcParams['figure.dpi'] = 200
plt.rcParams['figure.figsize'] = 6, 6

sns.scatterplot(
    data=pdf,
    x='UMAP 1',
    y='UMAP 2',
    hue='cell_id',
    ec='none',
    s=20,
    alpha=0.9,
    zorder=1
)

# # Add plot details
# plt.yticks([])
# plt.xticks([])
# # plt.title('All Cells')

sns.move_legend(
    plt.gca(),
    title="",
    loc='upper right',
    bbox_to_anchor=(1.4, 1.02),
)

In [None]:
break