In [124]:
import pandas as pd
import numpy as np
import scanpy as sc
from sklearn.linear_model import LinearRegression
from collections import Counter
from scipy import stats
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
from collections import defaultdict
import math
from pathlib import Path
import os
from sklearn.preprocessing import LabelEncoder

In [15]:
## optional
## delete celltypes with cell counts lower than certain amount
## specifing an lower bond is a must
def filter_low_counts(celltype_df, age_df, celltype_col, threshold):
    print("Checking low count cell types...")
    
    celltype_count = Counter(celltype_df[celltype_col])
    for key in celltype_count:
        if threshold == None:
            unique_ages = np.unique(age_df)
            num_groups = (len(unique_ages) + 1) * 100
            if celltype_count[key] < num_groups:
                print(key, " has too low counts")
                celltype_df = celltype_df[celltype_df[celltype_col] != key]
        else:
            if celltype_count[key] < threshold:
                print(key, " has too low counts")
                celltype_df = celltype_df[celltype_df[celltype_col] != key]
    return celltype_df

In [16]:
def get_skewed_count_info(adata, class_col, age_col, age_threshold):
    print("Checking skewed count cell types...")
    
    # Compute the fraction of cells for each age group within each cell ontology class
    group_counts = adata.obs.groupby([class_col, age_col]).size()
    total_counts = adata.obs.groupby([class_col]).size()
    
    # Calculate the fraction of each age group within each class
    class_age_fraction = group_counts / total_counts
    
    # Find the cell classes to filter out based on age distribution
    classes_to_filter = class_age_fraction[class_age_fraction > age_threshold].index.get_level_values(0).unique()
    
    return classes_to_filter

In [17]:
## Read h5ad file 
## and do cell type filtering based on age distribution and cell count thresholds.
def read_and_filter_h5ad(filepath_1, filepath_2, class_col="celltype", age_col="age", age_threshold=0.8, count_threshold=None):
    """Parameters:
    adata: AnnData object
        The Scanpy AnnData object containing single-cell data.
    class_col: str, optional (default: 'celltype')
        The column name in adata.obs representing the cell ontology class.
    age_col: str, optional (default: 'age')
        The column name in adata.obs representing the age of the cells.
    age_threshold: float, optional (default: 0.8)
        The threshold fraction for filtering based on age distribution. If one age group has more than this
        fraction of cells in a class, the class will be filtered out.
    count_threshold: list, optional (default: [100])
        Threshold for filtering cell types based on count. If a single value is provided,
        it filters out cell types with counts lower than this value. If a range is provided,
        it filters out cell types outside this range.
    
    Returns:
    filtered_adata: AnnData object
        The filtered AnnData object with specified cell ontology classes removed based on both criteria."""
    try:
        adata1 = sc.read_h5ad(filepath_1)
        adata2 = sc.read_h5ad(filepath_2)
        adata = adata1.concatenate(adata2)
        celltype_df = adata.obs[[class_col]].copy()
        age_df = adata.obs[[age_col]].copy()
        
        # Apply the cell count threshold filtering
        celltype_df = filter_low_counts(celltype_df, age_df, class_col, count_threshold)
    
        # Create a filtered AnnData object based on cell count filtering
        filtered_adata = adata[celltype_df.index].copy()
        
        # Identify the skewed classes to filter based on age distribution
        classes_to_filter = get_skewed_count_info(filtered_adata, class_col, age_col, age_threshold)
        
        if len(classes_to_filter):
            print(classes_to_filter[0], " has skewed cell counts")
        # Further filter the AnnData object based on age distribution
        final_filtered_adata = filtered_adata[~filtered_adata.obs[class_col].isin(classes_to_filter)].copy()
        
        return final_filtered_adata
    except Exception as e:
        raise(e)

In [None]:
# Get the current working directory, this should get the path automatically
# hope it works for mac
current_dir = Path.cwd()
print(f"Current working directory: {current_dir}")

# Construct full file paths using the `/` operator
file1 = current_dir / "tabula-muris-senis-facs-processed-official-annotations-Brain_Myeloid.h5ad"
file2 = current_dir / "tabula-muris-senis-facs-processed-official-annotations-Brain_Non-Myeloid.h5ad"
# Print the paths for verification
print(f"File 1 path: {file1}")
print(f"File 2 path: {file2}")

# Verify that files exist
assert file1.is_file(), f"File not found: {file1}"
assert file2.is_file(), f"File not found: {file2}"

# Read and filter the data
adata = read_and_filter_h5ad(str(file1), str(file2), "cell_ontology_class", "age")

