# Reprogramming Recepies

Auth: Nat Oliven, Joshua Pickard

Date: August 26, 2024

In [4]:
import numpy as np
import pandas as pd
import scanpy as sp
import os

# Day 5

In [5]:
# Copied from a prev day

"""
Josh, please read: the adata.obs['scalar'] = scalar copies the scalar down for that call, associated with every cell in X. Same with ['scaled'] and ['scaled_by']
in var. This is good in case the data is later appended into one anndata object.
But my return from the perturb_counts loop (cell below this) is a dictionary of all of the perturb_counts, since appending along any axis will probably either overwrite
obs or var.
"""


def perturb_counts(tf_list, scalar, adata): 
    """
    Applies a perturbation to the expression data of specific genes in an AnnData object.

    This function performs the following steps:
    1. Computes the maximum gene expression level for each cell.
    2. Applies a scaling operation to the expression levels of genes listed in `tf_list`.
       - Each entry of these genes in the matrix is multiplied by the maximum expression level 
         of its respective cell and a specified scalar value.
    3. Updates the AnnData object with new columns:
       - 'scaled': A boolean column indicating whether each gene is in the `tf_list`.
       - 'scaled_by': Contains the scaling factor used for each gene (the product of the maximum 
         expression level of each cell and the scalar), or `1` if the gene was not in `tf_list`.
    
    Parameters:
    tf_list (list): A list of gene symbols to be perturbed.
    scalar (float): The scalar value used to scale the expression levels.
    adata (AnnData): The AnnData object containing gene expression data.

    Returns:
    AnnData: The updated AnnData object with applied perturbations and new columns.
    """
    # Save the original state of the parameter objects, in case some tfs do not translate (failsafe)
    original_X = adata.X.copy()
    original_gene_mask = gene_mask.copy()
    
    # Compute maximum expression level of each cell
    max_exp = np.max(adata.X, axis=1)
    
    # Create a boolean mask for genes in tf_list
    gene_mask = adata.var['gene_symbol'].isin(tf_list)

    """This is new today. v """
    # Raise an error if any of the gene names in tf_list do not match column names (we will manually update these in adata):
    missing_genes = [gene for gene in tf_list if gene not in adata.var['gene_symbol'].values]
    
    if missing_genes:
        # Restore original parameter objects
        adata.X = original_X
        gene_mask = original_gene_mask
        raise ValueError(f"Genes {missing_genes} not found in anndata object")

    else:
        """ This is new today. ^ """
        # Apply the scaling operation to the specified genes
        adata.X[:, gene_mask] *= max_exp[:, np.newaxis] * scalar
        
        # Add/Update 'scaled' column in var
        adata.var['scaled'] = gene_mask
        
        # Add/Update 'scaled_by' column in var
        adata.var['scaled_by'] = 1  # Default value for genes not in tf_list
        adata.var.loc[gene_mask, 'scaled_by'] = max_exp[:, np.newaxis] * scalar  # Correct scaling factor assignment
    
    return adata



In [6]:
# Copied from a prev day

import anndata

def iterate_perturb_counts(tf_list, scalar_list, adata):
    """
    Applies perturbations to the expression data of specified transcription factors across multiple scalars 
    and stores the resulting AnnData objects in a dictionary.

    This function performs the following steps:
    1. Iterates over a list of scalar values.
    2. For each scalar, creates a copy of the AnnData object to preserve the original data.
    3. Applies the `perturb_counts` function to scale the expression data of genes listed in `tf_list` by
       the maximum gene expression of each cell and the current scalar.
    4. Stores the perturbed AnnData object in a dictionary with the scalar as the key.

    Parameters:
    tf_list (list): A list of gene symbols (transcription factors) to be perturbed.
    scalar_list (list): A list of scalar values for scaling the gene expression.
    adata (AnnData): The AnnData object containing gene expression data (cells x genes).

    Returns:
    dict: A dictionary where keys are scalar values and values are the corresponding perturbed AnnData objects.
    """
    
    adata_dict = {}
    
    for scalar in scalar_list:
        # Create a copy of the AnnData object for each scalar value
        adata_temp = adata.copy()
        
        # Apply perturb_counts to the copied AnnData object
        perturbed_adata = perturb_counts(tf_list, scalar, adata_temp)
        
        # Store the perturbed AnnData object in the dictionary with scalar as the key
        adata_dict[scalar] = perturbed_adata
    
    return adata_dict



In [None]:

# Copied from a prev day

DATAPATH = "/nfs/turbo/umms-indikar/shared/projects/DGC/data/tabula_sapiens/extract/"
FILE = "TS_epithelial.h5ad"
adata = sp.read_h5ad(os.path.join(DATAPATH, FILE))
adata_gene_list = adata.var['gene_symbol'].values.tolist()

# Print the first 5 entries
print("First 5 entries:")
print(adata_gene_list[:5])

# Print the last 5 entries
print("Last 5 entries:")
print(adata_gene_list[-5:])

In [None]:
# outside the function so I can manipulate these directly.
# I also checked whether it matters if I do case insensitive (capitalize everything then compare) or case sensitive.
# Unsurprisingly, case sensitive has more discrepancies (45 vs. 44), with the one extra that was picked up as "Ptf1a".
# I left the case insensitive version. 

# get a list of words (potential genes, also includes and, + , etc. ) from the table from the review paper
table_1_df = pd.read_csv("/home/oliven/scFoundationModels/notebooks/reprogramming/data/table_1_data_from_paper_9_1.csv")
combined_string = ' '.join(table_1_df['TFs'].astype(str)).replace(',', '')
word_list = combined_string.split()

# get a list of genes that appear in the data matrix
DATAPATH = "/nfs/turbo/umms-indikar/shared/projects/DGC/data/tabula_sapiens/extract/"
FILE = "TS_epithelial.h5ad"
adata = sp.read_h5ad(os.path.join(DATAPATH, FILE))
adata_gene_list = adata.var['gene_symbol'].values.tolist()

# New today
def check_valid_tfs(word_list, adata_gene_list):

    word_list_upper = [word.upper() for word in word_list]
    adata_gene_list_upper = [gene.upper() for gene in adata_gene_list]

    # print what does not overlap
    not_valid_gene = set(word_list_upper) - set(adata_gene_list_upper)

    print("Entries in the table that are not genes in the counts matrix: ")
    
    return list(not_valid_gene)
    
check_valid_tfs(word_list, adata_gene_list)

