In [1]:
import pandas as pd
import numpy as np
import scanpy as sc
from sklearn.linear_model import LinearRegression
from collections import Counter
from scipy import stats
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
from collections import defaultdict
import math
from pathlib import Path
import os
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
# import torch
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.utils.class_weight import compute_class_weight
# from torch.utils.data import TensorDataset, DataLoader
# import torch.nn as nn
# import torch.nn.functional as F
from sklearn.preprocessing import label_binarize
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
from sklearn.preprocessing import label_binarize
from sklearn.metrics import r2_score
import csv

In [2]:
import pandas as pd
import numpy as np
import scanpy as sc
from collections import Counter


def _filter_low_counts(celltype_df, age_df, celltype_col, threshold):
    print("Checking low count cell types...")
    
    celltype_count = Counter(celltype_df[celltype_col])
    for key in celltype_count:
        if threshold == None:
            unique_ages = np.unique(age_df)
            num_groups = (len(unique_ages) + 1) * 100
            if celltype_count[key] < num_groups:
                print(key, " has too low counts")
                celltype_df = celltype_df[celltype_df[celltype_col] != key]
        else:
            if celltype_count[key] < threshold:
                print(key, " has too low counts")
                celltype_df = celltype_df[celltype_df[celltype_col] != key]
    return celltype_df

def _get_skewed_count_info(adata, class_col, age_col, age_threshold):
    print("Checking skewed count cell types...")
    
    # Compute the fraction of cells for each age group within each cell ontology class
    group_counts = adata.obs.groupby([class_col, age_col]).size()
    total_counts = adata.obs.groupby([class_col]).size()
    
    # Calculate the fraction of each age group within each class
    class_age_fraction = group_counts / total_counts
    
    # Find the cell classes to filter out based on age distribution
    classes_to_filter = class_age_fraction[class_age_fraction > age_threshold].index.get_level_values(0).unique()
    
    return classes_to_filter
    


def read_and_filter_h5ad(filepath, class_col="celltype", age_col="age", filter_gender=True, gender="male", age_threshold=0.8, count_threshold=None):
    """Parameters:
    filepath: path to AnnData object
        The Scanpy AnnData object containing single-cell data.
    class_col: str, optional (default: 'celltype')
        The column name in adata.obs representing the cell ontology class.
    age_col: str, optional (default: 'age')
        The column name in adata.obs representing the age of the cells.
    filter_gender: boolen, optional (default: True)
        Whether you would like to filter a gender out
    gender: str, optional only if the filter_gender is True(default: "male")
        Choose which gender to keep
    age_threshold: float, optional (default: 0.8)
        The threshold fraction for filtering based on age distribution. If one age group has more than this
        fraction of cells in a class, the class will be filtered out.
    count_threshold: int, optional (default: None(100))
        Lower threshold for filtering cell types based on count.
 
    Returns:
    filtered_adata: AnnData object
        The filtered AnnData object with specified cell ontology classes removed based on both criteria."""
    try:
        adata = sc.read_h5ad(filepath)
        
        if filter_gender:
             filtered_adata = adata[adata.obs["sex"] == gender, :].copy()
                
        celltype_df = filtered_adata.obs[[class_col]].copy()
        age_df = filtered_adata.obs[[age_col]].copy()
        
        # Apply the cell count threshold filtering
        celltype_df = _filter_low_counts(celltype_df, age_df, class_col, count_threshold)
    
        # Create a filtered AnnData object based on cell count filtering
        filtered_adata = filtered_adata[celltype_df.index].copy()
        
        # Identify the skewed classes to filter based on age distribution
        classes_to_filter = _get_skewed_count_info(filtered_adata, class_col, age_col, age_threshold)
        
        if len(classes_to_filter):
            print(classes_to_filter[0], " has skewed cell counts")
            
        # Further filter the AnnData object based on age distribution
        final_filtered_adata = filtered_adata[~filtered_adata.obs[class_col].isin(classes_to_filter)].copy()
        
        return final_filtered_adata
    except Exception as e:
        raise(e)

