In [3]:
import pandas as pd
import numpy as np
import scanpy as sc
from sklearn.linear_model import LinearRegression
from collections import Counter
from scipy import stats
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
from collections import defaultdict
import math
from pathlib import Path
import os
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
import torch
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.utils.class_weight import compute_class_weight
from torch.utils.data import TensorDataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from sklearn.preprocessing import label_binarize
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
from sklearn.preprocessing import label_binarize
from sklearn.metrics import r2_score
import csv

In [4]:
## optional
## delete celltypes with cell counts lower than certain amount
## specifing an lower bond is a must
def filter_low_counts(celltype_df, age_df, celltype_col, threshold):
    print("Checking low count cell types...")
    
    celltype_count = Counter(celltype_df[celltype_col])
    for key in celltype_count:
        if threshold == None:
            unique_ages = np.unique(age_df)
            num_groups = (len(unique_ages) + 1) * 100
            if celltype_count[key] < num_groups:
                print(key, " has too low counts")
                celltype_df = celltype_df[celltype_df[celltype_col] != key]
        else:
            if celltype_count[key] < threshold:
                print(key, " has too low counts")
                celltype_df = celltype_df[celltype_df[celltype_col] != key]
    return celltype_df

def get_skewed_count_info(adata, class_col, age_col, age_threshold):
    print("Checking skewed count cell types...")
    
    # Compute the fraction of cells for each age group within each cell ontology class
    group_counts = adata.obs.groupby([class_col, age_col]).size()
    total_counts = adata.obs.groupby([class_col]).size()
    
    # Calculate the fraction of each age group within each class
    class_age_fraction = group_counts / total_counts
    
    # Find the cell classes to filter out based on age distribution
    classes_to_filter = class_age_fraction[class_age_fraction > age_threshold].index.get_level_values(0).unique()
    
    return classes_to_filter

## Read h5ad file 
## and do cell type filtering based on age distribution and cell count thresholds.
def read_and_filter_h5ad(filepath_1, filepath_2 = None, class_col="celltype", age_col="age", age_threshold=0.8, count_threshold=None):
    """Parameters:
    adata: AnnData object
        The Scanpy AnnData object containing single-cell data.
    class_col: str, optional (default: 'celltype')
        The column name in adata.obs representing the cell ontology class.
    age_col: str, optional (default: 'age')
        The column name in adata.obs representing the age of the cells.
    age_threshold: float, optional (default: 0.8)
        The threshold fraction for filtering based on age distribution. If one age group has more than this
        fraction of cells in a class, the class will be filtered out.
    count_threshold: list, optional (default: [100])
        Threshold for filtering cell types based on count. If a single value is provided,
        it filters out cell types with counts lower than this value. If a range is provided,
        it filters out cell types outside this range.
    
    Returns:
    filtered_adata: AnnData object
        The filtered AnnData object with specified cell ontology classes removed based on both criteria."""
    try:
        adata1 = sc.read_h5ad(filepath_1)
        if filepath_2 != None:
            adata2 = sc.read_h5ad(filepath_2)
            adata1 = adata1.concatenate(adata2)
        adata = adata1
        
        celltype_df = adata.obs[[class_col]].copy()
        age_df = adata.obs[[age_col]].copy()
        
        # Apply the cell count threshold filtering
        celltype_df = filter_low_counts(celltype_df, age_df, class_col, count_threshold)
    
        # Create a filtered AnnData object based on cell count filtering
        filtered_adata = adata[celltype_df.index].copy()
        
        # Identify the skewed classes to filter based on age distribution
        classes_to_filter = get_skewed_count_info(filtered_adata, class_col, age_col, age_threshold)
        
        if len(classes_to_filter):
            print(classes_to_filter[0], " has skewed cell counts")
        # Further filter the AnnData object based on age distribution
        final_filtered_adata = filtered_adata[~filtered_adata.obs[class_col].isin(classes_to_filter)].copy()
        
        return final_filtered_adata
    except Exception as e:
        raise(e)

In [5]:
current_dir = Path.cwd()
print(f"Current working directory: {current_dir}")
file1 = current_dir / "tabula-muris-senis-facs-processed-official-annotations-Brain_Myeloid.h5ad"
file2 = None
print(f"File 1 path: {file1}")
assert file1.is_file(), f"File not found: {file1}"
adata = read_and_filter_h5ad(str(file1), None, "cell_ontology_class", "age")


Current working directory: /home/hang/SC_Ageing_Prediction
File 1 path: /home/hang/SC_Ageing_Prediction/tabula-muris-senis-facs-processed-official-annotations-Brain_Myeloid.h5ad



