In [None]:
 %cd /sci/labs/yotamd/lab_share/avishai.wizel/eRNA/

In [None]:
import tensorflow as tf
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import anndata as ad
import scanpy as sc
from scipy.sparse import issparse
from sklearn.preprocessing import MinMaxScaler
from tensorflow.keras import regularizers
from tensorflow import keras
import tensorflow_addons as tfa



# Load data

In [None]:
sc_rna = ad.read_h5ad('./10X_PBMC/03_filtered_data/filtered_rna_adata.h5ad')
sc_atac = ad.read_h5ad("./10X_PBMC/03_filtered_data/filtered_atac_adata.h5ad")

# Main parameters

In [None]:
top_rna_var_genes = 15000
top_var_peaks = 500

# Filters cells

In [None]:
def filter_cells_by_qc_metrics(
    adata: ad.AnnData,
    min_genes: int = 200,
    max_genes: int = 2500,
    min_counts: int = 1000,
    max_mt_pct: float = 40.0
) -> ad.AnnData:
    """
    Filters cells based on standard quality control metrics:
    number of genes detected, total counts, and mitochondrial gene percentage.

    Args:
        adata (anndata.AnnData):
            The AnnData object containing raw gene expression counts (cells x genes).
            Assumes mitochondrial genes are prefixed with 'MT-' (human) or 'mt-' (mouse).
        min_genes (int): Minimum number of genes expressed per cell.
        max_genes (int): Maximum number of genes expressed per cell (to remove doublets).
        min_counts (int): Minimum total counts per cell.
        max_mt_pct (float): Maximum allowed percentage of mitochondrial counts per cell.

    Returns:
        anndata.AnnData: A new AnnData object with filtered cells.
                         QC metrics are added to adata.obs.
    """
    
    # Ensure adata.var['mt'] is set if not already
    if 'mt' not in adata.var:
        # Assuming human data with 'MT-' prefix. Adjust for mouse ('mt-') or other.
        adata.var['mt'] = adata.var_names.str.startswith('MT-') 

    print("Calculating QC metrics...")
    sc.pp.calculate_qc_metrics(
        adata, 
        qc_vars=['mt'], 
        percent_top=None, 
        log1p=False, 
        inplace=True
    )

    print(f"Original number of cells: {adata.n_obs}")
    
    # Apply filters
    initial_cells_count = adata.n_obs
    
    # Combine filtering criteria using boolean logic
    cells_to_keep = (
        (adata.obs['n_genes_by_counts'] >= min_genes) &
        (adata.obs['n_genes_by_counts'] <= max_genes) &
        (adata.obs['total_counts'] >= min_counts) &
        (adata.obs['pct_counts_mt'] <= max_mt_pct)
    )
    
    adata_filtered = adata[cells_to_keep, :].copy()
    
    filtered_cells_count = adata_filtered.n_obs
    removed_cells_count = initial_cells_count - filtered_cells_count
    
    print(f"Cells removed: {removed_cells_count}")
    print(f"Number of cells after filtering: {filtered_cells_count}")

    return adata_filtered

In [None]:
adata_filtered_qc = filter_cells_by_qc_metrics(
    sc_rna.copy(),
    min_genes=200,
    max_genes=2500,
    min_counts=500,
    max_mt_pct=25
)

# Normalize scRNA-seq

In [None]:
# 1. Size Normalization
sc.pp.normalize_total(adata_filtered_qc, target_sum=1e4)

# 2. Log-transformation (log1p)
sc.pp.log1p(adata_filtered_qc)
sc.pp.highly_variable_genes(adata_filtered_qc, n_top_genes=top_rna_var_genes, flavor='seurat')


Take only high variable genes:

In [None]:
sc_rna_filtered = adata_filtered_qc[:, adata_filtered_qc.var.highly_variable].copy()

# filter atac cells based on filtered rna cells
filtered_cell_barcodes = sc_rna_filtered.obs_names
adata_atac_filtered = sc_atac[sc_atac.obs_names.isin(filtered_cell_barcodes), :].copy()
adata_atac_filtered = sc_atac[filtered_cell_barcodes, :].copy()



# Filter for highly variable peaks

