In [1]:
from SpaGE.main import SpaGE
import scanpy as sc
import numpy as np
import pandas as pd

# Load data
data_path = '../dataset/Chromium_FFPE_Human_Breast_Cancer_Chromium_FFPE_Human_Breast_Cancer_count_sample_filtered_feature_bc_matrix.h5'
sc_data = sc.read_10x_h5(data_path)
sc_data.X = sc_data.X.toarray()
sc.pp.normalize_total(sc_data)
sc.pp.log1p(sc_data)

st_data_path = '../dataset/Xenium_breast_cancer_sample1_replicate1.h5ad'
st_data = sc.read_h5ad(st_data_path)
sc.pp.normalize_total(st_data)
sc.pp.log1p(st_data)
st_data

# 检查 sc_data.var_names 是否有重复项
duplicates = sc_data.var_names[sc_data.var_names.duplicated()]
print(f"重复的基因: {duplicates}")
# 删除 sc_data.var_names 中的重复基因
sc_data = sc_data[:, ~sc_data.var_names.duplicated()]
st_data.X = st_data.X.toarray()

# Define hold-out genes for validation
hold_out_genes = ['POSTN', 'IL7R', 'ITGAX', 'ACTA2', 'KRT15', 'VWF', 'FASN', 'CEACAM6']

# Convert to dataframes
sc_df = pd.DataFrame(sc_data.X, index=sc_data.obs.index, columns=sc_data.var.index)
st_df = pd.DataFrame(st_data.X, index=st_data.obs.index, columns=st_data.var.index)

# Check for NaN values in both dataframes
print("NaN values in sc_df:", sc_df.isna().sum().sum())
print("NaN values in st_df:", st_df.isna().sum().sum())

# Fill NaN values with zeros if they exist
if sc_df.isna().sum().sum() > 0:
    print("Filling NaN values in sc_df with zeros")
    sc_df = sc_df.fillna(0)

if st_df.isna().sum().sum() > 0:
    print("Filling NaN values in st_df with zeros")
    st_df = st_df.fillna(0)

# Verify no more NaNs exist
print("After filling, NaN values in sc_df:", sc_df.isna().sum().sum())
print("After filling, NaN values in st_df:", st_df.isna().sum().sum())

# Store the ground truth values for hold-out genes before removing them
ground_truth = {}
for gene in hold_out_genes:
    if gene in st_df.columns:
        ground_truth[gene] = st_df[gene].copy()
        print(f"Stored ground truth for {gene}")
    else:
        print(f"Warning: {gene} not found in spatial data")

# Remove hold-out genes from spatial data
st_df_holdout = st_df.copy()
for gene in hold_out_genes:
    if gene in st_df_holdout.columns:
        st_df_holdout = st_df_holdout.drop(columns=[gene])
        print(f"Removed {gene} from spatial data for hold-out validation")

# Run SpaGE imputation with hold-out genes removed
imputed_genes = SpaGE(Spatial_data=st_df_holdout, RNA_data=sc_df, n_pv=5)

# Convert the index to string type and add 1 to each value
imputed_genes.index = imputed_genes.index + 1
imputed_genes.index = imputed_genes.index.astype(str)

# Evaluate imputation quality for hold-out genes
from scipy.stats import pearsonr
print("\nEvaluating imputation quality for hold-out genes:")
correlations = {}
for gene in hold_out_genes:
    if gene in imputed_genes.columns and gene in ground_truth:
        corr, p_value = pearsonr(ground_truth[gene], imputed_genes[gene])
        correlations[gene] = corr
        print(f"{gene}: Pearson correlation = {corr:.4f} (p-value: {p_value:.4e})")
    else:
        print(f"{gene}: Not available for evaluation")

# Calculate average correlation across all hold-out genes
if correlations:
    avg_corr = sum(correlations.values()) / len(correlations)
    print(f"\nAverage correlation across all hold-out genes: {avg_corr:.4f}")

# Create a combined dataframe with both original and imputed genes
# First, check if there are any overlapping genes between the original data and imputed genes
overlapping_genes = np.intersect1d(st_df.columns, imputed_genes.columns)
if len(overlapping_genes) > 0:
    print(f"Warning: Found {len(overlapping_genes)} overlapping genes between original data and imputed genes.")
    print("These genes will be taken from the original data.")

# Create a combined dataframe with both original and imputed genes
# Use axis=1 to concatenate columns (genes) not rows
combined_df = pd.concat([st_df, imputed_genes.loc[st_df.index]], axis=1)

