In [None]:
import os
from pathlib import Path

import scanpy as sc
import scrublet as scr
from batchvae import BatchCorrector, SampleData, SampleDataHandler
from batchvae.utils.logger import Logger

In [None]:
BASE_DIR_PATH = "/home/yuyasato/work2/Projects/vasc_aging"
OUTPUT_DIR_PATH = "/home/yuyasato/work2/Projects/vasc_aging/analysis/6.batch_removal/mouse/out/projects"
logger = Logger()

os.environ["OMP_NUM_THREADS"] = "1"  # OpenMPのスレッド数を1に設定
os.environ["MKL_NUM_THREADS"] = (
    "1"  # MKL (Intel Math Kernel Library) のスレッド数を1に設定
)
SAMPLES = {"GSM00001": "GSE00001", "GSM00002": "GSM00002"}
samples: list[SampleData] = []
for sample_id in SAMPLES:
    adata = sc.read_10x_h5(
        f"{BASE_DIR_PATH}/resources/{SAMPLES[sample_id]}/{sample_id}/counted/outs/filtered_feature_bc_matrix.h5"
    )
    adata.var_names_make_unique()
    logger.log("Simlulating doublets with Scrublet...", "info")
    # Scrubletを用い、doubletを予測
    scrub = scr.Scrublet(adata.X)
    doublet_scores, predicted_doublets = scrub.scrub_doublets()
    adata.obs["doublet_score"] = doublet_scores
    adata.obs["predicted_doublets"] = predicted_doublets
    # QC and mitochondrial gene calculations
    mt_prefix = "mt-"
    adata.var["mt"] = adata.var_names.str.startswith(mt_prefix)
    logger.log("Calculating mito...", "info")
    sc.pp.calculate_qc_metrics(
        adata, qc_vars=["mt"], percent_top=None, log1p=False, inplace=True
    )
    adata = adata[adata.obs["n_genes_by_counts"] >= 200]  # type: ignore
    adata = adata[adata.obs["n_genes_by_counts"] <= 5000]  # type: ignore
    adata = adata[adata.obs["pct_counts_mt"] <= 5]  # type: ignore
    adata = adata[adata.obs["predicted_doublets"] == False]  # type: ignore
    samples.append(SampleData(adata=adata, logger=logger))
    del adata
combined = SampleDataHandler(samples, logger=logger).combine()
combined.adata.raw = combined.adata.copy()
combined.write(f"{OUTPUT_DIR_PATH}/combined")
combined = SampleData(adata=combined.adata, logger=logger)

adata = sc.read_h5ad(f"{OUTPUT_DIR_PATH}/combined.h5ad")
adata.raw = adata.copy()
combined = SampleData(adata=adata, logger=logger)
corrected = BatchCorrector(
    combined=combined, logger=logger, batch_key="project"
).correct(dir_path=Path(f"{OUTPUT_DIR_PATH}"))