In [None]:
def get_top_n_highly_variable_peaks(adata_atac: ad.AnnData, n_top_peaks: int = 10000) -> ad.AnnData:
    """
    Identifies and keeps the top N highly variable peaks in an AnnData object
    based on their variance-to-mean ratio (a proxy for dispersion).

    Args:
        adata_atac (anndata.AnnData):
            The AnnData object containing ATAC-seq data (cells x peaks).
            Assumes adata_atac.X contains counts or binarized values.
        n_top_peaks (int):
            The number of top highly variable peaks to select.

    Returns:
        anndata.AnnData: An AnnData object subsetted to include only the selected top N peaks.
                         Statistical information is added to adata_atac.var.
    """
    
    print(f"Original AnnData shape (cells x peaks): {adata_atac.shape}")

    # Convert to dense for calculation if sparse, using float32 for memory efficiency
    data = adata_atac.X.toarray().astype(np.float32) if issparse(adata_atac.X) else adata_atac.X.astype(np.float32)
    
    # Calculate mean and variance for each peak (column-wise)
    peak_means = np.mean(data, axis=0)
    peak_variances = np.var(data, axis=0)
    
    # Avoid division by zero for peaks with zero mean
    # A common approach is to add a small constant to the mean or handle NaN results.
    # For ATAC (binary), peaks with zero mean have zero variance and are not variable.
    # We'll set dispersion to 0 for these or filter them out.
    
    # Calculate variance-to-mean ratio as a measure of dispersion
    # Adding a small epsilon to avoid division by zero
    epsilon = 1e-6 
    dispersion = peak_variances / (peak_means + epsilon)
    
    # Store these metrics in adata.var
    adata_atac.var['peak_means'] = peak_means
    adata_atac.var['peak_variances'] = peak_variances
    adata_atac.var['peak_dispersion'] = dispersion

    # Rank peaks by dispersion and select top N
    # Sort in descending order
    adata_atac.var['ranked_dispersion'] = adata_atac.var['peak_dispersion'].rank(ascending=False, method='first')
    
    selected_peaks = adata_atac.var[adata_atac.var['ranked_dispersion'] <= n_top_peaks].index
    
    adata_filtered_peaks = adata_atac[:, selected_peaks].copy()
    
    print(f"AnnData shape after filtering to top {n_top_peaks} highly variable peaks: {adata_filtered_peaks.shape}")
    
    return adata_filtered_peaks


In [None]:
adata_hvps = get_top_n_highly_variable_peaks(adata_atac_filtered.copy(), top_var_peaks)


In [None]:
scRNA = sc_rna_filtered.X
scATAC_binary = adata_hvps.X.toarray()
del(sc_rna)
del(sc_atac)
del(adata_atac_filtered)
del(sc_rna_filtered)

In [None]:
print("scRNA dim (cells X genes):" ,scRNA.shape)
print("scATAC dim (cells X peaks):" ,scATAC_binary.shape)

In [None]:
# Standardize the RNA-seq data
scaler_X = StandardScaler()
X_scaled = scaler_X.fit_transform(scRNA.toarray())

In [None]:
scaler_y = MinMaxScaler()
y_scaled = scaler_y.fit_transform(scATAC_binary)

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y_scaled, test_size=0.2, random_state=42)

In [None]:
# # Calculate the total number of zeros and ones across all peaks combined
# # This flattens the array and then counts occurrences of 0s and 1s.
# num_zeros = np.sum(y_train == 0)
# num_ones = np.sum(y_train == 1)
# total_elements = num_zeros + num_ones # Total number of 0s and 1s combined in y_train

# # Calculate class weights
# # These weights are inversely proportional to the class frequencies.
# # This assigns a higher weight to the less frequent class (typically '1's in imbalanced data)
# # to make the model pay more attention to correctly classifying them.
# # The 'total_elements / 2.0' part ensures the sum of weights for a balanced dataset
# # would ideally be around 1, helping to stabilize the loss scale.

# # Ensure no division by zero if a class is completely absent (though rare for 0/1)
# weight_for_0 = (total_elements / (10.0 * num_zeros)) if num_zeros > 0 else 1.0
# weight_for_1 = (total_elements / (2.0 * num_ones)) if num_ones > 0 else 1.0

# # Store the calculated weights in a dictionary format required by Keras
# class_weights = {0: weight_for_0, 1: weight_for_1}
# print(f"Computed class weights: {class_weights}")


In [None]:

# --- 2. Building the Keras Neural Network Model ---

# Define model parameters
input_dim = X_train.shape[1]  # Number of genes (input features)
output_dim = y_train.shape[1] # Number of ATAC genomic locations (output targets)

# Build the Sequential Keras model
model = keras.Sequential([
    # Input Layer: Defines the shape of the input data
    layers.Input(shape=(input_dim,)),

    # Hidden Layers (Dense - Fully Connected)
    # Common to start with larger layers and gradually decrease units.
    # 'relu' (Rectified Linear Unit) is a common activation function for hidden layers.
    layers.Dense(units=512, activation='relu'),
    layers.BatchNormalization(),
    layers.Dropout(0.2), # Dropout layer for regularization (prevents overfitting), dropping 20% of neurons
    layers.Dense(units=256, activation='relu'),
    layers.BatchNormalization(),
    layers.Dropout(0.2),
    layers.Dense(units=128, activation='relu'),
    layers.BatchNormalization(),
    # You can add more layers here if needed

    # Output Layer
    # The number of units must match the number of ATAC locations (output_dim).
    # Activation function:
    # - 'linear' (or no activation) for continuous regression output (e.g., after StandardScaler).
    # - 'sigmoid' if your 'y' targets are scaled between 0 and 1 (e.g., after MinMaxScaler)
    #   or if they are binary (0/1) representing probabilities.
     layers.Dense(units=output_dim, activation='sigmoid')
])

