In [1]:
import sys, importlib
import scanpy as sc
import os
import anndata as ad
from sklearn.model_selection import train_test_split
import numpy as np

sys.path.append("..")
import scripts.data_prep as dp
importlib.reload(dp)

IN_PATH_neu = "../data/processed/GSE169569_GSE169570_neutrophils_bowel_disease-25.07.0.h5ad"
IN_PATH_cov = "../data/processed/GSE228841_PRJNA951718_COVID-19_vaccine-25.07.0.h5ad"

adata_neu = sc.read_h5ad(IN_PATH_neu)
adata_cov = sc.read_h5ad(IN_PATH_cov)



### Check information of the datasets

In [2]:
print(adata_neu.obs.columns)
print("------------------    ------------------  ------------------    ------------------")
print(adata_cov.obs.columns)

Index(['total_counts_before_preprocessing', 'total_counts_after_trimming',
       'total_counts_after_preprocessing',
       'total_count_ratio__after_to_before', 'QC_mapping_ratio_bacterial',
       'QC_mapping_ratio_viral', 'QC_mapping_ratio_miRNA',
       'QC_mapping_ratio_sncRNA', 'Bases', 'Bytes', 'Avg_spot_length',
       'BioProject', 'BioSample', 'Experiment', 'GEO_Accession_exp',
       'SRA_study', 'Sample_name', 'Consent', 'Center_name', 'Organism',
       'source_name', 'cohort', 'sex', 'age', 'treatment', 'diagnosis',
       'Assay_type', 'Sequencer', 'Sample_type', 'Lab_library_layout',
       'Lab_library_selection', 'Lab_library_source',
       'Lab_RNA_extraction_protocol', 'Lab_Blocking_State',
       'Lab_library_preparation_kit', 'Release_date', 'create_date', 'version',
       'total_count'],
      dtype='object')
