In [1]:
import pandas as pd
import numpy as np
import scanpy as sc
from sklearn.linear_model import LinearRegression
from collections import Counter
from scipy import stats
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
from collections import defaultdict
import math
from pathlib import Path
import os
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
import torch
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.utils.class_weight import compute_class_weight
from torch.utils.data import TensorDataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from sklearn.preprocessing import label_binarize
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
from sklearn.preprocessing import label_binarize
from sklearn.metrics import r2_score



In [2]:
## optional
## delete celltypes with cell counts lower than certain amount
## specifing an lower bond is a must
def filter_low_counts(celltype_df, age_df, celltype_col, threshold):
    print("Checking low count cell types...")
    
    celltype_count = Counter(celltype_df[celltype_col])
    for key in celltype_count:
        if threshold == None:
            unique_ages = np.unique(age_df)
            num_groups = (len(unique_ages) + 1) * 100
            if celltype_count[key] < num_groups:
                print(key, " has too low counts")
                celltype_df = celltype_df[celltype_df[celltype_col] != key]
        else:
            if celltype_count[key] < threshold:
                print(key, " has too low counts")
                celltype_df = celltype_df[celltype_df[celltype_col] != key]
    return celltype_df

def get_skewed_count_info(adata, class_col, age_col, age_threshold):
    print("Checking skewed count cell types...")
    
    # Compute the fraction of cells for each age group within each cell ontology class
    group_counts = adata.obs.groupby([class_col, age_col]).size()
    total_counts = adata.obs.groupby([class_col]).size()
    
    # Calculate the fraction of each age group within each class
    class_age_fraction = group_counts / total_counts
    
    # Find the cell classes to filter out based on age distribution
    classes_to_filter = class_age_fraction[class_age_fraction > age_threshold].index.get_level_values(0).unique()
    
    return classes_to_filter

## Read h5ad file 
## and do cell type filtering based on age distribution and cell count thresholds.
def read_and_filter_h5ad(filepath_1, filepath_2 = None, class_col="celltype", age_col="age", age_threshold=0.8, count_threshold=None):
    """Parameters:
    adata: AnnData object
        The Scanpy AnnData object containing single-cell data.
    class_col: str, optional (default: 'celltype')
        The column name in adata.obs representing the cell ontology class.
    age_col: str, optional (default: 'age')
        The column name in adata.obs representing the age of the cells.
    age_threshold: float, optional (default: 0.8)
        The threshold fraction for filtering based on age distribution. If one age group has more than this
        fraction of cells in a class, the class will be filtered out.
    count_threshold: list, optional (default: [100])
        Threshold for filtering cell types based on count. If a single value is provided,
        it filters out cell types with counts lower than this value. If a range is provided,
        it filters out cell types outside this range.
    
    Returns:
    filtered_adata: AnnData object
        The filtered AnnData object with specified cell ontology classes removed based on both criteria."""
    try:
        adata1 = sc.read_h5ad(filepath_1)
        if filepath_2 != None:
            adata2 = sc.read_h5ad(filepath_2)
            adata1 = adata1.concatenate(adata2)
        adata = adata1
        
        celltype_df = adata.obs[[class_col]].copy()
        age_df = adata.obs[[age_col]].copy()
        
        # Apply the cell count threshold filtering
        celltype_df = filter_low_counts(celltype_df, age_df, class_col, count_threshold)
    
        # Create a filtered AnnData object based on cell count filtering
        filtered_adata = adata[celltype_df.index].copy()
        
        # Identify the skewed classes to filter based on age distribution
        classes_to_filter = get_skewed_count_info(filtered_adata, class_col, age_col, age_threshold)
        
        if len(classes_to_filter):
            print(classes_to_filter[0], " has skewed cell counts")
        # Further filter the AnnData object based on age distribution
        final_filtered_adata = filtered_adata[~filtered_adata.obs[class_col].isin(classes_to_filter)].copy()
        
        return final_filtered_adata
    except Exception as e:
        raise(e)

In [8]:
current_dir = Path.cwd()
print(f"Current working directory: {current_dir}")
file1 = current_dir / "tabula-muris-senis-facs-processed-official-annotations-Brain_Myeloid.h5ad"
file2 = None
print(f"File 1 path: {file1}")
assert file1.is_file(), f"File not found: {file1}"
adata = read_and_filter_h5ad(str(file1), None, "cell_ontology_class", "age")


Current working directory: /home/hang/SC_Ageing_Prediction
File 1 path: /home/hang/SC_Ageing_Prediction/tabula-muris-senis-facs-processed-official-annotations-Brain_Myeloid.h5ad



This is where adjacency matrices should go now.
  warn(

This is where adjacency matrices should go now.
  warn(


Checking low count cell types...
macrophage  has too low counts
Checking skewed count cell types...


  group_counts = adata.obs.groupby([class_col, age_col]).size()
  total_counts = adata.obs.groupby([class_col]).size()


In [4]:
adata

AnnData object with n_obs × n_vars = 13130 × 22966
    obs: 'FACS.selection', 'age', 'cell', 'cell_ontology_class', 'cell_ontology_id', 'free_annotation', 'method', 'mouse.id', 'sex', 'subtissue', 'tissue', 'n_genes', 'n_counts', 'louvain', 'leiden'
    var: 'n_cells', 'means', 'dispersions', 'dispersions_norm', 'highly_variable'
    uns: 'age_colors', 'cell_ontology_class_colors', 'leiden', 'louvain', 'neighbors', 'pca'
    obsm: 'X_pca', 'X_tsne', 'X_umap'
    varm: 'PCs'
    obsp: 'distances', 'connectivities'

In [12]:
#print out first 5 rows of the data
adata.obs["cell"][1]
print(adata.obs["cell"][1])

  adata.obs["cell"][1]


'A10_B001061'