In [None]:
 %cd /sci/labs/yotamd/lab_share/avishai.wizel/eRNA/

In [None]:
import anndata as ad
import pandas as pd
from scipy.sparse import csr_matrix
import scanpy as sc
import numpy as np
from sklearn.model_selection import train_test_split
from typing import Tuple, Union
from scipy.sparse import issparse
from sklearn.metrics import roc_auc_score

In [None]:
sc_rna = ad.read_h5ad('./10X_PBMC/03_filtered_data/filtered_rna_adata.h5ad')
sc_atac = ad.read_h5ad("./10X_PBMC/03_filtered_data/filtered_atac_adata.h5ad")

In [None]:
# 1. Size Normalization
sc.pp.normalize_total(sc_rna, target_sum=1e4)

# 2. Log-transformation (log1p)
sc.pp.log1p(sc_rna)

# 3. Standardization (Scale)
sc.pp.scale(sc_rna, max_value=10)

In [None]:
def split_anndata_xy_train_test(
    adata_x: ad.AnnData,
    adata_y: ad.AnnData,
    test_size: float = 0.25,
    random_state: Union[int, None] = None,
    stratify: Union[np.ndarray, pd.Series, None] = None
) -> Tuple[ad.AnnData, ad.AnnData, ad.AnnData, ad.AnnData]:
    """
    Splits two AnnData objects (X and Y, typically RNA and ATAC) into training and testing sets
    based on matching cell observations. Ensures cells are split identically across both AnnData objects.

    Args:
        adata_x (anndata.AnnData):
            The AnnData object representing the features (e.g., RNA). Cells are assumed to be observations (rows).
        adata_y (anndata.AnnData):
            The AnnData object representing the targets/labels (e.g., ATAC). Cells are assumed to be observations (rows).
            Must have the same cells in the same order as adata_x.
        test_size (float):
            The proportion of the dataset to include in the test split.
        random_state (int, optional):
            Controls the shuffling for reproducibility.
        stratify (array-like, optional):
            Class labels for stratified splitting, aligned with adata_x.obs_names.

    Returns:
        Tuple[ad.AnnData, ad.AnnData, ad.AnnData, ad.AnnData]:
            A tuple containing (adata_x_train, adata_x_test, adata_y_train, adata_y_test) AnnData objects.
    """
    
    # Ensure both AnnData objects have the same cells and in the same order
    if not np.array_equal(adata_x.obs_names, adata_y.obs_names):
        raise ValueError("adata_x and adata_y must have identical cell (obs) names and order.")

    cell_barcodes = adata_x.obs_names.to_numpy()

    # Handle stratification if provided (convert Series to numpy array if needed)
    if stratify is not None and isinstance(stratify, pd.Series):
        stratify = stratify.values

    # Perform the split on the cell barcodes
    train_cells, test_cells = train_test_split(
        cell_barcodes,
        test_size=test_size,
        random_state=random_state,
        stratify=stratify
    )

    # Create new AnnData objects for X and Y using the split cell barcodes
    adata_x_train = adata_x[train_cells, :].copy()
    adata_x_test = adata_x[test_cells, :].copy()

    adata_y_train = adata_y[train_cells, :].copy()
    adata_y_test = adata_y[test_cells, :].copy()

    return adata_x_train, adata_x_test, adata_y_train, adata_y_test


In [None]:
# Split the AnnData objects
rna_train, rna_test, atac_train, atac_test = split_anndata_xy_train_test(
    sc_rna,
    sc_atac,
    test_size=0.2, # 20% for testing
    random_state=42
)

In [None]:
import os
import sys

import numpy as np
import pandas as pd
from sklearn.neighbors import KNeighborsRegressor
from sklearn import metrics
from scipy import sparse

import anndata as ad