------------------    ------------------  ------------------    ------------------
Index(['total_counts_before_preprocessing', 'total_counts_after_prepro

In [3]:
print(adata_neu.var.columns)
print("------------------    ------------------  ------------------    ------------------")
print(adata_cov.var.columns)

Index(['length', 'g_fraction', 'a_fraction', 't_fraction', 'c_fraction',
       'gc_fraction', 'MFE_37', 'spikein', 'qiaseq_spikein', 'hbdx_spikein',
       ...
       'MS2_Cancer__Thrombocytes__RPMscaled_prop',
       'MS2_Cancer__Monocytes__RPMscaled_prop',
       'MS2_Cancer__B_cells__RPMscaled_prop',
       'MS2_Cancer__Basophils__RPMscaled_prop',
       'MS2_Cancer__Eosinophils__RPMscaled_prop',
       'MS2_Cancer__Erythrocytes__RPMscaled_prop',
       'MS2_Cancer__main_cell_type__100', 'MS2_Cancer__main_cell_type__75',
       'MS2_Cancer__main_cell_type__50', 'MS2_Cancer__plasma_cell_RPM_ratio'],
      dtype='object', length=186)
------------------    ------------------  ------------------    ------------------
Index(['length', 'g_fraction', 'a_fraction', 't_fraction', 'c_fraction',
       'gc_fraction', 'MFE_37', 'spikein', 'qiaseq_spikein', 'hbdx_spikein',
       ...
       'MS2_Cancer__Thrombocytes__RPMscaled_prop',
       'MS2_Cancer__Monocytes__RPMscaled_prop',
       'MS2_C

In [5]:
adata_neu.var_names               # Index([...])

Index(['GTGCATGATCTCAAGTTTTCAATCTGAGACCT', 'ATCACAGGGTAGAACCACGGAC',
       'TGGAGAGAAAGGCAGTTCCTGT', 'TCCTGACTCCAGGTCCTGTGT',
       'CGTGGTCTCCCAACCCTTGTACCAGT', 'TCGAGGACCCCCCCTGCCTGG',
       'TCGGGCCGATCGCACGCC', 'CTGGGAATACCGGGTGCTGTAGGCTA',
       'TGACCTATGAATTGGCAGCC', 'AGGTTCCGGATAAGTAAGAGCT',
       ...
       'TCTCGTCTGATCTCGGAAGCTAAGCAGGGTCGGG', 'AATCCCGGACGAGCCCTG',
       'ATTCGTAGACGACCTGCTTCTGGGTCGGGGTT', 'GAAAGATGGTGAACTATGCCTGGGCAGGG',
       'TTCAAGTAATCCAGGATAGGCTT', 'TTCCGTACTGAGCTGCCCCGA',
       'CTGGCCCTCTCTGCCCTTAAGA', 'AACCGAGCGTCCAAGCTCTTT',
       'TTTTCATTAATCAAGAACGAAAGTCGGAGG',
       'GAGACCCGTCGCCGCGCTCTCCCCCCTCCCGGCGCC'],
      dtype='object', name='Sequence', length=51777)

In [None]:
# batch for neu: bioproject, center_name, sequencer, lab_library_preparation_kit/lab_rna_extraction_protocol
# batch for cov: bioproject, SRA_study, center_name, sequencer, lab_* fields
adata_neu.obs[['BioProject','Center_name', 'Organism', 'source_name', 'cohort', 'sex', 'age', 'treatment', 'diagnosis', 'Assay_type', 'Sequencer','Sample_type']]

Unnamed: 0_level_0,BioProject,Center_name,Organism,source_name,cohort,sex,age,treatment,diagnosis,Assay_type,Sequencer,Sample_type
smpID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
SRR14062693,PRJNA717018,GEO,Homo sapiens,peripheral blood,Swedish,male,27,treatment-naive,Symptomatic control,ncRNA-Seq,Illumina HiSeq 2500,PAXgene
SRR14062694,PRJNA717018,GEO,Homo sapiens,peripheral blood,Swedish,female,20,treatment-naive,Ulcerative colitis,ncRNA-Seq,Illumina HiSeq 2500,PAXgene
SRR14062695,PRJNA717018,GEO,Homo sapiens,peripheral blood,Swedish,female,30,treatment-naive,Symptomatic control,ncRNA-Seq,Illumina HiSeq 2500,PAXgene
SRR14062696,PRJNA717018,GEO,Homo sapiens,peripheral blood,Swedish,male,27,treatment-naive,Crohn's disease,ncRNA-Seq,Illumina HiSeq 2500,PAXgene
SRR14062697,PRJNA717018,GEO,Homo sapiens,peripheral blood,Swedish,male,32,treatment-naive,Ulcerative colitis,ncRNA-Seq,Illumina HiSeq 2500,PAXgene
...,...,...,...,...,...,...,...,...,...,...,...,...
SRR14063135,PRJNA717025,GEO,Homo sapiens,peripheral blood,German,male,28,treatment-exposed,Crohn's disease,ncRNA-Seq,Illumina HiSeq 2500,PAXgene
SRR14063136,PRJNA717025,GEO,Homo sapiens,peripheral blood,German,female,66,treatment-exposed,Ulcerative colitis,ncRNA-Seq,Illumina HiSeq 2500,PAXgene
SRR14063137,PRJNA717025,GEO,Homo sapiens,peripheral blood,German,male,23,treatment-exposed,Crohn's disease,ncRNA-Seq,Illumina HiSeq 2500,PAXgene
SRR14063138,PRJNA717025,GEO,Homo sapiens,peripheral blood,German,female,43,treatment-exposed,Crohn's disease,ncRNA-Seq,Illumina HiSeq 2500,PAXgene


In [None]:
print(dp.print_adata_summary(adata_neu))
print("------------------    ------------------  ------------------    -------------------  -----------------  ------------------    ------------------")
print(dp.print_adata_summary(adata_cov))

Shape: (447, 51777)
X dtype: float64
Example values from X: [0.         0.         0.65439486 0.         0.         0.
 0.65439486 0.81989226 0.         6.50348223]
Layers: []

.obs columns: ['total_counts_before_preprocessing', 'total_counts_after_trimming', 'total_counts_after_preprocessing', 'total_count_ratio__after_to_before', 'QC_mapping_ratio_bacterial', 'QC_mapping_ratio_viral', 'QC_mapping_ratio_miRNA', 'QC_mapping_ratio_sncRNA', 'Bases', 'Bytes', 'Avg_spot_length', 'BioProject', 'BioSample', 'Experiment', 'GEO_Accession_exp', 'SRA_study', 'Sample_name', 'Consent', 'Center_name', 'Organism', 'source_name', 'cohort', 'sex', 'age', 'treatment', 'diagnosis', 'Assay_type', 'Sequencer', 'Sample_type', 'Lab_library_layout', 'Lab_library_selection', 'Lab_library_source', 'Lab_RNA_extraction_protocol', 'Lab_Blocking_State', 'Lab_library_preparation_kit', 'Release_date', 'create_date', 'version', 'total_count']

.obs sample (first 5 rows):
             total_counts_before_preprocessing

In [None]:
# Zero fraction
print(dp._zero_fraction(adata_neu.X))
print(dp._zero_fraction(adata_cov.X))

0.6622658890935611
0.654555137757781


#### batch candidates

In [8]:
import numpy as np
import pandas as pd
from sklearn.preprocessing import OneHotEncoder
from sklearn.linear_model import LinearRegression
import scanpy as sc

def guess_batch_candidates(adata, exclude=()):
    candidates = []
    for col in adata.obs.columns:
        if col in exclude:
            continue
        if adata.obs[col].dtype.name == "category" or adata.obs[col].dtype == object:
            if adata.obs[col].nunique() > 1:
                candidates.append(col)
    return candidates

def _batch_r2_for_pcs(X, expl_var, obs, batch_col):
    """
    PCA for each PC, R² of batch_col
    weighted_R2 = variance ratio weighted average
    """
    enc = OneHotEncoder(sparse_output=False, handle_unknown="ignore")
    B = enc.fit_transform(obs[[batch_col]])
    r2_vec = []
    for i in range(X.shape[1]):
        y = X[:, i]
        mdl = LinearRegression().fit(B, y)
        r2_vec.append(mdl.score(B, y))
    r2_vec = np.array(r2_vec)
    weighted = float((r2_vec * expl_var).sum() / expl_var.sum())
    return weighted, r2_vec

def rank_batches_within(adata, candidates, base_covars=()):
    """
    Rank batch candidates by weighted_R²
    """

    adata_pca = adata.copy()
    sc.pp.highly_variable_genes(adata_pca, flavor="seurat_v3", n_top_genes=2000)
    adata_pca = adata_pca[:, adata_pca.var['highly_variable']]
    sc.pp.scale(adata_pca, max_value=10)
    sc.tl.pca(adata_pca, n_comps=30)

    X = adata_pca.obsm['X_pca']
    expl_var = adata_pca.uns['pca']['variance_ratio']

    rows = []
    for col in candidates:
        wr2, r2_vec = _batch_r2_for_pcs(X, expl_var, adata.obs, col)
        rows.append({'factor': col, 'weighted_partial_R2': wr2, 'n_levels': adata.obs[col].nunique()})
    rank_df = pd.DataFrame(rows).sort_values("weighted_partial_R2", ascending=False).reset_index(drop=True)
    return rank_df, adata_pca

def add_confound_flags(rank_df, adata, bio_label='cell_type', high=0.7):
    """
    add flag column to rank_df based on confounding with bio_label
    """
    df = rank_df.copy()
    if bio_label not in adata.obs:
        df['confound_flag'] = 'NA'
        return df
    flags = []
    for f in df['factor']:
        if f not in adata.obs: 
            flags.append('NA'); continue
        cross = pd.crosstab(adata.obs[f], adata.obs[bio_label])
        
        chi2 = ((cross - cross.mean(axis=0))**2 / (cross.mean(axis=0)+1e-6)).sum().sum()
        norm = chi2 / (len(adata)*len(adata.obs[bio_label].unique()))
        if norm > high:
            flags.append('high')
        else:
            flags.append('low')
    df['confound_flag'] = flags
    return df

def deduplicate_batches(adata, ordered, thresh=0.95):
    """
    deduplicate batches if nearly completely confounded
    """
    seen_patterns = []
    pruned = []
    for f in ordered:
        if f not in adata.obs: continue
        pattern = tuple(adata.obs[f].astype(str))
        dup = False
        for p in seen_patterns:
            overlap = np.mean([a==b for a,b in zip(pattern, p)])
            if overlap >= thresh:
                dup = True; break
        if not dup:
            seen_patterns.append(pattern)
            pruned.append(f)
    return pruned

def select_minimal_batch_set(adata, ordered, base_covars=(), delta=0.01):
    """
    add batch factors from ordered list if partial R² gain >= delta
    """
    sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=2000)
    adata_pca = adata[:, adata.var['highly_variable']]
    sc.pp.scale(adata_pca, max_value=10)
    sc.tl.pca(adata_pca, n_comps=30)
    X = adata_pca.obsm['X_pca']
    expl_var = adata_pca.uns['pca']['variance_ratio']

    selected, gains = [], []
    tot = 0.0
    for f in ordered:
        wr2, _ = _batch_r2_for_pcs(X, expl_var, adata.obs, f)
        incr = wr2 - tot
        if incr >= delta:
            selected.append(f)
            gains.append((f, incr))
            tot = wr2
    return selected, tot, gains


In [None]:
neu = adata_neu
cov = adata_cov

for name, ad in [('neu', neu), ('cov', cov)]:
    print(f"\n### [{name}] within-dataset batch selection ###")
    # 1) Explore candidate batches
    candidates = guess_batch_candidates(ad, exclude=('sex', 'age', 'study','cell_type','condition','label','donor'))
    print("candidates:", candidates)

    # 2) Rank candidates by weighted partial R²
    rank_df, ad_pca = rank_batches_within(ad, candidates, base_covars=())
    rank_df = add_confound_flags(rank_df, ad, bio_label='cell_type', high=0.7)
    print(rank_df.head(10))

    # 3) Delete duplicates (if any)
    ordered = rank_df['factor'].tolist()
    pruned = deduplicate_batches(ad, ordered, thresh=0.95)
    print("after de-dup:", pruned[:10])

    # 4) Select minimal set
    selected, tot_w, gains = select_minimal_batch_set(ad, pruned, base_covars=(), delta=0.01)
    print("selected batch set:", selected, " (total weighted partial R² =", round(tot_w, 3), ")")
    print("incremental gains:", [(f, round(g,3)) for f,g in gains])
    
    # Visualize top factors and selected
    



### [neu] within-dataset batch selection ###
candidates: ['BioProject', 'BioSample', 'Experiment', 'GEO_Accession_exp', 'SRA_study', 'Sample_name', 'cohort', 'treatment', 'diagnosis', 'create_date']


  view_to_actual(adata)


              factor  weighted_partial_R2  n_levels confound_flag
0          BioSample             1.000000       447            NA
1         Experiment             1.000000       447            NA
2  GEO_Accession_exp             1.000000       447            NA
3        Sample_name             1.000000       447            NA
4        create_date             0.389127        14            NA
5         BioProject             0.289879         2            NA
6          SRA_study             0.289879         2            NA
7             cohort             0.289879         2            NA
8          treatment             0.289879         2            NA
9          diagnosis             0.119703         4            NA
after de-dup: ['BioSample', 'Experiment', 'GEO_Accession_exp', 'create_date', 'BioProject', 'SRA_study', 'cohort', 'treatment', 'diagnosis']


  view_to_actual(adata)


selected batch set: ['BioSample']  (total weighted partial R² = 1.0 )
incremental gains: [('BioSample', 1.0)]

### [cov] within-dataset batch selection ###
candidates: ['BioSample', 'Experiment', 'Sample_name', 'vaccine', 'naat_result', 'timepoint', 'create_date']


  view_to_actual(adata)


        factor  weighted_partial_R2  n_levels confound_flag
0    BioSample             1.000000       180            NA
1   Experiment             1.000000       180            NA
2  Sample_name             1.000000       180            NA
3  naat_result             0.115688         3            NA
4  create_date             0.035053         3            NA
5    timepoint             0.017321         3            NA
6      vaccine             0.005348         2            NA
after de-dup: ['BioSample', 'Experiment', 'Sample_name', 'naat_result', 'create_date', 'timepoint', 'vaccine']


  view_to_actual(adata)


selected batch set: ['BioSample']  (total weighted partial R² = 1.0 )
incremental gains: [('BioSample', 1.0)]


### Denormalize back to counts

In [16]:
for adata in [adata_neu, adata_cov]:
    dp.invert_log2cpm_to_counts(
        adata,
        totals_col="total_counts_after_preprocessing",
        out_layer="counts",
        clip_negative=True
    )

adata_neu.write("/mnt/data/thesis_project/data/processed/raw_count/GSE169569_raw_counts.h5ad")
adata_cov.write("/mnt/data/thesis_project/data/processed/raw_count/GSE228841_raw_counts.h5ad")

Zero fraction (counts): 0.6622658890935611
Row-sum corr vs totals: 0.9999999980500865
Median abs diff: 314.0
Min/Max counts: 0 3857841
Zero fraction (counts): 0.654555137757781
Row-sum corr vs totals: 0.9999997938787484
Median abs diff: 6923.5
Min/Max counts: 0 1599031


In [17]:
IN_PATH_neu = "../data/processed/raw_count/GSE169569_raw_counts.h5ad"
IN_PATH_cov = "../data/processed/raw_count/GSE228841_raw_counts.h5ad"

adata_neu = sc.read_h5ad(IN_PATH_neu)
adata_cov = sc.read_h5ad(IN_PATH_cov)

print(adata_neu.layers.keys())
print(adata_cov.layers.keys())

KeysView(Layers with keys: counts, log2_1p_CPM_original)
KeysView(Layers with keys: counts, log2_1p_CPM_original)