# Compile the model
# Optimizer: 'adam' is a popular and efficient optimizer.
# Loss function:
# - 'mse' (Mean Squared Error) is standard for regression problems.
# - 'mae' (Mean Absolute Error) is another common choice for regression.
# - If your output is binary (0/1) with 'sigmoid' activation, use 'binary_crossentropy' for loss.
optimizer=keras.optimizers.SGD(learning_rate=0.01, momentum=0.9) # Start with 0.01 for LR

# model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['mae', 'accuracy']) # Using MAE as an additional metric

import tensorflow as tf

def dice_loss(y_true, y_pred, smooth=1e-6):
    y_true = tf.cast(y_true, tf.float32)
    y_pred = tf.cast(y_pred, tf.float32)
    
    intersection = tf.reduce_sum(y_true * y_pred, axis=1)
    union = tf.reduce_sum(y_true, axis=1) + tf.reduce_sum(y_pred, axis=1)
    
    dice_coef = (2. * intersection + smooth) / (union + smooth)
    loss = 1 - dice_coef
    
    return tf.reduce_mean(loss)

    
def combined_loss(y_true, y_pred):
    focal = tfa.losses.SigmoidFocalCrossEntropy(gamma=2.0, alpha=0.75)(y_true, y_pred)
    d_loss = dice_loss(y_true, y_pred)
    return focal + d_loss
    
# Compile the model
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),  # Can lower to 1e-4 if data is noisy
    loss=combined_loss,
    metrics=[
        tf.keras.metrics.AUC(
            curve='PR',         # Precision-Recall AUC
            multi_label=True,   # Required for multi-label output
            num_labels=output_dim,  # Set to your output vector size
            name='pr_auc'
        ),
        tf.keras.metrics.BinaryAccuracy(name='bin_acc'),   # Overall binary accuracy
        tf.keras.metrics.Precision(name='precision'),      # How many predicted positives were correct
        tf.keras.metrics.Recall(name='recall')             # How many actual positives were found
    ]
)

# Print a summary of the model architecture
model.summary()


In [None]:

# --- 3. Training the Model ---

print("\nStarting model training...")
history = model.fit(
    X_train, y_train,
    epochs=50,          # Number of times to iterate over the entire training dataset
    batch_size=32,      # Number of samples per gradient update
    validation_split=0.1, # Fraction of the training data to be used as validation data
                          # This helps monitor performance on unseen data during training.
    verbose=1,           # Display progress bar during training
    # class_weight=class_weights
)
print("Model training complete.")




In [None]:
def plot_metrics_from_history(history):
    metrics = ['precision', 'recall', 'pr_auc']
    for metric in metrics:
        plt.figure(figsize=(6, 4))
        plt.plot(history.history[metric], label=f"Train {metric}")
        plt.plot(history.history[f"val_{metric}"], label=f"Val {metric}")
        plt.title(f"{metric.upper()} Over Epochs")
        plt.xlabel("Epoch")
        plt.ylabel(metric.upper())
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.show()

plot_metrics_from_history(history)

In [None]:
y_pred_probs = model.predict(X_test)


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve, average_precision_score

precisions = []
recalls = []
average_precisions = []

for i in range(y_test.shape[1]):
    precision, recall, _ = precision_recall_curve(y_test[:, i], y_pred_probs[:, i])
    ap = average_precision_score(y_test[:, i], y_pred_probs[:, i])
    precisions.append(precision)
    recalls.append(recall)
    average_precisions.append(ap)

# ממוצע מדדי AP לכל הפיקים
mean_ap = np.mean(average_precisions)
print(f"Mean Average Precision (mAP): {mean_ap:.4f}")

# לצורך פלוט ממוצע – אופציה פשוטה היא להשתמש ב-micro-average PR curve:
from sklearn.metrics import precision_recall_curve

# Flatten arrays עבור micro-average
y_true_flat = y_test.ravel()
y_scores_flat = y_pred_probs.ravel()

precision_micro, recall_micro, _ = precision_recall_curve(y_true_flat, y_scores_flat)
average_precision_micro = average_precision_score(y_true_flat, y_scores_flat)

plt.figure(figsize=(8, 6))
plt.plot(recall_micro, precision_micro, label=f'micro-average PR curve (AP = {average_precision_micro:.4f})')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall curve (Micro-average)')
plt.legend()
plt.grid(True)
plt.show()