# GPU-accelerated single-cell integration

This notebook performs QC, normalization, donor-only batch correction, and clustering for the fasting cohort using `rapids_singlecell` on GPU. Cells with >400 detected genes and genes expressed in >50 cells are retained, and every generated PNG is saved at 300 dpi while simultaneously displayed inline.

## Workflow overview
- Load the dataset meta-sheet (`MolsPerCell_MEX_metadata.csv`) that maps donors, sample types, and timepoints to MEX archives.
- Stage and read each sample as 10x-format matrices, compute QC metrics, and keep only cells with more than 400 genes.
- Filter genes expressed in more than 50 cells, normalize, and identify HVGs on the GPU with `rapids_singlecell`.
- Run donor-level batch correction (timepoints remain biological) and compute embeddings, clusters, and marker genes.
- Export annotated results, 300 dpi figures, and an execution log so the entire run is reproducible.

In [None]:
import os
import warnings
from pathlib import Path
import logging

import numpy as np
import pandas as pd
import scanpy as sc
import rapids_singlecell as rsc
import cupy as cp
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import Image, display

warnings.filterwarnings('ignore')

MAX_THREADS = int(os.environ.get('SC_MAX_THREADS', '48'))
thread_env_vars = [
    'OMP_NUM_THREADS', 'OPENBLAS_NUM_THREADS', 'MKL_NUM_THREADS',
    'BLIS_NUM_THREADS', 'NUMEXPR_NUM_THREADS', 'NUMBA_NUM_THREADS'
]
for var in thread_env_vars:
    os.environ[var] = str(MAX_THREADS)

sc.settings.n_jobs = MAX_THREADS
sc.settings.verbosity = 3
sc.settings.set_figure_params(dpi=300, dpi_save=300, fontsize=11)
sns.set_theme(style='whitegrid')
print(f"Using up to {MAX_THREADS} CPU threads and GPU device {cp.cuda.runtime.getDevice()} for Rapids workflows.")

In [None]:
PROJECT_ROOT = Path('..').resolve()
DATA_ROOT = PROJECT_ROOT
METADATA_PATH = PROJECT_ROOT / 'MolsPerCell_MEX_metadata.csv'
STAGING_DIR = PROJECT_ROOT / 'data' / 'staged_mex'
FIG_DIR = PROJECT_ROOT / 'figures'
RESULTS_DIR = PROJECT_ROOT / 'results'
LOG_PATH = PROJECT_ROOT / 'logs' / 'singlecell_analysis.log'

for path in (STAGING_DIR, FIG_DIR, RESULTS_DIR, LOG_PATH.parent):
    path.mkdir(parents=True, exist_ok=True)

sc.settings.figdir = str(FIG_DIR)

logger = logging.getLogger('rapids_singlecell_notebook')
logger.setLevel(logging.INFO)
if not logger.handlers:
    formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
    fh = logging.FileHandler(LOG_PATH)
    fh.setFormatter(formatter)
    sh = logging.StreamHandler()
    sh.setFormatter(formatter)
    logger.addHandler(fh)
    logger.addHandler(sh)

logger.info('Notebook session initialised with donor-only batch correction policy.')
print(f'Project root: {PROJECT_ROOT}')
print(f'Log file: {LOG_PATH}')

## Load and curate metadata
You can optionally restrict the analysis to specific donors, timepoints, or sample types by editing the filter lists below. Duplicate sample IDs (e.g., repeated timepoints) are suffixed automatically.

In [None]:
subjects_of_interest = None  # e.g. ['FSB', 'WHX']
timepoints_of_interest = None  # e.g. ['T1', 'T2']
sampletypes_of_interest = None  # e.g. ['WB', 'PBMC']

metadata = pd.read_csv(METADATA_PATH)
metadata['sample_base'] = (
    metadata['Subject'].astype(str) + '_' +
    metadata['SampleType'].astype(str) + '_' +
    metadata['TimePoint'].astype(str)
)
metadata['duplicate_idx'] = metadata.groupby('sample_base').cumcount()
metadata['sample_id'] = metadata.apply(
    lambda row: row['sample_base'] if row['duplicate_idx'] == 0 else f"{row['sample_base']}_{row['duplicate_idx']+1}",
    axis=1
)
metadata['zip_path'] = metadata['Path'].apply(lambda rel: (DATA_ROOT / rel).resolve())
metadata['exists'] = metadata['zip_path'].apply(lambda p: p.exists())