In [None]:
# for those genes with multiple aliases, checking which are valid
# we have to worry about making this case insensitive just in case
multiple_alias_dict = {}
multiple_alias_dict.update({
    'P53': ['BCC7','BMFS5', 'LFS1', 'TRP53'],
    'OCT3/4': ['POU5F1', 'OCT3', 'OCT4', 'OTF4', 'MGC22487'],
    'MASH1': ['HASH1', 'BHLHa46', 'ASH1', 'ASH-1', 'ASCL1'],
    'HNF6': ['HNF6', 'HNF6A', 'ONECUT1'],
    'HB9': ['MNX1', 'HOXHB9', 'SCRA1', 'HLXB9', 'GC07M156491', 'GC07M156786', 'GC07M150530'],
    'PPARG2': ['PPARG', 'NR1C3', 'PPARG1', 'PPARgamma', 'PPAR-Gamma', 'PPARG5', 'CIMT1', 'GLM1'], 
    'PU.1': ['SPI1', 'SPI-A', 'SFPI1', 'SPI-1', 'OF', 'AGM10'],
    'N-MYC': ['MYCN', 'BHLHe37', 'N-Myc', 'NMYC', 'MYCNOT', 'MYCNsORF', 'MYCNsPEP', 'BHLHE37', 'FGLDS1', 'MODED', 'MPAPA', 'ODED'],
    # oct9 had a strange genecards lookup
    'OCT9': ['POU3F4', 'SLC22A16'],
    'LEF-1': ['TCF1ALPHA', 'TCF7L3', 'TCF10', 'LEF1'],
    # these next two were listed as ER71/ETV2
    'ER71/ETV2': ['ER71', 'ETV2', 'ETSRP71'],
    # sv40 had a strange genecards lookup
    'SV40': [''],
    # this one didn't show up
    'LXH3': ['M2-LHX3', 'M2LHX3', 'CPHD3', 'LIM3'],
    'NGN2': ['NEUROG2', 'BHLHA8', 'MATH4A', 'Math4a', 'ATOH4', 'Ngn-2', 'BHLHA8', 'Atoh4', 'NGN-2'],
    'LMX1A': ['LMX1.1', 'LMX1', 'LMX-1.', 'DFNA7'],
    # NF-Kb had  a strange genecards lookup
    'NF-ΚB': ['NFkb1', 'NFKB1'], 
    'L-MYC': ['MYCL', 'LMYC', 'BHLHe38', 'MYCL1', 'BHLHE38'],
    'BRN2': ['POU3F2', 'BRN2', 'OCT7', 'POUF3', 'OTF7', 'Brain-2', 'OTF-7', 'Brn-2', 'Oct-7', 'N-Oct3'], 
    'NURR1': ['NR4A2', 'TINUR', 'NOT', 'HZF3', 'NURR1', 'RNR1', 'IDLDP'],
    'SOX2': ['SRY-Box 2', 'MCOPS3', 'ANOP3'],
    'NEUROD': ['NEUROD1', 'BHLHa3', 'BETA2', 'BHF-1', 'MODY6', 'NeuroD1', 'BHLHA3', 'T2D'],
    'C-MYC': ['MYC', 'C-MYC', 'MYCC', 'MRTL', 'BHLHE39'],
    # AP-2A had  a strange genecards lookup
    'AP-2A': [' '],
    'PAX6': ['D11S812E', 'WAGR', 'AN2', 'AN', 'AN1', 'ASGD5', 'FVH1', 'MGDA'],
    'OSTERIX': ['SP7', 'OSX', 'OI11', 'OI12'],
    # OCT6 had a strange genecards lookup
    'OCT6': ['POU3F1', 'SCIP', 'OTF6', 'OTF-6'],
   
})

In [None]:
# day 4, technically (migrate this)
adata_gene_list = adata.var['gene_symbol'].values.tolist()
def identify_gene_name_translations(multiple_alias_dict, adata_gene_list):
    """
    Identifies whether any values from multiple_alias_dict appear in adata_gene_list, case-insensitive.
    
    Parameters:
    multiple_alias_dict (dict): A dictionary where keys are gene names and values are lists of aliases.
    adata_gene_list (list): A list of gene names to check against, case-insensitive.

    Prints:
    For each key, whether it was found in the gene list along with the matching values.
    """
    # Convert the gene list to uppercase for case-insensitive comparison
    adata_gene_list_upper = [gene.upper() for gene in adata_gene_list]

    # Loop through each key and values in the dictionary
    for key, values in multiple_alias_dict.items():
        # Convert each alias to uppercase
        values_upper = [value.upper() for value in values]
        
        # Check if any alias is present in the gene list
        found_values = [value for value in values_upper if value in adata_gene_list_upper]
        
        # Print appropriate message based on whether any values were found
        if found_values:
            print(f"{key} was found in gene list as {found_values}")
        else:
            print(f"{key} was not found in gene list.")

identify_gene_name_translations(multiple_alias_dict, adata_gene_list)

In [None]:
break

# Day 4: Mine

In [None]:

# Copied from a prev day

DATAPATH = "/nfs/turbo/umms-indikar/shared/projects/DGC/data/tabula_sapiens/extract/"
FILE = "TS_epithelial.h5ad"
adata = sp.read_h5ad(os.path.join(DATAPATH, FILE))
adata_gene_list = adata.var['gene_symbol'].values.tolist()

# Print the first 5 entries
print("First 5 entries:")
print(adata_gene_list[:5])

# Print the last 5 entries
print("Last 5 entries:")
print(adata_gene_list[-5:])

In [None]:
# outside the function so I can manipulate these directly.
# I also checked whether it matters if I do case insensitive (capitalize everything then compare) or case sensitive.
# Unsurprisingly, case sensitive has more discrepancies (45 vs. 44), with the one extra that was picked up as "Ptf1a".
# I left the case insensitive version. 

# get a list of words (potential genes, also includes and, + , etc. ) from the table from the review paper
table_1_df = pd.read_csv("/home/oliven/scFoundationModels/notebooks/reprogramming/data/table_1_data_from_paper_9_1.csv")
combined_string = ' '.join(table_1_df['TFs'].astype(str)).replace(',', '')
word_list = combined_string.split()

# get a list of genes that appear in the data matrix
DATAPATH = "/nfs/turbo/umms-indikar/shared/projects/DGC/data/tabula_sapiens/extract/"
FILE = "TS_epithelial.h5ad"
adata = sp.read_h5ad(os.path.join(DATAPATH, FILE))
adata_gene_list = adata.var['gene_symbol'].values.tolist()

# New today
def check_valid_tfs(word_list, adata_gene_list):

    word_list_upper = [word.upper() for word in word_list]
    adata_gene_list_upper = [gene.upper() for gene in adata_gene_list]

    # print what does not overlap
    not_valid_gene = set(word_list_upper) - set(adata_gene_list_upper)

    print("Entries in the table that are not genes in the counts matrix: ")
    
    return list(not_valid_gene)
    