This is where adjacency matrices should go now.
  warn(

This is where adjacency matrices should go now.
  warn(


Checking low count cell types...
macrophage  has too low counts
Checking skewed count cell types...


  group_counts = adata.obs.groupby([class_col, age_col]).size()
  total_counts = adata.obs.groupby([class_col]).size()


In [6]:
adata

AnnData object with n_obs × n_vars = 13130 × 22966
    obs: 'FACS.selection', 'age', 'cell', 'cell_ontology_class', 'cell_ontology_id', 'free_annotation', 'method', 'mouse.id', 'sex', 'subtissue', 'tissue', 'n_genes', 'n_counts', 'louvain', 'leiden'
    var: 'n_cells', 'means', 'dispersions', 'dispersions_norm', 'highly_variable'
    uns: 'age_colors', 'cell_ontology_class_colors', 'leiden', 'louvain', 'neighbors', 'pca'
    obsm: 'X_pca', 'X_tsne', 'X_umap'
    varm: 'PCs'
    obsp: 'distances', 'connectivities'

In [7]:
#print out first 5 rows of the data
print(adata.obs["n_genes"].head())

index
A10_B001060_B009250_S214.mm10-plus-1-0    1505
A10_B001061_B009251_S298.mm10-plus-1-0    2384
A10_B002503_B009456_S10.mm10-plus-1-0      599
A10_B002702_B009296_S154.mm10-plus-1-0    1547
A10_D045853_B009304_S106.mm10-plus-1-0     931
Name: n_genes, dtype: int64


In [None]:
# X = adata.X.toarray() if not isinstance(adata.X, np.ndarray) else adata.X

# # Round values to nearest int and ensure non-negative
# X = np.round(X).astype(int)
# X[X < 0] = 0  # just in case

# # Get gene names
# genes = np.array(adata.var_names)

# # Prepare the list of gene strings
# gene_strings = []
# for row in X:
#     repeated_genes = np.repeat(genes, row)
#     gene_string = " ".join(repeated_genes)
#     gene_strings.append(gene_string)

# # Get relevant metadata
# df_meta = adata.obs[['age', 'sex', 'cell_ontology_class', 'tissue']].copy()
# df_meta.columns = ['age', 'gender', 'cell_ontology_class', 'tissue']

# # Create final DataFrame
# df_final = pd.DataFrame({
#     'genes': gene_strings,
#     'age': df_meta['age'].values,
#     'gender': df_meta['gender'].values,
#     'cell_ontology_class': df_meta['cell_ontology_class'].values,
#     'tissue': df_meta['tissue'].values
# })

# # Set index as 0 to n-1
# df_final.index = range(df_final.shape[0])

# # Optional: Save to CSV
# df_final.to_csv("processed_cells.csv", index_label="index")

# # Return final DataFrame
# df_final



## too memory intensive 
## write to file directly

In [8]:
# Prepare output file
output_file = "processed_cells_streamed.csv"

# Get gene names
genes = np.array(adata.var_names)

# Get expression matrix (dense row-by-row)
X = adata.X

# Prepare metadata
ages = adata.obs["age"].values
genders = adata.obs["sex"].values
classes = adata.obs["cell_ontology_class"].values
tissues = adata.obs["tissue"].values

# Open file and write line-by-line
with open(output_file, mode='w', newline='', encoding='utf-8') as f:
    writer = csv.writer(f)
    # Write header
    writer.writerow(["index", "genes", "age", "gender", "cell_ontology_class", "tissue"])

    for i in range(adata.n_obs):
        # Get row i as dense array
        row = X[i].toarray().flatten() if not isinstance(X, np.ndarray) else X[i]
        
        # Convert to int and clip negatives
        row = np.maximum(np.round(row).astype(int), 0)

        # Efficiently repeat gene names
        repeated_genes = np.repeat(genes, row)
        gene_string = " ".join(repeated_genes)

        # Write row to file
        writer.writerow([
            i,
            gene_string,
            ages[i],
            genders[i],
            classes[i],
            tissues[i]
        ])
f.close()

In [None]:
import numpy as np
import json
import math
import os

# Gene names and data
genes = np.array(adata.var_names)
X = adata.X

ages = adata.obs["age"].values
genders = adata.obs["sex"].values
classes = adata.obs["cell_ontology_class"].values
tissues = adata.obs["tissue"].values

# Split config
n_cells = adata.n_obs
n_splits = 10
split_size = math.ceil(n_cells / n_splits)

# Create output directory
output_dir = "fine_tune_chunks"
os.makedirs(output_dir, exist_ok=True)



# --------------------------------------------
# Constants for Gemma-3 chat template
BOS = "<bos>"
START = "<start_of_turn>"
END = "<end_of_turn>"



# Writing split files – Gemma-3 format
for split_index in range(n_splits):
    start = split_index * split_size
    end = min((split_index + 1) * split_size, n_cells)

    filename = os.path.join(output_dir, f"cell_data_part_{split_index+1}.json")
    with open(filename, "w", encoding="utf-8") as f:
        for i in range(start, end):
            # gene bag-of-words, as before
            row = X[i].toarray().flatten() if not isinstance(X, np.ndarray) else X[i]
            row = np.maximum(np.round(row).astype(int), 0)
            gene_string = " ".join(np.repeat(genes, row))

            # user / model messages
            user_msg = (
                f"I have a cell, its genes are {gene_string}. "
                f"Its gender is {genders[i]}. "
                f"Its class is {classes[i]}. "
                f"Its tissue is {tissues[i]}."
            )
            model_msg = f"The age of the cell should be {ages[i]}"

            # Gemma-3 conversation string
            text = (
                f"{BOS}\n"
                f"{START}user\n{user_msg}{END}\n"
                f"{START}model\n{model_msg}{END}"
            )

            # one-line JSONL record
            json.dump({"text": text}, f, ensure_ascii=False)
            f.write("\n")