# adata = read_and_filter_h5ad("../Mouse Tabula Muris/tabula-muris-senis-facs-processed-official-annotations-Brain_Myeloid.h5ad", 
#                              "../Mouse Tabula Muris/tabula-muris-senis-facs-processed-official-annotations-Brain_Non-Myeloid.h5ad",
#                              "cell_ontology_class", "age")
# # 

In [None]:
# What does adata look like?

print(type(adata))
print(adata)
print(type(adata.obs))
# print out all obs columns
print(adata.obs.columns)    
print(adata.obsm['X_tsne'].shape)
print(adata.X.T)

In [None]:
# print out all obs columns
for col in adata.obs.columns:
    print(adata.obs[col].value_counts())
    print("***************************************************")
    print("\n")


# print(adata.obs['age'].value_counts())

# question: why there is not 18m here

In [None]:
# data preprocessing

sc.pp.filter_genes(adata, min_cells=5)
sc.pp.filter_cells(adata, min_genes=500)
adata.obs['n_counts'] = np.sum(adata.X, axis=1).A1
adata = adata[adata.obs['n_counts']>=3000]
sc.pp.normalize_per_cell(adata, counts_per_cell_after=1e4) #simple lib size normalization?
adata = sc.pp.filter_genes_dispersion(adata, subset = False, min_disp=.5, max_disp=None, 
                              min_mean=.0125, max_mean=10, n_bins=20, n_top_genes=None, 
                              log=True, copy=True)
sc.pp.log1p(adata)
sc.pp.scale(adata, max_value=10, zero_center=False)
sc.tl.pca(adata,use_highly_variable=True)
sc.pp.neighbors(adata, n_neighbors=18)
sc.tl.louvain(adata, resolution = 1)
sc.tl.umap(adata)
if 'X_umap' not in adata.obsm.keys():
    sc.pp.neighbors(adata, n_neighbors=15, use_rep='X_pca')
    sc.tl.umap(adata)

In [128]:
# Index(['FACS.selection', 'age', 'cell', 'cell_ontology_class',
#        'cell_ontology_id', 'free_annotation', 'method', 'mouse.id', 'sex',
#        'subtissue', 'tissue', 'n_genes', 'n_counts', 'louvain', 'leiden',
#        'batch']
# all indexes.





celltype_df = pd.DataFrame(adata.obs["cell_ontology_class"])
celltype_df = celltype_df.rename(columns={"cell_ontology_class": "celltype"})


#  Extract and display all unique cell types
unique_cell_types = celltype_df['celltype'].unique()
# print("Unique cell types:")
# for cell_type in unique_cell_types:
#     print(cell_type)

# Map cell types to numerical values
le = LabelEncoder()
celltype_df['celltype_encoded'] = le.fit_transform(celltype_df['celltype'])

#  Create and display the mapping
celltype_mapping = dict(zip(le.classes_, le.transform(le.classes_)))
# print("\nCell type to number mapping:")
# for cell_type, number in celltype_mapping.items():
#     print(f"{cell_type}: {number}")




age_df = pd.DataFrame(adata.obs["age"])
age_df = age_df.rename(columns={"age": "age"})

gender_df = pd.DataFrame(adata.obs["sex"])
gender_df = gender_df.rename(columns={"sex": "gender"})
#map male to 0, female to 1
gender_mapping = {'male': 0, 'female': 1}
gender_df['gender'] = gender_df['gender'].map(gender_mapping)



ngenes_df = pd.DataFrame(adata.obs["n_genes"])
ngenes_df = ngenes_df.rename(columns={"n_genes": "n_genes"})






In [23]:
def clean_age(age_df, substring):
    values = []
    for x in age_df["age"]:
        try:
            # Attempt to strip the substring and convert to integer
            value = int(x.strip(substring))
            values.append(value)
        except ValueError:
            # Handle the case where conversion fails
            warnings.warn(f"Warning: '{x}' could not be converted to an integer.")
            break
    age_df["age"] = values
    return age_df

def get_raw_counts(adata, celltype_df):
    raw_count = pd.DataFrame.sparse.from_spmatrix(adata.X.T, 
                                               index = adata.var_names, 
                                               columns = adata.obs_names).astype(int)
    raw_count = raw_count[list(celltype_df.index)]
    return raw_count

In [None]:
cleaned_age_df = clean_age(age_df, "m")


In [None]:
raw_count = get_raw_counts(adata, celltype_df)
raw_count.values

In [None]:
print(raw_count[:10])
print(age_df[:10])
print(gender_df[:10])
print(celltype_df[:10])