# If there are duplicated columns, keep the first occurrence (from original data)
if combined_df.columns.duplicated().any():
    combined_df = combined_df.loc[:, ~combined_df.columns.duplicated()]

print(f"Original data shape: {st_df.shape}")
print(f"Imputed genes shape: {imputed_genes.shape}")
print(f"Combined data shape: {combined_df.shape}")

# Verify that the row count matches the original data
assert combined_df.shape[0] == st_df.shape[0], "Row count mismatch after concatenation"
# Verify that the column count is the sum of unique genes
assert combined_df.shape[1] == len(set(st_df.columns) | set(imputed_genes.columns)), "Column count mismatch"
combined_df = combined_df.astype(np.float16)
# Check the data type of combined_df
print("Data type of combined_df:")
print(f"DataFrame dtype: {combined_df.dtypes.value_counts()}")
print(f"Memory usage: {combined_df.memory_usage().sum() / 1e6:.2f} MB")

# Create and save the imputed AnnData object
original_st_data = sc.read_h5ad(st_data_path)
imputed_st_data = sc.AnnData(X=combined_df, 
                             obs=original_st_data.obs,
                             uns=original_st_data.uns
                             )
imputed_st_data.uns['COVET_genes'] = imputed_genes.columns.tolist()
imputed_st_data.uns['hold_out_genes'] = hold_out_genes
imputed_st_data.uns['hold_out_correlations'] = correlations
imputed_st_data.write_h5ad('../dataset/SpaGE_imputed_Xenium_breast_cancer_sample1_replicate1_hold_out.h5ad')
imputed_st_data.X.sum(axis=1)

  utils.warn_names_duplicates("var")


  utils.warn_names_duplicates("var")




重复的基因: Index(['TBCE', 'HSPA14', 'TMSB15B'], dtype='object')


NaN values in sc_df: 0
NaN values in st_df: 0


After filling, NaN values in sc_df: 0
After filling, NaN values in st_df: 0
Stored ground truth for POSTN
Stored ground truth for IL7R
Stored ground truth for ITGAX
Stored ground truth for ACTA2
Stored ground truth for KRT15
Stored ground truth for VWF
Stored ground truth for FASN
Stored ground truth for CEACAM6


Removed POSTN from spatial data for hold-out validation
Removed IL7R from spatial data for hold-out validation
Removed ITGAX from spatial data for hold-out validation
Removed ACTA2 from spatial data for hold-out validation
Removed KRT15 from spatial data for hold-out validation


Removed VWF from spatial data for hold-out validation
Removed FASN from spatial data for hold-out validation
Removed CEACAM6 from spatial data for hold-out validation
Genes to predict:  ['A1CF' 'A2M' 'A2ML1' ... 'ZYG11B' 'ZYX' 'ZZEF1']


Spatial_data_scaled NaN values:  0


RNA_data_scaled NaN values:  11720890


Spatial_data_scaled NaN values after filling:  0


RNA_data_scaled NaN values after filling:  0
Common_data NaN values:  0


  return var(axis=axis, dtype=dtype, out=out, ddof=ddof, **kwargs)
  return var(axis=axis, dtype=dtype, out=out, ddof=ddof, **kwargs)



Evaluating imputation quality for hold-out genes:
POSTN: Pearson correlation = 0.6982 (p-value: 0.0000e+00)


IL7R: Pearson correlation = 0.7130 (p-value: 0.0000e+00)
ITGAX: Pearson correlation = 0.5150 (p-value: 0.0000e+00)


ACTA2: Pearson correlation = 0.5002 (p-value: 0.0000e+00)
KRT15: Pearson correlation = 0.6288 (p-value: 0.0000e+00)


VWF: Pearson correlation = 0.7218 (p-value: 0.0000e+00)
FASN: Pearson correlation = 0.7754 (p-value: 0.0000e+00)


CEACAM6: Pearson correlation = 0.5760 (p-value: 0.0000e+00)

Average correlation across all hold-out genes: 0.6411
These genes will be taken from the original data.


Original data shape: (167780, 313)
Imputed genes shape: (167780, 17780)
Combined data shape: (167780, 18085)


Data type of combined_df:
DataFrame dtype: float16    18085
Name: count, dtype: int64


Memory usage: 6074.17 MB


array([ 955.5, 1002.5,  733. , ..., 1033. ,  999.5, 1024. ], dtype=float16)