if subjects_of_interest:
    metadata = metadata[metadata['Subject'].isin(subjects_of_interest)]
if timepoints_of_interest:
    metadata = metadata[metadata['TimePoint'].isin(timepoints_of_interest)]
if sampletypes_of_interest:
    metadata = metadata[metadata['SampleType'].isin(sampletypes_of_interest)]

metadata = metadata.reset_index(drop=True)
missing = metadata[~metadata['exists']]
if not missing.empty:
    raise FileNotFoundError('Missing MEX archives:\\n' + missing[['sample_id', 'zip_path']].to_string(index=False))

logger.info('Loaded metadata for %d samples (%d unique donors).', len(metadata), metadata['Subject'].nunique())
metadata[['sample_id', 'Subject', 'SampleType', 'TimePoint', 'zip_path']].head()

## Stage MEX archives and build per-sample AnnData objects
Each archive is extracted once under `data/staged_mex/<sample_id>` and re-used on future runs. QC metrics per cell are computed immediately so filtering thresholds can be applied uniformly downstream.

In [None]:
import zipfile

def stage_mex_archive(row):
    sample_dir = STAGING_DIR / row['sample_id']
    sample_dir.mkdir(parents=True, exist_ok=True)
    matrix_file = sample_dir / 'matrix.mtx'
    gz_matrix = sample_dir / 'matrix.mtx.gz'
    if not matrix_file.exists() and not gz_matrix.exists():
        logger.info('Extracting %s to %s', row['zip_path'], sample_dir)
        with zipfile.ZipFile(row['zip_path'], 'r') as zf:
            zf.extractall(sample_dir)
    return sample_dir


def load_sample(row):
    staged_dir = stage_mex_archive(row)
    adata = sc.read_10x_mtx(staged_dir, var_names='gene_symbols', make_unique=True, gex_only=False)
    adata.var_names = adata.var_names.astype(str)
    adata.var_names_make_unique()
    if 'feature_types' in adata.var.columns:
        gex_rows = adata.var['feature_types'].astype(str) == 'Gene Expression'
        if gex_rows.any():
            adata = adata[:, gex_rows.to_numpy()].copy()
    adata.obs['subject'] = row['Subject']
    adata.obs['timepoint'] = row['TimePoint']
    adata.obs['sample_type'] = row['SampleType']
    adata.obs['sample_id'] = row['sample_id']
    adata.obs['donor_timepoint'] = row['Subject'] + '_' + row['TimePoint']
    adata.var['mt'] = adata.var_names.str.upper().str.startswith('MT-')
    sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], percent_top=None, inplace=True)
    return adata

adata_by_sample = {}
for _, row in metadata.iterrows():
    ad = load_sample(row)
    adata_by_sample[row['sample_id']] = ad
    logger.info('%s: %d cells, %d genes before QC', row['sample_id'], ad.n_obs, ad.n_vars)

len(adata_by_sample)

## Concatenate cohorts and filter cells/genes
Only cells expressing more than 400 genes are kept. Genes must be observed in more than 50 cells across the cohort. The QC summary logs retention statistics for reproducibility.

In [None]:
adata = sc.concat(adata_by_sample, label='sample_label', keys=list(adata_by_sample.keys()), join='outer', fill_value=0)
initial_cells = adata.n_obs
initial_genes = adata.n_vars

cell_filter = adata.obs['n_genes_by_counts'] > 400
adata = adata[cell_filter].copy()
filtered_cells = adata.n_obs

sc.pp.filter_genes(adata, min_cells=50)
filtered_genes = adata.n_vars

logger.info('Cell filter (>400 genes) retained %.2f%% (%d/%d).', 100*filtered_cells/initial_cells, filtered_cells, initial_cells)
logger.info('Gene filter (>50 cells) retained %d/%d genes.', filtered_genes, initial_genes)