In [3]:
# current_dir = Path.cwd()
# print(f"Current working directory: {current_dir}")
# file1 = "/Users/chanyue/Desktop/Pellegrini_Lab/Aging/Mouse_Tabula_Muris/tabula-muris-senis-facs-processed-official-annotations-Brain_combined.h5ad"
# file1 = "/Users/chanyue/Downloads/AsianImmune_scAging.h5ad"
file1 = "/Users/chanyue/Desktop/Pellegrini_Lab/Aging/Mouse_Tabula_Muris/Limb_Muscle_droplet.h5ad"
file2 = None
print(f"File 1 path: {file1}")
# assert file1.is_file(), f"File not found: {file1}"

File 1 path: /Users/chanyue/Desktop/Pellegrini_Lab/Aging/Mouse_Tabula_Muris/Limb_Muscle_droplet.h5ad


In [6]:
# adata = read_and_filter_h5ad(file1, "cell_type", "development_stage")
adata = read_and_filter_h5ad(file1, "cell_ontology_class", "age")


This is where adjacency matrices should go now.
  warn(

This is where adjacency matrices should go now.
  warn(


Checking low count cell types...
Schwann cell  has too low counts
skeletal muscle cell  has too low counts
Checking skewed count cell types...


In [10]:
adata.obs["sex"]

index
AAACCTGGTATTCTCT-1-34-0-0    male
AAACCTGTCTGATACG-1-34-0-0    male
AAACGGGCAAAGCAAT-1-34-0-0    male
AAACGGGCAGCTCGAC-1-34-0-0    male
AAACGGGCATATGGTC-1-34-0-0    male
                             ... 
TTTGGTTTCGTGGACC-1-49-1-0    male
TTTGTCAAGCTTCGCG-1-49-1-0    male
TTTGTCAGTACTCTCC-1-49-1-0    male
TTTGTCAGTATCTGCA-1-49-1-0    male
TTTGTCATCTGTTGAG-1-49-1-0    male
Name: sex, Length: 17396, dtype: category
Categories (1, object): ['male']

In [11]:
sc.pp.filter_genes(adata, min_cells=3)
sc.pp.filter_cells(adata, min_genes=200)

In [12]:
sc.pp.normalize_total(adata, target_sum=1e4, inplace=True)

In [13]:
sc.pp.log1p(adata)

In [14]:
sc.pp.highly_variable_genes(adata, min_mean=0.0125, max_mean=3, min_disp=0.5)
adata = adata[:, adata.var.highly_variable]

In [15]:
adata

View of AnnData object with n_obs × n_vars = 17396 × 4568
    obs: 'age', 'batch', 'cell', 'cell_ontology_class', 'cell_ontology_id', 'free_annotation', 'method', 'mouse.id', 'n_genes', 'sex', 'subtissue', 'tissue', 'tissue_free_annotation', 'n_counts', 'louvain', 'cluster_names'
    var: 'n_cells', 'means', 'dispersions', 'dispersions_norm', 'highly_variable'
    uns: 'cluster_names_colors', 'louvain', 'neighbors', 'pca', 'rank_genes_groups', 'log1p', 'hvg'
    obsm: 'X_pca', 'X_umap', 'X_tsne'
    varm: 'PCs'
    obsp: 'distances', 'connectivities'

In [16]:
#print out first 5 rows of the data
print(adata.obs["n_genes"].head())

index
AAACCTGGTATTCTCT-1-34-0-0    1891
AAACCTGTCTGATACG-1-34-0-0    1004
AAACGGGCAAAGCAAT-1-34-0-0    1323
AAACGGGCAGCTCGAC-1-34-0-0    1351
AAACGGGCATATGGTC-1-34-0-0    1451
Name: n_genes, dtype: int64


In [17]:
# X = adata.X.toarray() if not isinstance(adata.X, np.ndarray) else adata.X

# # Round values to nearest int and ensure non-negative
# X = np.round(X).astype(int)
# X[X < 0] = 0  # just in case

# # Get gene names
# genes = np.array(adata.var_names)

# # Prepare the list of gene strings
# gene_strings = []
# for row in X:
#     repeated_genes = np.repeat(genes, row)
#     gene_string = " ".join(repeated_genes)
#     gene_strings.append(gene_string)

# # Get relevant metadata
# df_meta = adata.obs[['age', 'sex', 'cell_ontology_class', 'tissue']].copy()
# df_meta.columns = ['age', 'gender', 'cell_ontology_class', 'tissue']

# # Create final DataFrame
# df_final = pd.DataFrame({
#     'genes': gene_strings,
#     'age': df_meta['age'].values,
#     'gender': df_meta['gender'].values,
#     'cell_ontology_class': df_meta['cell_ontology_class'].values,
#     'tissue': df_meta['tissue'].values
# })

# # Set index as 0 to n-1
# df_final.index = range(df_final.shape[0])

# # Optional: Save to CSV
# df_final.to_csv("processed_cells.csv", index_label="index")

# # Return final DataFrame
# df_final



## too memory intensive 
## write to file directly

In [18]:
# # Prepare output file
# output_file = "processed_cells_streamed.csv"

# # Get gene names
# genes = np.array(adata.var_names)

# # Get expression matrix (dense row-by-row)
# X = adata.X

# # Prepare metadata
# ages = adata.obs["age"].values
# genders = adata.obs["sex"].values
# classes = adata.obs["cell_ontology_class"].values
# tissues = adata.obs["tissue"].values

# # Open file and write line-by-line
# with open(output_file, mode='w', newline='', encoding='utf-8') as f:
#     writer = csv.writer(f)
#     # Write header
#     writer.writerow(["index", "genes", "age", "gender", "cell_ontology_class", "tissue"])

#     for i in range(adata.n_obs):
#         # Get row i as dense array
#         row = X[i].toarray().flatten() if not isinstance(X, np.ndarray) else X[i]
        
#         # Convert to int and clip negatives
#         row = np.maximum(np.round(row).astype(int), 0)

#         # Efficiently repeat gene names
#         repeated_genes = np.repeat(genes, row)
#         gene_string = " ".join(repeated_genes)

#         # Write row to file
#         writer.writerow([
#             i,
#             gene_string,
#             ages[i],
#             genders[i],
#             classes[i],
#             tissues[i]
#         ])
# f.close()

In [19]:
from collections import Counter
import numpy as np
import json, math, os, pathlib
from sklearn.model_selection import StratifiedKFold

# ─── Pull data from AnnData ─────────────────────────────────────────────────────
genes   = np.asarray(adata.var_names)
X       = adata.X
ages    = adata.obs["age"].values
genders = adata.obs["sex"].values
classes = adata.obs["cell_ontology_class"].values
tissues = adata.obs["tissue"].values

print(Counter(ages))

Counter({'24m': 6167, '1m': 5020, '18m': 3740, '30m': 2469})


In [20]:
# # ─── Split config ───────────────────────────────────────────────────────────────
# N_SPLITS   = 11
# SPLIT_SIZE = math.ceil(adata.n_obs / N_SPLITS)
# OUT_DIR    = "fine_tune_chunks"
# os.makedirs(OUT_DIR, exist_ok=True)

# INSTRUCTION = "Predict the age of a single cell from gene expression and metadata."

# def bag_of_words(counts_row: np.ndarray, gene_names: np.ndarray) -> str:
#     """Convert a vector of counts to space-separated tokens."""
#     counts_row = np.maximum(np.round(counts_row).astype(int), 0)
#     return " ".join(np.repeat(gene_names, counts_row))

# # ─── Build chunk files in **Alpaca** format ─────────────────────────────────────

# skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=42)
# for split_idx, (_, split_indices) in enumerate(skf.split(np.zeros(len(ages)), ages)):
#     records = []
#     for i in split_indices:
#         row = X[i].toarray().ravel() if not isinstance(X, np.ndarray) else X[i]
#         cell_input = (
#             f"Genes: {bag_of_words(row, genes)}\n"
#             f"Gender: {genders[i]}\n"
#             f"Class: {classes[i]}\n"
#             f"Tissue: {tissues[i]}"
#         )

#         records.append({
#             "instruction": INSTRUCTION,
#             "input": cell_input,
#             "output": str(ages[i])
#         })
#         break
# #     fname = os.path.join(OUT_DIR, f"cell_data_part_{split_idx+1}.json")
# #     with open(fname, "w", encoding="utf-8") as f:
# #         json.dump(records, f, indent=2, ensure_ascii=False)

# #     print(f"Wrote {len(records):>5} samples → {fname}")


In [21]:
import inflect

p = inflect.engine()

In [23]:
# ─── Split config ───────────────────────────────────────────────────────────────
N_SPLITS   = 11
SPLIT_SIZE = math.ceil(adata.n_obs / N_SPLITS)
OUT_DIR    = "fine_tune_chunks"
os.makedirs(OUT_DIR, exist_ok=True)

INSTRUCTION = "Predict the age of a single cell from gene expression and metadata."

def bag_of_words(counts_row: np.ndarray, gene_names: np.ndarray) -> str:
    """Convert a vector of counts to 'gene english(count)' tokens separated by space."""
    counts_row = np.maximum(np.round(counts_row).astype(int), 0)
    tokens = [f"{gene} {p.number_to_words(count)}"
        for gene, count in zip(gene_names, counts_row)
        if count > 0]
    return " ".join(tokens)

# ─── Build chunk files in **Alpaca** format ─────────────────────────────────────
# check length 
lengths = []
skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=42)
for split_idx, (_, split_indices) in enumerate(skf.split(np.zeros(len(ages)), ages)):
    records = []
    for i in split_indices:
        row = X[i].toarray().ravel() if not isinstance(X, np.ndarray) else X[i]
        lengths.append(len(bag_of_words(row, genes)))
        cell_input = (
            f"Genes: {bag_of_words(row, genes)}\n"
            f"Gender: {genders[i]}\n"
            f"Class: {classes[i]}\n"
            f"Tissue: {tissues[i]}"
        )
    
        records.append({
            "instruction": INSTRUCTION,
            "input": cell_input,
            "output": str(ages[i])
        })
    fname = os.path.join(OUT_DIR, f"limb-muscle_cell_data_part_{split_idx+1}.json")
    with open(fname, "w", encoding="utf-8") as f:
        json.dump(records, f, indent=2, ensure_ascii=False)

    print(f"Wrote {len(records):>5} samples → {fname}")

Wrote  1582 samples → fine_tune_chunks/limb-muscle_cell_data_part_1.json
Wrote  1582 samples → fine_tune_chunks/limb-muscle_cell_data_part_2.json
Wrote  1582 samples → fine_tune_chunks/limb-muscle_cell_data_part_3.json
Wrote  1582 samples → fine_tune_chunks/limb-muscle_cell_data_part_4.json
Wrote  1582 samples → fine_tune_chunks/limb-muscle_cell_data_part_5.json
Wrote  1581 samples → fine_tune_chunks/limb-muscle_cell_data_part_6.json
Wrote  1581 samples → fine_tune_chunks/limb-muscle_cell_data_part_7.json
Wrote  1581 samples → fine_tune_chunks/limb-muscle_cell_data_part_8.json
Wrote  1581 samples → fine_tune_chunks/limb-muscle_cell_data_part_9.json
Wrote  1581 samples → fine_tune_chunks/limb-muscle_cell_data_part_10.json
Wrote  1581 samples → fine_tune_chunks/limb-muscle_cell_data_part_11.json


In [24]:
max(lengths), min(lengths)

(10872, 819)