In [None]:
# now concatenate all the dataframes together
# Transpose raw_count so that cells are the index
raw_count_T = raw_count.T


cells_in_raw = set(raw_count_T.index)
cells_in_age = set(age_df.index)
cells_in_gender = set(gender_df.index)
cells_in_celltype = set(celltype_df.index)

# Find common cells present in all DataFrames

print("Shape of raw_count_T:", raw_count_T.shape)
print("Shape of age_df:", age_df.shape)
print("Shape of gender_df:", gender_df.shape)
print("Shape of celltype_df:", celltype_df.shape)



In [54]:
# Optionally, filter DataFrames to include only common cells
common_cells = cells_in_raw & cells_in_age & cells_in_gender & cells_in_celltype
common_cells_list = list(common_cells)

# Filter DataFrames to include only common cells
raw_count_T = raw_count_T.loc[common_cells_list]
age_df = age_df.loc[common_cells_list]
gender_df = gender_df.loc[common_cells_list]
celltype_df = celltype_df.loc[common_cells_list]

In [122]:
# Concatenate the DataFrames for deep learning
combined_df = pd.concat([raw_count_T, age_df, gender_df, celltype_df], axis=1)



In [None]:
# Step 1: Ensure DataFrames are properly aligned
def prepare_dataframe(df):
    df.index = df.index.astype(str).str.strip().str.lower()
    df.sort_index(inplace=True)
    return df

raw_count_T = prepare_dataframe(raw_count_T)
age_df = prepare_dataframe(age_df)
gender_df = prepare_dataframe(gender_df)
celltype_df = prepare_dataframe(celltype_df)
combined_df = prepare_dataframe(combined_df)

# Ensure that indices match across all DataFrames
common_indices = raw_count_T.index.intersection(age_df.index).intersection(gender_df.index).intersection(celltype_df.index)
raw_count_T = raw_count_T.loc[common_indices]
age_df = age_df.loc[common_indices]
gender_df = gender_df.loc[common_indices]
celltype_df = celltype_df.loc[common_indices]
combined_df = combined_df.loc[common_indices]

# Step 2: Compare Gene Expression Data
gene_columns = raw_count_T.columns
combined_gene_data = combined_df[gene_columns]
gene_data_matches = combined_gene_data.equals(raw_count_T)
print("Gene expression data matches:", gene_data_matches)

# Step 3: Compare Metadata Columns
age_matches = combined_df['age'].equals(age_df['age'])
print("Age data matches:", age_matches)

gender_matches = combined_df['gender'].equals(gender_df['gender'])
print("Gender data matches:", gender_matches)

celltype_matches = combined_df['celltype'].equals(celltype_df['celltype'])
print("Cell type data matches:", celltype_matches)

# Step 4: Report Overall Match
all_data_matches = gene_data_matches and age_matches and gender_matches and celltype_matches
print("\nOverall data matches:", all_data_matches)

# Step 5: Identify and Report Discrepancies
if not all_data_matches:
    if not gene_data_matches:
        # Identify discrepancies in gene expression data
        gene_diff = (combined_gene_data != raw_count_T)
        cells_with_diff = gene_diff.any(axis=1)
        genes_with_diff = gene_diff.any(axis=0)
        print("\nDiscrepancies found in gene expression data.")
        print(f"Number of cells with discrepancies: {cells_with_diff.sum()}")
        print(f"Number of genes with discrepancies: {genes_with_diff.sum()}")
        # List first few discrepancies
        discrepant_cells = cells_with_diff[cells_with_diff].index[:5]
        for cell in discrepant_cells:
            diff_genes = gene_diff.loc[cell][gene_diff.loc[cell]].index.tolist()
            print(f"Cell '{cell}' has discrepancies in genes: {diff_genes[:5]}")
    
    if not age_matches:
        age_diff = combined_df['age'] != age_df['age']
        discrepant_cells = age_diff[age_diff].index.tolist()
        print("\nDiscrepancies found in age data for cells:", discrepant_cells)
    
    if not gender_matches:
        gender_diff = combined_df['gender'] != gender_df['gender']
        discrepant_cells = gender_diff[gender_diff].index.tolist()
        print("\nDiscrepancies found in gender data for cells:", discrepant_cells)
    
    if not celltype_matches:
        celltype_diff = combined_df['celltype'] != celltype_df['celltype']
        discrepant_cells = celltype_diff[celltype_diff].index.tolist()
        print("\nDiscrepancies found in cell type data for cells:", discrepant_cells)
else:
    print("\nAll data in combined_df matches the original DataFrames.")