qc_summary = {
    'initial_cells': initial_cells,
    'filtered_cells': filtered_cells,
    'initial_genes': initial_genes,
    'filtered_genes': filtered_genes,
}
qc_summary

## Per-sample QC barplots split by donor
Barplots summarise retained cells, mean detected genes, mean UMI counts, and mean mitochondrial percentage per sample. Bars are colored by donor to highlight donor-specific shifts.

In [None]:
sample_stats = adata.obs.groupby('sample_id').agg(
    cells=('sample_id', 'size'),
    mean_genes=('n_genes_by_counts', 'mean'),
    mean_counts=('total_counts', 'mean'),
    mean_pct_mt=('pct_counts_mt', 'mean'),
    subject=('subject', 'first'),
    timepoint=('timepoint', 'first')
).reset_index()

sample_stats = sample_stats.sort_values(['subject', 'timepoint', 'sample_id'])

fig, axes = plt.subplots(2, 2, figsize=(18, 10))
metrics = [
    ('cells', 'Cells retained'),
    ('mean_genes', 'Mean genes per cell'),
    ('mean_counts', 'Mean counts per cell'),
    ('mean_pct_mt', 'Mean % mitochondrial RNA')
]

for ax, (metric, title) in zip(axes.flatten(), metrics):
    sns.barplot(data=sample_stats, x='sample_id', y=metric, hue='subject', ax=ax, dodge=False)
    ax.set_title(title)
    ax.set_xlabel('Sample ID')
    ax.set_ylabel(title)
    ax.tick_params(axis='x', rotation=80)
    ax.legend(loc='upper right')

plt.tight_layout()
barplot_path = FIG_DIR / 'qc_barplots_by_donor.png'
fig.savefig(barplot_path, dpi=300, bbox_inches='tight')
plt.close(fig)
display(Image(filename=barplot_path))
logger.info('Saved per-sample QC barplots to %s', barplot_path)
sample_stats

## Remove non-coding, mitochondrial, and ribosomal genes before clustering
Unsupervised steps operate only on protein-coding genes by removing mitochondrial, ribosomal, and common non-coding gene families (e.g., lncRNAs, snoRNAs, miRNAs).

In [None]:
gene_names = adata.var_names.str.upper()
mt_mask = gene_names.str.startswith('MT-')
ribo_mask = gene_names.str.startswith(('RPS', 'RPL', 'MRPS', 'MRPL'))
noncoding_patterns = [
    r'^MIR', r'^MIRLET', r'^SNORD', r'^SNORA', r'^SCARNA', r'^SNHG', r'^RNU', r'^RNV',
    r'^LINC', r'^LOC', r'^AC[0-9]', r'^AL[0-9]', r'^RP[0-9]+-', r'^CTD-', r'^FAM', r'-AS[0-9]*$', r'-IT[0-9]*$'
]
noncoding_mask = gene_names.str.contains('|'.join(noncoding_patterns), regex=True, na=False)

coding_mask = ~(mt_mask | ribo_mask | noncoding_mask)
removed_genes = int((~coding_mask).sum())
adata = adata[:, coding_mask].copy()

logger.info('Removed %d genes (non-coding/mt/ribosomal). %d genes remain for clustering.', removed_genes, adata.n_vars)
adata

## Normalisation, HVGs, and scaling (GPU)
Normalization, log1p, and HVG selection run through `rapids_singlecell` to keep the workflow on GPU memory where possible.

In [None]:
rsc.pp.normalize_total(adata, target_sum=1e4)
rsc.pp.log1p(adata)
rsc.pp.highly_variable_genes(adata, flavor='seurat_v3', n_top_genes=3000)
adata = adata[:, adata.var['highly_variable']].copy()
rsc.pp.scale(adata, max_value=10)
logger.info('Retained %d highly variable genes after coding-only filter.', adata.n_vars)
adata

## Donor-only batch correction and embeddings
Harmony integration is run with `subject` as the batch key so only donor-level effects are removed. Timepoints remain untouched.