class KNNRegressor(object):
    """
    Thin wrapper on top of the standard sklearn KNN regressor
    Key changes:
    - If k = 0, then use all training data
    - Take AnnData objects and check var name compatibility first
    """

    def __init__(self, k: int = 0):
        """Initialize the model"""
        assert k >= 0
        self.k = k
        self.model = None
        self.x_var_names = []
        self.y_var_names = []

        self.x = None
        self.y = None

    def fit(self, x: ad.AnnData, y: ad.AnnData) -> None:
        """Fit the model"""
        assert np.all(x.obs_names == y.obs_names), "Mismatched obs names"
        assert len(x.shape) == len(y.shape) == 2
        self.x_var_names = x.var_names
        self.y_var_names = y.var_names

        self.x = x.X
        self.y = y.X
        self.y_mean = y.X.mean(axis=0)

    def predict(self, x: ad.AnnData) -> ad.AnnData:
        """Predict"""
        assert np.all(x.var_names == self.x_var_names)
        if self.k > 0:
            # Brute force pairwise ditances in parallel
            pairwise_distances = metrics.pairwise_distances(x.X, self.x, n_jobs=-1)
            # argsort across each row
            # https://numpy.org/doc/stable/reference/generated/numpy.argsort.html
            pairwise_dist_idx = np.argsort(pairwise_distances, axis=1)[:, : self.k]
            # Get each query's closest cells and average
            pred_values = []
            if isinstance(self.x, sparse.csr_matrix):
                for i in range(pairwise_distances.shape[0]):
                    closest_cells = sparse.vstack(
                        [self.y.getrow(j) for j in pairwise_dist_idx[i]]
                    )
                    # Uniform weighting of all points in neighborhood
                    pred_values.append(closest_cells.mean(axis=0))
            else:
                for i in range(pairwise_distances.shape[0]):
                    closest_cells = self.y[pairwise_dist_idx[i]]
                    pred_values.append(closest_cells.mean(axis=0))
            pred_values = np.stack(pred_values)

        else:
            pred_values = np.stack([self.y_mean.copy() for _i in range(x.n_obs)])

        retval = ad.AnnData(
            pred_values,
            obs=pd.DataFrame(index=x.obs_names),
            var=pd.DataFrame(index=self.y_var_names),
        )
        return retval

In [None]:
model  = KNNRegressor(k=10)
model.fit(rna_train, atac_train)

In [None]:
first_10_predict = model.predict(rna_test)
print(first_10_predict)

In [None]:
def calculate_atac_auroc(
    adata_atac_true: ad.AnnData,
    adata_atac_predicted: ad.AnnData
) -> float:
    """
    Calculates the Area Under the Receiver Operating Characteristic (AUROC) curve
    between true ATAC-seq profiles and predicted ATAC-seq profiles.

    Assumes:
    - Both AnnData objects have the same cells and peaks in the same order.
    - adata_atac_true contains binary (0 or 1) true ATAC values.
    - adata_atac_predicted contains continuous prediction scores (e.g., from KNN averaging).

    Args:
        adata_atac_true (anndata.AnnData):
            AnnData object containing the true ATAC-seq profiles (cells x peaks).
            Expected to be binary (0 or 1).
        adata_atac_predicted (anndata.AnnData):
            AnnData object containing the predicted ATAC-seq profiles (cells x peaks).
            Expected to contain continuous scores.

    Returns:
        float: The global AUROC score. Returns NaN if computation fails (e.g., no positive samples).
               A value close to 0.5 indicates random prediction, 1.0 is perfect prediction.
    """
    
    # 1. Ensure dimensions match and data is in numpy array format
    if adata_atac_true.shape != adata_atac_predicted.shape:
        raise ValueError("Shapes of true and predicted ATAC AnnData objects do not match.")
    if not np.array_equal(adata_atac_true.obs_names, adata_atac_predicted.obs_names) or \
       not np.array_equal(adata_atac_true.var_names, adata_atac_predicted.var_names):
        raise ValueError("Cell and/or peak names/order do not match between true and predicted ATAC objects.")

    # Convert sparse matrices to dense NumPy arrays if they are sparse
    y_true = adata_atac_true.X.toarray() if issparse(adata_atac_true.X) else adata_atac_true.X
    y_score = adata_atac_predicted.X.toarray() if issparse(adata_atac_predicted.X) else adata_atac_predicted.X

    # Flatten the arrays to 1D for global AUROC calculation
    # This treats each cell-peak combination as an independent binary classification task.
    y_true_flat = y_true.flatten()
    y_score_flat = y_score.flatten()

    # Ensure y_true is binary (0 or 1)
    if not np.array_equal(np.unique(y_true_flat), [0, 1]):
        # If there are values other than 0 or 1, attempt to binarize or raise error
        print("Warning: True ATAC values are not strictly binary (0 or 1). Attempting to binarize > 0 to 1.")
        y_true_flat = (y_true_flat > 0).astype(int)
    
    # Check if there are both positive and negative samples, which is required for AUROC
    if len(np.unique(y_true_flat)) < 2:
        print("Warning: Only one class present in true ATAC values. AUROC cannot be calculated.")
        return np.nan # Not a Number

    print(f"Calculating AUROC for {len(y_true_flat)} individual (cell, peak) predictions...")
    auroc_score = roc_auc_score(y_true_flat, y_score_flat)
    
    return auroc_score

In [None]:
calculate_atac_auroc(adata_atac_true=atac_test,adata_atac_predicted=first_10_predict)