check_valid_tfs(word_list, adata_gene_list)

In [None]:
# for those genes with multiple aliases, checking which are valid
# we have to worry about making this case insensitive just in case
multiple_alias_dict = {}
multiple_alias_dict.update({
    'p53': ['BCC7','BMFS5', 'LFS1', 'TRP53'],
    'OCT3/4': ['POU5F1', 'OCT3', 'OCT4', 'OTF4', 'MGC22487'],
    'MASH1': ['HASH1', 'BHLHa46', 'ASH1', 'ASH-1', 'ASCL1'],
    'HNF6': ['HNF6', 'HNF6A', 'ONECUT1'],
    'HB9': ['MNX1', 'HOXHB9', 'SCRA1', 'HLXB9', 'GC07M156491', 'GC07M156786', 'GC07M150530'],
    'PPARG2': ['PPARG', 'NR1C3', 'PPARG1', 'PPARgamma', 'PPAR-Gamma', 'PPARG5', 'CIMT1', 'GLM1'], 
    'PU.1': ['SPI1', 'SPI-A', 'SFPI1', 'SPI-1', 'OF', 'AGM10'],
    'N-MYC': ['MYCN', 'BHLHe37', 'N-Myc', 'NMYC', 'MYCNOT', 'MYCNsORF', 'MYCNsPEP', 'BHLHE37', 'FGLDS1', 'MODED', 'MPAPA', 'ODED'],
    # oct9 had a strange genecards lookup
    'OCT9': ['POU3F4', 'SLC22A16']
    'LEF-1': ['TCF1ALPHA', 'TCF7L3', 'TCF10', 'LEF1'],
    # these next two were listed as ER71/ETV2
    'ER71/ETV2': ['ER71', 'ETV2', 'ETSRP71'],
    # sv40 had a strange genecards lookup
    'SV40': [''],
})

In [None]:
# day 4, technically (migrate this)
adata_gene_list = adata.var['gene_symbol'].values.tolist()
def identify_gene_name_translations(multiple_alias_dict, adata_gene_list):
    """
    Identifies whether any values from multiple_alias_dict appear in adata_gene_list, case-insensitive.
    
    Parameters:
    multiple_alias_dict (dict): A dictionary where keys are gene names and values are lists of aliases.
    adata_gene_list (list): A list of gene names to check against, case-insensitive.

    Prints:
    For each key, whether it was found in the gene list along with the matching values.
    """
    # Convert the gene list to uppercase for case-insensitive comparison
    adata_gene_list_upper = [gene.upper() for gene in adata_gene_list]

    # Loop through each key and values in the dictionary
    for key, values in multiple_alias_dict.items():
        # Convert each alias to uppercase
        values_upper = [value.upper() for value in values]
        
        # Check if any alias is present in the gene list
        found_values = [value for value in values_upper if value in adata_gene_list_upper]
        
        # Print appropriate message based on whether any values were found
        if found_values:
            print(f"{key} was found in gene list as {found_values}")
        else:
            print(f"{key} was not found in gene list.")

identify_gene_name_translations(multiple_alias_dict, adata_gene_list)

In [None]:
# The above list is small enough that I can manually check it.

# Things to remove from word_list:
not_genes = ['Variant', 'Large',  '(ETS', '2)', 'Knockdown', ]