In [None]:
rsc.pp.pca(adata, n_comps=50)
rsc.pp.harmony_integrate(adata, key='subject', basis='X_pca')
rsc.pp.neighbors(adata, n_neighbors=20, n_pcs=40, use_rep='X_pca_harmony')
rsc.tl.umap(adata)
rsc.tl.leiden(adata, resolution=0.8, key_added='leiden_0_8')
logger.info('Computed embeddings and Leiden clusters (resolution 0.8) using coding genes only.')

## QC visualisations
All figures are saved as 300 dpi PNGs under `figures/` and displayed inline for quick review.

In [None]:
def save_and_display_current(fig, filename):
    out_path = FIG_DIR / filename
    fig.savefig(out_path, dpi=300, bbox_inches='tight')
    plt.close(fig)
    display(Image(filename=out_path))
    logger.info('Saved figure %s', out_path)

fig, axes = plt.subplots(1, 3, figsize=(18, 5))
sc.pl.violin(adata, ['n_genes_by_counts'], groupby='subject', rotation=90, ax=axes[0], show=False)
sc.pl.violin(adata, ['total_counts'], groupby='subject', rotation=90, ax=axes[1], show=False)
sc.pl.violin(adata, ['pct_counts_mt'], groupby='subject', rotation=90, ax=axes[2], show=False)
save_and_display_current(fig, 'qc_violin_by_subject.png')

ax = sc.pl.highest_expr_genes(adata, n_top=20, show=False)
fig = ax.figure if hasattr(ax, 'figure') else plt.gcf()
save_and_display_current(fig, 'qc_highest_expr_genes.png')

## UMAP views and donor/timepoint inspection

In [None]:
for color in ['subject', 'timepoint', 'sample_id', 'leiden_0_8']:
    sc.pl.umap(adata, color=color, frameon=False, show=False, save=f'_{color}')
    out_path = Path(sc.settings.figdir) / f'umap_{color}.png'
    display(Image(filename=out_path))
    logger.info('Saved figure %s', out_path)

sc.pl.umap(adata, color='subject', legend_loc='on data', frameon=False, show=False)
fig = plt.gcf()
annotated_path = FIG_DIR / 'umap_subject_annotated.png'
fig.savefig(annotated_path, dpi=300, bbox_inches='tight')
plt.close(fig)
display(Image(filename=annotated_path))
logger.info('Saved figure %s', annotated_path)

## Marker discovery
Logistic regression markers are computed per Leiden cluster. Heatmaps and dot plots summarise the leading markers.

In [None]:
rsc.tl.rank_genes_groups_logreg(adata, groupby='leiden_0_8')
sc.pl.rank_genes_groups_heatmap(adata, n_genes=8, show=False, save='_leiden_heatmap')
sc.pl.rank_genes_groups_dotplot(adata, n_genes=8, show=False, save='_leiden_dotplot')
for base in ['rank_genes_groups_heatmap_leiden_heatmap', 'rank_genes_groups_dotplot_leiden_dotplot']:
    path = Path(sc.settings.figdir) / f'{base}.png'
    if path.exists():
        display(Image(filename=path))
        logger.info('Saved figure %s', path)

## Save annotated object and tabular outputs

In [None]:
final_h5ad = RESULTS_DIR / 'rapids_singlecell_integrated.h5ad'
adata.write(final_h5ad)
cluster_csv = RESULTS_DIR / 'cluster_sizes.csv'
marker_csv = RESULTS_DIR / 'cluster_markers.csv'

adata.obs['leiden_0_8'].value_counts().sort_index().to_csv(cluster_csv, header=['n_cells'])
markers_df = sc.get.rank_genes_groups_df(adata, group=None)
markers_df.to_csv(marker_csv, index=False)

logger.info('Saved %s, %s, and %s', final_h5ad, cluster_csv, marker_csv)
print('Results saved to:', final_h5ad)

## Next steps
- Re-run clustering at alternative resolutions (e.g., 0.6 or 1.0) to test stability.
- Subset specific immune lineages and repeat HVG/UMAP to magnify subtle states.
- Feed markers into pathway enrichment or cell-type annotation tools (Azimuth, CellTypist, etc.).