# Valid genes to replace/rename in word_list:
genes_to_translate = ['LMX1A;', #'P53']
translated_names = ['LMX1A',]

# These ones might have appeared as multiple entries, etc. b/c of spacing. easiest way was to delete and add back
genes_to_add = ['ETS2',]

# replace then subtract and add. [--------------]
new_word_list = set(word_list) - set(not_genes)

# Running the function one more time to check:





In [None]:
""" Renamed med_nonz to max_exp to be more accurate. """

In [None]:
# Copied from a prev day

"""
Josh, please read: the adata.obs['scalar'] = scalar copies the scalar down for that call, associated with every cell in X. Same with ['scaled'] and ['scaled_by']
in var. This is good in case the data is later appended into one anndata object.
But my return from the perturb_counts loop (cell below this) is a dictionary of all of the perturb_counts, since appending along any axis will probably either overwrite
obs or var.
"""


def perturb_counts(tf_list, scalar, adata): 
    """
    Applies a perturbation to the expression data of specific genes in an AnnData object.

    This function performs the following steps:
    1. Computes the maximum gene expression level for each cell.
    2. Applies a scaling operation to the expression levels of genes listed in `tf_list`.
       - Each entry of these genes in the matrix is multiplied by the maximum expression level 
         of its respective cell and a specified scalar value.
    3. Updates the AnnData object with new columns:
       - 'scaled': A boolean column indicating whether each gene is in the `tf_list`.
       - 'scaled_by': Contains the scaling factor used for each gene (the product of the maximum 
         expression level of each cell and the scalar), or `1` if the gene was not in `tf_list`.
    
    Parameters:
    tf_list (list): A list of gene symbols to be perturbed.
    scalar (float): The scalar value used to scale the expression levels.
    adata (AnnData): The AnnData object containing gene expression data.

    Returns:
    AnnData: The updated AnnData object with applied perturbations and new columns.
    """
    # Save the original state of the parameter objects, in case some tfs do not translate (failsafe)
    original_X = adata.X.copy()
    original_gene_mask = gene_mask.copy()
    
    # Compute maximum expression level of each cell
    max_exp = np.max(adata.X, axis=1)
    
    # Create a boolean mask for genes in tf_list
    gene_mask = adata.var['gene_symbol'].isin(tf_list)

    """This is new today. v """
    # Raise an error if any of the gene names in tf_list do not match column names (we will manually update these in adata):
    missing_genes = [gene for gene in tf_list if gene not in adata.var['gene_symbol'].values]
    
    if missing_genes:
        # Restore original parameter objects
        adata.X = original_X
        gene_mask = original_gene_mask
        raise ValueError(f"Genes {missing_genes} not found in anndata object")

    else:
        """ This is new today. ^ """
        # Apply the scaling operation to the specified genes
        adata.X[:, gene_mask] *= max_exp[:, np.newaxis] * scalar
        
        # Add/Update 'scaled' column in var
        adata.var['scaled'] = gene_mask
        
        # Add/Update 'scaled_by' column in var
        adata.var['scaled_by'] = 1  # Default value for genes not in tf_list
        adata.var.loc[gene_mask, 'scaled_by'] = max_exp[:, np.newaxis] * scalar  # Correct scaling factor assignment
    
    return adata



In [None]:
# Copied from a prev day

import anndata

def iterate_perturb_counts(tf_list, scalar_list, adata):
    """
    Applies perturbations to the expression data of specified transcription factors across multiple scalars 
    and stores the resulting AnnData objects in a dictionary.

    This function performs the following steps:
    1. Iterates over a list of scalar values.
    2. For each scalar, creates a copy of the AnnData object to preserve the original data.
    3. Applies the `perturb_counts` function to scale the expression data of genes listed in `tf_list` by
       the maximum gene expression of each cell and the current scalar.
    4. Stores the perturbed AnnData object in a dictionary with the scalar as the key.

    Parameters:
    tf_list (list): A list of gene symbols (transcription factors) to be perturbed.
    scalar_list (list): A list of scalar values for scaling the gene expression.
    adata (AnnData): The AnnData object containing gene expression data (cells x genes).

    Returns:
    dict: A dictionary where keys are scalar values and values are the corresponding perturbed AnnData objects.
    """
    
    adata_dict = {}
    
    for scalar in scalar_list:
        # Create a copy of the AnnData object for each scalar value
        adata_temp = adata.copy()
        
        # Apply perturb_counts to the copied AnnData object
        perturbed_adata = perturb_counts(tf_list, scalar, adata_temp)
        
        # Store the perturbed AnnData object in the dictionary with scalar as the key
        adata_dict[scalar] = perturbed_adata
    
    return adata_dict



# Day 4: Copied From Josh's Notebook

**Focus:** check out Nats code (debug a bit) and create a few perturbations
- changes made to NO's code:
    1. iterate_perturb_counts: changes the order of the arguments to `adata, tf_list, scalar_list`
    2. perturb_counts: changes the order of the arguments to `adata, tf_list, scalar_list`
    3. perturb_counts: there was an issue with the use of `[: np.newaxis]` with respect to `max_exp`, which is a `coo_matrix` (special type of sparse matrix). Code was modified to address an issue being thrown here.
- new function:
    1. validateTFs(TFs, adata): this checks if all the transcription factors are present in the adata
- pertrubation driver (`Perform Perturbations and Create new files`):
    1. loads Fibroblast data from Tabula Sapiens
    2. loads `.csv` file of known reprogrmaming protocols
    3. for each set of TFs that are validated by `validateTFs`:
        1. use `iterate_perturb_counts` to generate perturbations with scalars `[0.5, 0.75, 1]`
        2. concatenate the dataframes to make a single dataframe
        3. save metadata from reprogramming protocol (i.e. PMID, source/targets, etc.)
        4. save the new anndata as a `.h5ad` file

## Nat's Code with some modifications

In [None]:
import numpy as np
import anndata as ad
import pandas as pd
import scanpy as sp
import os

In [None]:
def iterate_perturb_counts(adata, tf_list, scalar_list):
    """
    Applies perturbations to the expression data of specified transcription factors across multiple scalars 
    and stores the resulting AnnData objects in a dictionary.

    This function performs the following steps:
    1. Iterates over a list of scalar values.
    2. For each scalar, creates a copy of the AnnData object to preserve the original data.
    3. Applies the `perturb_counts` function to scale the expression data of genes listed in `tf_list` by
       the maximum gene expression of each cell and the current scalar.
    4. Stores the perturbed AnnData object in a dictionary with the scalar as the key.

    Parameters:
    tf_list (list): A list of gene symbols (transcription factors) to be perturbed.
    scalar_list (list): A list of scalar values for scaling the gene expression.
    adata (AnnData): The AnnData object containing gene expression data (cells x genes).

    Returns:
    dict: A dictionary where keys are scalar values and values are the corresponding perturbed AnnData objects.
    """
    
    adata_dict = {}
    
    for scalar in scalar_list:
        # Create a copy of the AnnData object for each scalar value
        adata_temp = adata.copy()
        
        # Apply perturb_counts to the copied AnnData object
        perturbed_adata = perturb_counts(adata_temp, tf_list, scalar)
        
        # Store the perturbed AnnData object in the dictionary with scalar as the key
        adata_dict[scalar] = perturbed_adata
    
    return adata_dict

def perturb_counts(adata, tf_list, scalar): 
    """
    Applies a perturbation to the expression data of specific genes in an AnnData object.

    This function performs the following steps:
    1. Computes the maximum gene expression level for each cell.
    2. Applies a scaling operation to the expression levels of genes listed in `tf_list`.
       - Each entry of these genes in the matrix is multiplied by the maximum expression level 
         of its respective cell and a specified scalar value.
    3. Updates the AnnData object with new columns:
       - 'scaled': A boolean column indicating whether each gene is in the `tf_list`.
       - 'scaled_by': Contains the scaling factor used for each gene (the product of the maximum 
         expression level of each cell and the scalar), or `1` if the gene was not in `tf_list`.
    
    Parameters:
    tf_list (list): A list of gene symbols to be perturbed.
    scalar (float): The scalar value used to scale the expression levels.
    adata (AnnData): The AnnData object containing gene expression data.

    Returns:
    AnnData: The updated AnnData object with applied perturbations and new columns.
    """

    # Create a boolean mask for genes in tf_list
    gene_mask = adata.var['gene_symbol'].isin(tf_list)
    
    # Save the original state of the parameter objects, in case some tfs do not translate (failsafe)
    original_X = adata.X.copy()
    original_gene_mask = gene_mask.copy()
    
    # Compute maximum expression level of each cell
    max_exp = np.max(adata.X, axis=1)

    """This is new today. v """
    # Raise an error if any of the gene names in tf_list do not match column names (we will manually update these in adata):
    missing_genes = [gene for gene in tf_list if gene not in adata.var['gene_symbol'].values]
    
    if missing_genes:
        # Restore original parameter objects
        adata.X = original_X
        gene_mask = original_gene_mask
        raise ValueError(f"Genes {missing_genes} not found in anndata object")

    else:    
        
        # Apply the scaling operation to the specified genes
        adata.X[:, gene_mask] = max_exp * scalar
        
        # Add/Update 'scaled' column in var
        adata.var['scaled'] = gene_mask
        
        # Add/Update 'scaled_by' column in var
        adata.var['scaled_by'] = scalar  # Default value for genes not in tf_list
    
    return adata


## New code to validate lists of TFs

In [None]:
def validateTFs(TFs, adata):
    adata_gene_list = adata.var['gene_symbol'].values.tolist()
    for TF in TFs:
        if TF not in adata_gene_list:
            return False
    return True

## Load data and perturbations

In [None]:
df = pd.read_csv('data/first_5_recepies_8_29_2024.csv')

DATAPATH = "/nfs/turbo/umms-indikar/shared/projects/DGC/data/tabula_sapiens/jpic/"
FILE = "fibroblast.h5ad"
adata = sp.read_h5ad(os.path.join(DATAPATH, FILE))
adata_gene_list = adata.var['gene_symbol'].values.tolist()


## Perform Perturbations and Create new files

In [None]:
len(df['TFs'])

In [None]:
output_directory = "/nfs/turbo/umms-indikar/shared/projects/DARPA_AI/in-silico-reprogramming/one-shot/perturbed"
scalars = [0.5, 0.75, 1.001]
for i in range(len(df['TFs'])):
    val = df['TFs'].iloc[i]
    val = val.replace(',',' ')
    val = val.replace(';',' ')
    val = val.replace(':',' ')
    TFs = val.split(' ')
    if validateTFs(TFs, adata):
        print(TFs)
        adataDict = iterate_perturb_counts(adata, TFs, scalars)

        # Concatenate all AnnData objects along the observations axis
        concatenated_adata = ad.concat(list(adataDict.values()), axis=0)

        # Save reprogramming metadata into the concatenated_adata.obs table
        concatenated_adata.obs['Source_cells'] = df['Source cells'].iloc[i]
        concatenated_adata.obs['Target_cells'] = df['Target cells'].iloc[i]
        concatenated_adata.obs['Treatment'] = df['Treatment'].iloc[i]
        concatenated_adata.obs['Species'] = df['Species'].iloc[i]
        concatenated_adata.obs['Cell_Transplantation'] = df['Cell Transplantation'].iloc[i]
        concatenated_adata.obs['Published_Year'] = df['Published Year'].iloc[i]
        concatenated_adata.obs['PMID'] = df['PMID'].iloc[i]
        
        # Join the TFs list into a string for the filename
        TFs_str = "_".join(TFs)
        
        # Generate the file path for saving
        file_name = f"{TFs_str}.h5ad"
        output_path = os.path.join(output_directory, file_name)
        
        # Save the concatenated AnnData object to the file
        concatenated_adata.write_h5ad(output_path)


# Day 3

# Day 2

In [None]:

# Copied from a prev day

DATAPATH = "/nfs/turbo/umms-indikar/shared/projects/DGC/data/tabula_sapiens/extract/"
FILE = "TS_epithelial.h5ad"
adata = sp.read_h5ad(os.path.join(DATAPATH, FILE))
adata_gene_list = adata.var['gene_symbol'].values.tolist()

# Print the first 5 entries
print("First 5 entries:")
print(adata_gene_list[:5])

# Print the last 5 entries
print("Last 5 entries:")
print(adata_gene_list[-5:])

In [None]:
# outside the function so I can manipulate these directly.
# I also checked whether it matters if I do case insensitive (capitalize everything then compare) or case sensitive.
# Unsurprisingly, case sensitive has more discrepancies (45 vs. 44), with the one extra that was picked up as "Ptf1a".
# I left the case insensitive version. 

# get a list of words (potential genes, also includes and, + , etc. ) from the table from the review paper
table_1_df = pd.read_csv("/home/oliven/scFoundationModels/notebooks/reprogramming/data/table_1_data_from_paper_9_1.csv")
combined_string = ' '.join(table_1_df['TFs'].astype(str)).replace(',', '')
word_list = combined_string.split()

# get a list of genes that appear in the data matrix
DATAPATH = "/nfs/turbo/umms-indikar/shared/projects/DGC/data/tabula_sapiens/extract/"
FILE = "TS_epithelial.h5ad"
adata = sp.read_h5ad(os.path.join(DATAPATH, FILE))
adata_gene_list = adata.var['gene_symbol'].values.tolist()

# New today
def check_valid_tfs(word_list, adata_gene_list):

    word_list_upper = [word.upper() for word in word_list]
    adata_gene_list_upper = [gene.upper() for gene in adata_gene_list]

    # print what does not overlap
    not_valid_gene = set(word_list_upper) - set(adata_gene_list_upper)

    print("Entries in the table that are not genes in the counts matrix: ")
    
    return list(not_valid_gene)
    
check_valid_tfs(word_list, adata_gene_list)

In [None]:
# for those genes with multiple aliases, checking which are valid
# we have to worry about making this case insensitive just in case
multiple_alias_dict = {}
multiple_alias_dict.update({
    'p53': ['BCC7','BMFS5', 'LFS1', 'TRP53'],
    'OCT3/4': ['POU5F1', 'OCT3', 'OCT4', 'OTF4', 'MGC22487'],
    'MASH1': ['HASH1', 'BHLHa46', 'ASH1', 'ASH-1', 'ASCL1'],
    'HNF6': ['HNF6', 'HNF6A', 'ONECUT1'],
    'HB9': ['MNX1', 'HOXHB9', 'SCRA1', 'HLXB9', 'GC07M156491', 'GC07M156786', 'GC07M150530'],
    'PPARG2': ['PPARG', 'NR1C3', 'PPARG1', 'PPARgamma', 'PPAR-Gamma', 'PPARG5', 'CIMT1', 'GLM1'], 
    'PU.1': ['SPI1', 'SPI-A', 'SFPI1', 'SPI-1', 'OF', 'AGM10'],
    'N-MYC': ['MYCN', 'BHLHe37', 'N-Myc', 'NMYC', 'MYCNOT', 'MYCNsORF', 'MYCNsPEP', 'BHLHE37', 'FGLDS1', 'MODED', 'MPAPA', 'ODED'],
    # oct9 had a strange genecards lookup
    'OCT9': ['POU3F4', 'SLC22A16']
    'LEF-1': ['TCF1ALPHA', 'TCF7L3', 'TCF10', 'LEF1'],
    # these next two were listed as ER71/ETV2
    'ER71/ETV2': ['ER71', 'ETV2', 'ETSRP71'],
    # sv40 had a strange genecards lookup
    'SV40': [''],
})

In [None]:
# day 4, technically (migrate this)
adata_gene_list = adata.var['gene_symbol'].values.tolist()
def identify_gene_name_translations(multiple_alias_dict, adata_gene_list):
    """
    Identifies whether any values from multiple_alias_dict appear in adata_gene_list, case-insensitive.
    
    Parameters:
    multiple_alias_dict (dict): A dictionary where keys are gene names and values are lists of aliases.
    adata_gene_list (list): A list of gene names to check against, case-insensitive.

    Prints:
    For each key, whether it was found in the gene list along with the matching values.
    """
    # Convert the gene list to uppercase for case-insensitive comparison
    adata_gene_list_upper = [gene.upper() for gene in adata_gene_list]

    # Loop through each key and values in the dictionary
    for key, values in multiple_alias_dict.items():
        # Convert each alias to uppercase
        values_upper = [value.upper() for value in values]
        
        # Check if any alias is present in the gene list
        found_values = [value for value in values_upper if value in adata_gene_list_upper]
        
        # Print appropriate message based on whether any values were found
        if found_values:
            print(f"{key} was found in gene list as {found_values}")
        else:
            print(f"{key} was not found in gene list.")

identify_gene_name_translations(multiple_alias_dict, adata_gene_list)

In [None]:
# The above list is small enough that I can manually check it.

# Things to remove from word_list:
not_genes = ['Variant', 'Large',  '(ETS', '2)', 'Knockdown', ]

# Valid genes to replace/rename in word_list:
genes_to_translate = ['LMX1A;', #'P53']
translated_names = ['LMX1A',]

# These ones might have appeared as multiple entries, etc. b/c of spacing. easiest way was to delete and add back
genes_to_add = ['ETS2',]

# replace then subtract and add. [--------------]
new_word_list = set(word_list) - set(not_genes)

# Running the function one more time to check:





In [None]:
""" Renamed med_nonz to max_exp to be more accurate. """

In [None]:
# Copied from a prev day

"""
Josh, please read: the adata.obs['scalar'] = scalar copies the scalar down for that call, associated with every cell in X. Same with ['scaled'] and ['scaled_by']
in var. This is good in case the data is later appended into one anndata object.
But my return from the perturb_counts loop (cell below this) is a dictionary of all of the perturb_counts, since appending along any axis will probably either overwrite
obs or var.
"""


def perturb_counts(tf_list, scalar, adata): 
    """
    Applies a perturbation to the expression data of specific genes in an AnnData object.

    This function performs the following steps:
    1. Computes the maximum gene expression level for each cell.
    2. Applies a scaling operation to the expression levels of genes listed in `tf_list`.
       - Each entry of these genes in the matrix is multiplied by the maximum expression level 
         of its respective cell and a specified scalar value.
    3. Updates the AnnData object with new columns:
       - 'scaled': A boolean column indicating whether each gene is in the `tf_list`.
       - 'scaled_by': Contains the scaling factor used for each gene (the product of the maximum 
         expression level of each cell and the scalar), or `1` if the gene was not in `tf_list`.
    
    Parameters:
    tf_list (list): A list of gene symbols to be perturbed.
    scalar (float): The scalar value used to scale the expression levels.
    adata (AnnData): The AnnData object containing gene expression data.

    Returns:
    AnnData: The updated AnnData object with applied perturbations and new columns.
    """
    # Save the original state of the parameter objects, in case some tfs do not translate (failsafe)
    original_X = adata.X.copy()
    original_gene_mask = gene_mask.copy()
    
    # Compute maximum expression level of each cell
    max_exp = np.max(adata.X, axis=1)
    
    # Create a boolean mask for genes in tf_list
    gene_mask = adata.var['gene_symbol'].isin(tf_list)

    """This is new today. v """
    # Raise an error if any of the gene names in tf_list do not match column names (we will manually update these in adata):
    missing_genes = [gene for gene in tf_list if gene not in adata.var['gene_symbol'].values]
    
    if missing_genes:
        # Restore original parameter objects
        adata.X = original_X
        gene_mask = original_gene_mask
        raise ValueError(f"Genes {missing_genes} not found in anndata object")

    else:
        """ This is new today. ^ """
        # Apply the scaling operation to the specified genes
        adata.X[:, gene_mask] *= max_exp[:, np.newaxis] * scalar
        
        # Add/Update 'scaled' column in var
        adata.var['scaled'] = gene_mask
        
        # Add/Update 'scaled_by' column in var
        adata.var['scaled_by'] = 1  # Default value for genes not in tf_list
        adata.var.loc[gene_mask, 'scaled_by'] = max_exp[:, np.newaxis] * scalar  # Correct scaling factor assignment
    
    return adata



In [None]:
# Copied from a prev day

import anndata

def iterate_perturb_counts(tf_list, scalar_list, adata):
    """
    Applies perturbations to the expression data of specified transcription factors across multiple scalars 
    and stores the resulting AnnData objects in a dictionary.

    This function performs the following steps:
    1. Iterates over a list of scalar values.
    2. For each scalar, creates a copy of the AnnData object to preserve the original data.
    3. Applies the `perturb_counts` function to scale the expression data of genes listed in `tf_list` by
       the maximum gene expression of each cell and the current scalar.
    4. Stores the perturbed AnnData object in a dictionary with the scalar as the key.

    Parameters:
    tf_list (list): A list of gene symbols (transcription factors) to be perturbed.
    scalar_list (list): A list of scalar values for scaling the gene expression.
    adata (AnnData): The AnnData object containing gene expression data (cells x genes).

    Returns:
    dict: A dictionary where keys are scalar values and values are the corresponding perturbed AnnData objects.
    """
    
    adata_dict = {}
    
    for scalar in scalar_list:
        # Create a copy of the AnnData object for each scalar value
        adata_temp = adata.copy()
        
        # Apply perturb_counts to the copied AnnData object
        perturbed_adata = perturb_counts(tf_list, scalar, adata_temp)
        
        # Store the perturbed AnnData object in the dictionary with scalar as the key
        adata_dict[scalar] = perturbed_adata
    
    return adata_dict



In [None]:
# We want each recipe (returned as a dictionary of anndata objects, one adata object for each scalar)
def save_perturb_to_turbo()

In [None]:
### Testing on one of the tf lists from the file.


# Day 1

## Perturbation Model Discussion

E.V. = expression values

Possible algorithm:
```
1. find highest E.V.  for a single cell
2. find expression value of TFs being modified
3. have a value k for the number of different concentrations we want to test
4. choose k different amounts to increase the TFs from there measured E.V. to the 150% maximum E.V.
   - make an arbitray choice and code it up
```

**A reasonable person could write this 10s of different ways**

In [None]:
import numpy as np
# median nonzero value of each row
# the w stands for working. I just dont want to screw up the original.
DATAPATH = "/nfs/turbo/umms-indikar/shared/projects/DGC/data/tabula_sapiens/extract/"
FILE = "TS_epithelial.h5ad"
adata_w = sp.read_h5ad(os.path.join(DATAPATH, FILE))
adata_w
# def median_nonzero(col):
#     nonzero_vals = col[col != 0]  # Extract nonzero values
#     if len(nonzero_vals) == 0:    # If no nonzero values, return NaN
#         return np.nan
#     return np.median(nonzero_vals)

# # Apply the function to each column and store the results
# med_nonz = np.apply_along_axis(median_nonzero, axis=0, arr=adata_w.X)
# adata_w.var['med_nonz'] = med_nonz
# adata_w.var

In [None]:
import numpy as np


X = adata_w.X.toarray() if not isinstance(adata_w.X, np.ndarray) else adata_w.X

def median_nonzero(col):
    nonzero_vals = col[col != 0] 
    return np.median(nonzero_vals) if len(nonzero_vals) > 0 else 0

#perform and save results of fn
med_nonz = np.apply_along_axis(median_nonzero, axis=0, arr=X)
adata_w.var['med_nonz'] = med_nonz


In [None]:
adata_w.var.head()

In [None]:
# seeing what it looks like before tf_list changes the first 3 rows
# Convert to dense if it's sparse and display the first five rows
import numpy as np

# Convert to a dense array if necessary
dense_X = adata_w.X.toarray() if not isinstance(adata_w.X, np.ndarray) else adata_w.X



In [None]:
# problem, scaling by a factor of the max expressed gene in that cell means that you could be scaling by different genes for each cell,
# when the cells are all of the same type. for each get the median nonzero expression

# for testing purposes: v
tf_list = ['DDX11L1', 'WASH7P', 'MIR6859-1']
tf = 'DDX11L1'
# for testing purposes: ^

#mask = adata_w.obs_names.isin(tf_list)
mask = np.where(adata_w.var['gene_symbol'] == tf)[0]
adata_w.X[mask, :] = adata_w.X[mask, :] * adata_w.var['med_nonz'].values 



In [None]:
DATAPATH = "/nfs/turbo/umms-indikar/shared/projects/DGC/data/tabula_sapiens/extract/"
FILE = "TS_epithelial.h5ad"
adata = sp.read_h5ad(os.path.join(DATAPATH, FILE))
(adata_w.X[mask, :] - adata.X[mask, :]).sum()

In [None]:
adata_w.X[mask, :]

### Scaling by median nonzero entry of each gene across all cells (old)

In [None]:

# import numpy as np
# DATAPATH = "/nfs/turbo/umms-indikar/shared/projects/DGC/data/tabula_sapiens/extract/"
# FILE = "TS_epithelial.h5ad"
# adata_w = sp.read_h5ad(os.path.join(DATAPATH, FILE))
# #tf = 'DDX11L1'
# tf_list =['DDX11L1', 'WASH7P', 'MIR6859-1']

# def median_nonzero(col):
#     nonzero_vals = col[col != 0] 
#     return np.median(nonzero_vals) if len(nonzero_vals) > 0 else 0

# # requires scalar is a scalar
# def perturb_counts(tf_list, scalar, adata): 
#     # compute nonzero median expression of each gene across cells, save to var
#     med_nonz = np.apply_along_axis(median_nonzero, axis=0, arr=adata.X)
#     adata.var['med_nonz'] = med_nonz

#     # filter by desired tf(s), and apply the nonzero_median scaling operation to only these  
#     mask = np.where(adata.var['gene_symbol'].isin(tf_list))[0]
#     adata.X[mask, :] = adata.X[mask, :] * adata.var['med_nonz'].values * scalar
#     return adata




### Scaling by max gene expression within each cell

In [None]:
# old version with (I believe) improper mask that was being applied to rows and not columns

# import numpy as np
# DATAPATH = "/nfs/turbo/umms-indikar/shared/projects/DGC/data/tabula_sapiens/extract/"
# FILE = "TS_epithelial.h5ad"
# adata_w = sp.read_h5ad(os.path.join(DATAPATH, FILE))
# #tf = 'DDX11L1'
# tf_list =['DDX11L1', 'WASH7P', 'MIR6859-1']


# # requires scalar is a scalar
# def perturb_counts(tf_list, scalar, adata): 
#     # compute nonzero median expression of each gene across cells, save to var
#     med_nonz = med_nonz = np.max(adata.X, axis=1)
#     adata.obs['med_nonz'] = med_nonz

#     # filter by desired tf(s), and apply the nonzero_median scaling operation to only these  
#     mask = np.where(adata.var['gene_symbol'].isin(tf_list))[0]
#     adata.X[mask, :] = adata.X[mask, :] * adata.obs['med_nonz'].values * scalar
#     return adata

# old version without extra obs and var rows telling what was scaled and by how much


# def perturb_counts(tf_list, scalar, adata): 
#     # Compute maximum expression level of each cell and save it to obs
#     med_nonz = np.max(adata.X, axis=1)
#     adata.obs['med_nonz'] = med_nonz
    
#     # apply operation only to genes in tf_list
#     mask = np.where(adata.var['gene_symbol'].isin(tf_list))[0]
#     adata.X[:, mask] = adata.X[:, mask] * adata.obs['med_nonz'].values[:, np.newaxis] * scalar
    
#     return adata

# version before gpt optimized

# def perturb_counts(tf_list, scalar, adata): 
#     # Compute maximum expression level of each cell and save it to obs
#     med_nonz = np.max(adata.X, axis=1)
#     adata.obs['med_nonz'] = med_nonz
    
#     # Add a new obs column called 'scalar' containing the scalar value for each row
#     adata.obs['scalar'] = scalar
    
#     # Create a mask for genes in tf_list
#     mask = np.where(adata.var['gene_symbol'].isin(tf_list))[0]
#     # Apply the scaling operation to the specified genes
#     adata.X[:, mask] = adata.X[:, mask] * adata.obs['med_nonz'].values[:, np.newaxis] * scalar
    
#     # Add a new var column called 'scaled' with True for genes in tf_list and False otherwise
#     adata.var['scaled'] = adata.var['gene_symbol'].isin(tf_list)
    
#     # Add a new var column 'scaled_by'
#     adata.var['scaled_by'] = 1
#     # Set scaling factor for genes in tf_list

#     ############I asked gpt to do this line and am unsure if it is correct. checking now.
#     adata.var.loc[adata.var['scaled'], 'scaled_by'] = adata.obs['med_nonz'].values[:, np.newaxis] * scalar
#     ###############
    
#     return adata

# loop version before gpt optimized

# # requires that within each perturbation, all of the transcription factors in tf_list are scaled by the same amount, that is, (scalar * "max gene expression in that cell")
# # requires adata is cells x genes

# import anndata

# def iterate_perturb_counts(tf_list, scalar_list, adata):
#     adata_dict = {}
    
#     for scalar in scalar_list:
#         adata_temp = adata.copy()
#         perturbed_adata = perturb_counts(tf_list, scalar, adata_temp)
#         adata_dict[scalar] = perturbed_adata
    
#     return adata_dict



In [None]:

"""
Josh, please read: the adata.obs['scalar'] = scalar copies the scalar down for that call, associated with every cell in X. Same with ['scaled'] and ['scaled_by']
in var. This is good in case the data is later appended into one anndata object.
But my return from the perturb_counts loop (cell below this) is a dictionary of all of the perturb_counts, since appending along any axis will probably either overwrite
obs or var.
"""
import numpy as np

def perturb_counts(tf_list, scalar, adata): 
    """
    Applies a perturbation to the expression data of specific genes in an AnnData object.

    This function performs the following steps:
    1. Computes the maximum gene expression level for each cell.
    2. Applies a scaling operation to the expression levels of genes listed in `tf_list`.
       - Each entry of these genes in the matrix is multiplied by the maximum expression level 
         of its respective cell and a specified scalar value.
    3. Updates the AnnData object with new columns:
       - 'scaled': A boolean column indicating whether each gene is in the `tf_list`.
       - 'scaled_by': Contains the scaling factor used for each gene (the product of the maximum 
         expression level of each cell and the scalar), or `1` if the gene was not in `tf_list`.
    
    Parameters:
    tf_list (list): A list of gene symbols to be perturbed.
    scalar (float): The scalar value used to scale the expression levels.
    adata (AnnData): The AnnData object containing gene expression data.

    Returns:
    AnnData: The updated AnnData object with applied perturbations and new columns.
    """
    
    # Compute maximum expression level of each cell
    med_nonz = np.max(adata.X, axis=1)
    
    # Create a boolean mask for genes in tf_list
    gene_mask = adata.var['gene_symbol'].isin(tf_list)
    
    # Apply the scaling operation to the specified genes
    adata.X[:, gene_mask] *= med_nonz[:, np.newaxis] * scalar
    
    # Add/Update 'scaled' column in var
    adata.var['scaled'] = gene_mask
    
    # Add/Update 'scaled_by' column in var
    adata.var['scaled_by'] = 1  # Default value for genes not in tf_list
    adata.var.loc[gene_mask, 'scaled_by'] = med_nonz[:, np.newaxis] * scalar  # Correct scaling factor assignment
    
    return adata



In [None]:
import anndata

def iterate_perturb_counts(tf_list, scalar_list, adata):
    """
    Applies perturbations to the expression data of specified transcription factors across multiple scalars 
    and stores the resulting AnnData objects in a dictionary.

    This function performs the following steps:
    1. Iterates over a list of scalar values.
    2. For each scalar, creates a copy of the AnnData object to preserve the original data.
    3. Applies the `perturb_counts` function to scale the expression data of genes listed in `tf_list` by
       the maximum gene expression of each cell and the current scalar.
    4. Stores the perturbed AnnData object in a dictionary with the scalar as the key.

    Parameters:
    tf_list (list): A list of gene symbols (transcription factors) to be perturbed.
    scalar_list (list): A list of scalar values for scaling the gene expression.
    adata (AnnData): The AnnData object containing gene expression data (cells x genes).

    Returns:
    dict: A dictionary where keys are scalar values and values are the corresponding perturbed AnnData objects.
    """
    
    adata_dict = {}
    
    for scalar in scalar_list:
        # Create a copy of the AnnData object for each scalar value
        adata_temp = adata.copy()
        
        # Apply perturb_counts to the copied AnnData object
        perturbed_adata = perturb_counts(tf_list, scalar, adata_temp)
        
        # Store the perturbed AnnData object in the dictionary with scalar as the key
        adata_dict[scalar] = perturbed_adata
    
    return adata_dict



## Visualize Input Data

In [None]:
DATAPATH = "/nfs/turbo/umms-indikar/shared/projects/DGC/data/tabula_sapiens/extract/"
FILE = "TS_epithelial.h5ad"

adata = sp.read_h5ad(os.path.join(DATAPATH, FILE))

In [None]:
adata

In [None]:
adata.var

In [None]:
adata.X.max(axis=1) # what is the value of the highest expressed gene for each cell?

In [None]:
TF = 'DDX11L1'
index = np.where(adata.var['gene_symbol'] == TF)[0]
index

In [None]:
adata.var['gene_symbol']

In [None]:
adata.obs

## Build driver

In [None]:
import pandas as pd

def main(job_number, parameter_file):
    """
    This is the main function for the array job to perform the reprogramming experiment. job_number is a single parameter
    that will be used to look up in a parameter table which model, reprogramming recipe, and other information relevant
    to the test.
    """

    # Determine embedding parameters and recipie
    df_embedding_parameters = pd.read_csv(parameter_file)
    TFs    = df_embedding_parameters['TFs'].values[job_number]
    model  = df_embedding_parameters['model'].values[job_number]
    source = df_embedding_parameters['source'].values[job_number]
    target = df_embedding_parameters['target'].values[job_number]

    # Load the source data
    adata = 

    # Perturb the data
    perturbed_adata = perturbation_model(adata, TFs)

    # Generate embeddings
    if model == 'geneformer':
        adata_embedded = embed_geneformer([source_adata, perturbed_adata, target_adata])
    elif model == 'tGPT':
        adata_embedded = embed_tGPT([source_adata, perturbed_adata, target_adata])
    elif model == 'scGTP':
        adata_embedded = embed_scGTP([source_adata, perturbed_adata, target_adata])

    # Save the results to a file

    return 0


## Build parameter dataframe

In [None]:
embedding_parameters = {
    'source': [],
    'target': [],
    'TFs'   : [],
    'model' : []
}
models = ['geneformer', 'tGPT', 'scGTP']

df = pd.read_csv('data/first_5_recepies_8_29_2024.csv')

for i in range(5):
    TFs = df['TFs'].values[i].split()
    source = df['Source'].values[i]
    target = df['Target'].values[i]
    for model in models:
        embedding_parameters['TFs'].append(TFs)
        embedding_parameters['source'].append(source)
        embedding_parameters['target'].append(source)
        embedding_parameters['model'].append(model)

df_embedding_parameters = pd.DataFrame(embedding_parameters)

df_embedding_parameters

In [None]:
df

# Day 0

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
# load known reprogramming regiems
df = pd.read_csv('data/known-regiems-T1.csv')

In [None]:
# get list of unique transcription factors
TFs = []
for regime in df['TFs'].unique():
    TFs += regime.replace(',', '').split()
TFs = list(set(TFs))
print(f"{len(TFs)=}")
TFs