In [2]:
from __future__ import annotations

import copy
import json
import os
import pickle
import re
from re import Pattern
from typing import Any, ClassVar, Literal

import dask
import dask.bag as db
import geopandas as gpd
import numpy as np
import pandas as pd
import polars as pl
import pyopenms as oms
import rtree
from MetaMSTools.ms_tools import (
    AdductDetector,
    AdductDetectorConfig,
    FeatureFinder,
    FeatureFinderConfig,
    FeatureLinker,
    OpenMSDataWrapper,
    RTAligner,
    TICSmoother,
)
from pydantic import BaseModel, ConfigDict, Field
from shapely.geometry import box
from sqlalchemy import create_engine, text


def get_data_wrapper():
    qc_datas = OpenMSDataWrapper(
        file_paths=[
            "../data/raw_files/QC1.mzML",
            "../data/raw_files/QC2.mzML"
        ]
    )
    qc_datas.init_exps()
    qc_datas = TICSmoother()(qc_datas)
    feature_config = FeatureFinderConfig()
    feature_config.feature_finding_metabo.charge_upper_bound = 1
    feature_config.feature_finding_metabo.charge_lower_bound = 1
    qc_datas = FeatureFinder(config=feature_config)(qc_datas)
    qc_datas.infer_ref_feature_for_align()
    qc_datas = RTAligner()(qc_datas)
    qc_datas = FeatureLinker()(qc_datas)
    adduct_config = AdductDetectorConfig(
        charge_min=1,
        charge_max=1,
    )
    qc_datas = AdductDetector(
        config=adduct_config
    )(qc_datas)
    return qc_datas

In [3]:
datas = get_data_wrapper()

Progress of 'mass trace detection':
-- done [took 4.54 s (CPU), 0.27 s (Wall)] -- 
Progress of 'mass trace detection':
-- done [took 0.21 s (CPU), 0.20 s (Wall)] -- 
Progress of 'elution peak detection':
-- done [took 1.23 s (CPU), 0.04 s (Wall)] -- 
Progress of 'elution peak detection':
-- done [took 2.20 s (CPU), 0.07 s (Wall)] -- 
Progress of 'assembling mass traces to features':
Loading metabolite isotope model with 5% RMS error
-- done [took 1.63 s (CPU), 0.05 s (Wall)] -- 
Progress of 'assembling mass traces to features':
-- done [took 1.57 s (CPU), 0.05 s (Wall)] -- 
Progress of 'Linking features':
-- done [took 0.71 s (CPU), 0.02 s (Wall)] -- 
Adding neutral: ---------- Adduct -----------------
Charge: 0
Amount: 1
MassSingle: 13.9793
Formula: H-2O1
log P: -2.30259

MassExplainer table size: 92
Generating Masses with threshold: -6.90776 ...
<Loading metabolite isotope model with 5% RMS error> occurred 2 times
done
0 of 13 valid net charge compomer results did not pass the featur



In [42]:
class BaseMap(BaseModel):

    model_config = ConfigDict({"arbitrary_types_allowed": True})

    table_schema: ClassVar[dict[str, dict]] = {}

    exp_name: str = Field(
        ...,
        data_type="metadata",
        save_mode="json",
        description="实验名称"
    )
    metadata: dict = Field(
        default={},
        data_type="metadata",
        save_mode="json",
        description="数据的metadata信息。"
    )

    def __getstate__(self):
        state = copy.deepcopy(super().__getstate__())
        for k,v in state['__dict__'].items():
            if isinstance(v, pl.DataFrame):
                if any(tp == pl.Object for tp in v.dtypes):
                    state['__dict__'][k] = v.to_pandas()
        return state

    def __setstate__(self, state):
        init_func = []
        for k,f in self.model_fields.items():
            if f.annotation == pl.DataFrame or f.annotation == pl.DataFrame | None:
                if isinstance(state['__dict__'][k], pd.DataFrame):
                    state['__dict__'][k] = pl.from_pandas(state['__dict__'][k])
            elif f.annotation == rtree.index.Index | None:
                if isinstance(state['__dict__'][k], rtree.index.Index):
                    if "init_func" in f.json_schema_extra:
                        init_func.append(f.json_schema_extra["init_func"])
            elif f.annotation == rtree.index.Index:
                if "init_func" in f.json_schema_extra:
                    init_func.append(f.json_schema_extra["init_func"])
        super().__setstate__(state)
        for func_name in init_func:
            getattr(self, func_name)()

    def save(self, save_dir_path: str):

        if not os.path.exists(save_dir_path):
            os.makedirs(save_dir_path)

        metadata_path = os.path.join(save_dir_path, "metadata.json")
        index_dir_path = os.path.join(save_dir_path, "index")
        if not os.path.exists(index_dir_path):
            os.makedirs(index_dir_path)
        data_dir_path = os.path.join(save_dir_path, "data")
        if not os.path.exists(data_dir_path):
            os.makedirs(data_dir_path)
        sqlite_db_path = os.path.join(data_dir_path, "data.sqlite")
        engine = create_engine(f"sqlite:///{sqlite_db_path}")

        metadata_to_save = {"module_type": self.__class__.__name__}
        for k,f in self.model_fields.items():
            if f.json_schema_extra['data_type'] == 'metadata':
                metadata_to_save[k] = getattr(self, k)
            elif f.json_schema_extra['data_type'] == 'index':
                if isinstance(getattr(self, k), rtree.index.Index):
                    rtree_save_path = os.path.join(index_dir_path, k)
                    if os.path.exists(rtree_save_path + ".dat"):
                        os.remove(rtree_save_path + ".dat")
                    if os.path.exists(rtree_save_path + ".idx"):
                        os.remove(rtree_save_path + ".idx")
                    tree:rtree.index.Index = getattr(self, f.json_schema_extra['build_func'])(rtree_save_path)
                    tree.close()
                elif isinstance(getattr(self, k), pd.Index):
                    index_save_path = os.path.join(index_dir_path, k+".csv")
                    pd.Series(getattr(self, k)).to_csv(index_save_path, header=False)
                elif isinstance(getattr(self, k), gpd.GeoDataFrame):
                    getattr(self, k).to_parquet(os.path.join(index_dir_path, k+".parquet"))
                else:
                    other_index_save_path = os.path.join(index_dir_path, k+".pkl")
                    with open(other_index_save_path, 'wb') as f:
                        pickle.dump(getattr(self, k), f)
            elif f.json_schema_extra['data_type'] == 'data':
                if f.json_schema_extra['save_mode'] == 'sqlite':
                    data = getattr(self, k)
                    if isinstance(data, pl.DataFrame):
                        with engine.connect() as conn:
                            data.write_database(table_name=k, connection=conn, if_table_exists="replace")
                    elif isinstance(data, pd.DataFrame):
                        with engine.connect() as conn:
                            data.to_sql(k, conn, if_exists="replace")
                    else:
                        raise ValueError(f"Unsupported data type to save as sqlite: {type(data)}")
                else:
                    other_data_save_path = os.path.join(data_dir_path, k+".pkl")
                    with open(other_data_save_path, 'wb') as f:
                        pickle.dump(getattr(self, k), f)

        with open(metadata_path, 'w') as f:
            json.dump(metadata_to_save, f)

        engine.dispose()

    @classmethod
    def _base_load(cls, save_dir_path: str) -> dict[str, Any]:

        data_dict = {}

        metadata_path = os.path.join(save_dir_path, "metadata.json")
        index_dir_path = os.path.join(save_dir_path, "index")
        data_dir_path = os.path.join(save_dir_path, "data")

        if not os.path.exists(metadata_path):
            raise ValueError(f"Metadata file not found in {save_dir_path}")
        with open(metadata_path) as f:
            metadata:dict = json.load(f)
            exp_name = metadata.pop('exp_name')
            metadata.pop('module_type')
        data_dict['exp_name'] = exp_name
        data_dict['metadata'] = metadata

        if os.path.exists(index_dir_path):
            for k,f in cls.model_fields.items():
                if f.json_schema_extra['data_type'] == 'index':
                    if f.json_schema_extra['save_mode'] == 'rtree':
                        rtree_save_path = os.path.join(index_dir_path, k)
                        if os.path.exists(rtree_save_path + ".dat") and os.path.exists(rtree_save_path + ".idx"):
                            data_dict[k] = rtree.index.Index(rtree_save_path)
                    elif f.annotation == pd.Index or f.annotation == pd.Index | None:
                        index_save_path = os.path.join(index_dir_path, k+".csv")
                        if os.path.exists(index_save_path):
                            data_dict[k] = pd.Index(pd.read_csv(index_save_path, header=None, index_col=0).iloc[:,0])
                    elif f.annotation == gpd.GeoDataFrame or f.annotation == gpd.GeoDataFrame | None:
                        index_save_path = os.path.join(index_dir_path, k+".parquet")
                        if os.path.exists(index_save_path):
                            data_dict[k] = gpd.read_parquet(index_save_path)
                    else:
                        other_index_save_path = os.path.join(index_dir_path, k+".pkl")
                        if os.path.exists(other_index_save_path):
                            with open(other_index_save_path, 'rb') as f:
                                data_dict[k] = pickle.load(f)

        if os.path.exists(data_dir_path):
            sqlite_db_path = os.path.join(data_dir_path, "data.sqlite")
            if os.path.exists(sqlite_db_path):
                engine = create_engine(f"sqlite:///{sqlite_db_path}")
                for k,f in cls.model_fields.items():
                    if f.json_schema_extra['data_type'] == 'data':
                        if f.json_schema_extra['save_mode'] == 'sqlite':
                            with engine.connect() as conn:
                                if conn.execute(text(
                                    f"SELECT name \
                                        FROM sqlite_master \
                                        WHERE type='table' \
                                        AND name='{k}'"
                                )).fetchone() is not None:
                                    if f.annotation == pl.DataFrame or f.annotation == pl.DataFrame | None:
                                        if k in cls.table_schema:
                                            table_schema = cls.table_schema[k]
                                        else:
                                            table_schema = None
                                        data_dict[k] = pl.read_database(
                                            query=f"SELECT * FROM {k}", connection=conn,
                                            schema_overrides=table_schema
                                        )
                                    elif f.annotation == pd.DataFrame or f.annotation == pd.DataFrame | None:
                                        data_dict[k] = pd.read_sql_query(f"SELECT * FROM {k}", conn)
                        else:
                            other_data_save_path = os.path.join(data_dir_path, k+".pkl")
                            if os.path.exists(other_data_save_path):
                                with open(other_data_save_path, 'rb') as f:
                                    data_dict[k] = pickle.load(f)
                engine.dispose()

        return data_dict

    @classmethod
    def load(cls, save_dir_path: str):

        return cls(**cls._base_load(save_dir_path))


In [47]:
class SpectrumMap(BaseMap):

    scan_id_matcher: ClassVar[Pattern] = re.compile(r'scan=(\d+)')
    table_schema: ClassVar[dict[str, dict]] = {
        "ms1_df": {
            "spec_id":pl.String,
            "rt":pl.Float32,
        },
        "ms2_df": {
            "spec_id":pl.String,
            "rt":pl.Float32,
            "precursor_mz":pl.Float32,
            "base_peak_mz":pl.Float32,
            "base_peak_intensity":pl.Float32,
        }
    }

    ms1_index: gpd.GeoDataFrame | None = Field(
        default=None,
        data_type="index",
        save_mode="parquet",
        description="MS1谱图的空间索引表，基于geopandas"
    )
    ms1_df: pl.DataFrame | None = Field(
        default=None,
        data_type="data",
        save_mode="sqlite",
        description="MS1谱图的DataFrame，基于polars"
    )
    ms2_index: gpd.GeoDataFrame | None = Field(
        default=None,
        data_type="index",
        save_mode="parquet",
        description="MS2谱图的空间索引表，基于geopandas"
    )
    ms2_df: pl.DataFrame | None = Field(
        default=None,
        data_type="data",
        save_mode="sqlite",
        description="MS2谱图的DataFrame，基于polars"
    )

    @staticmethod
    def get_exp_meta(exp: oms.MSExperiment) -> dict[str, str]:
        spec: oms.MSSpectrum = exp[0]
        meta_info_string = spec.getMetaValue("filter string")
        meta_info_list = meta_info_string.split(" ")
        ms_type = meta_info_list[0]
        ion_mode = meta_info_list[1]
        ion_source = meta_info_list[3]
        return {
            "ms_type": ms_type,
            "ion_mode": ion_mode,
            "ion_source": ion_source,
        }

    @staticmethod
    def get_scan_index(spec: oms.MSSpectrum) -> int:
        scan_id_match = SpectrumMap.scan_id_matcher.search(spec.getNativeID())
        if scan_id_match:
            return int(scan_id_match.group(1))
        else:
            raise ValueError(
                f"Cannot extract scan index from \
                spectrum native ID: {spec.getNativeID()}"
            )

    @staticmethod
    def ms2spec2dfdict(spec: oms.MSSpectrum) -> dict[
        Literal[
            "spec_id",
            "rt",
            "precursor_mz",
            "base_peak_mz",
            "base_peak_intensity",
            "mz_array",
            "intensity_array",
        ],
        int | float | np.ndarray
    ]:
        spec_id = SpectrumMap.get_scan_index(spec)
        rt = spec.getRT()
        precursor_mz = spec.getPrecursors()[0].getMZ()
        base_peak_mz = spec.getMetaValue("base peak m/z")
        base_peak_intensity = spec.getMetaValue("base peak intensity")
        mz_array, intensity_array = spec.get_peaks()
        return {
            "spec_id": spec_id,
            "rt": rt,
            "precursor_mz": precursor_mz,
            "base_peak_mz": base_peak_mz,
            "base_peak_intensity": base_peak_intensity,
            "mz_array": mz_array.tolist(),
            "intensity_array": intensity_array.tolist(),
        }

    @staticmethod
    def ms1spec2dfdict(spec: oms.MSSpectrum) -> dict[
        Literal[
            "spec_id",
            "rt",
            "mz_array",
            "intensity_array",
        ],
        int | float | np.ndarray
    ]:
        spec_id = SpectrumMap.get_scan_index(spec)
        rt = spec.getRT()
        mz_array, intensity_array = spec.get_peaks()
        return {
            "spec_id": spec_id,
            "rt": rt,
            "mz_array": mz_array.tolist(),
            "intensity_array": intensity_array.tolist(),
        }

    def insert_ms1_id_to_ms2(self) -> None:
        '''
        如果MS2谱图没有对应的MS1谱图ID，则插入null
        '''
        if self.ms1_df is None or self.ms2_df is None:
            raise ValueError(
                "MS1 and MS2 dataframes must be loaded \
                    before inserting MS1 IDs to MS2 dataframe"
            )
        ms1_df_mapping = self.ms1_df.with_columns(
            pl.col('spec_id').alias('ms1_id')
        ).select(['spec_id','ms1_id'])
        self.ms2_df = self.ms2_df.join_asof(
            ms1_df_mapping,
            left_on='spec_id',
            right_on='spec_id',
            strategy='backward'
        )

    def convert_scan_to_spec_id(self) -> None:
        if self.ms1_df.schema['spec_id'] in (
            pl.Int8, pl.Int16, pl.Int32, pl.Int64, pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64
        ):
            self.ms1_df = self.ms1_df.with_columns(
                (f"{self.exp_name}::ms1::" + self.ms1_df['spec_id'].cast(str)).alias('spec_id')
            )
            self.ms1_index.index = pd.Index(self.ms1_df['spec_id'].to_list())
        if self.ms2_df.schema['spec_id'] in (
            pl.Int8, pl.Int16, pl.Int32, pl.Int64, pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64
        ):
            self.ms2_df = self.ms2_df.with_columns(
                (f"{self.exp_name}::ms2::" + self.ms2_df['spec_id'].cast(str)).alias('spec_id')
            )
            self.ms2_index.index = pd.Index(self.ms2_df['spec_id'].to_list())
        if self.ms2_df.schema['ms1_id'] in (
            pl.Int8, pl.Int16, pl.Int32, pl.Int64, pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64
        ):
            self.ms2_df = self.ms2_df.with_columns(
                (f"{self.exp_name}::ms1::" + self.ms2_df['ms1_id'].cast(str)).alias('ms1_id')
            )

    def modify_ms2_rt(self) -> None:
        ms1_rt_df = self.ms1_df.select(['spec_id', 'rt']).rename({'rt': 'ms1_rt','spec_id':'ms1_id'})
        joined_df = self.ms2_df.join(ms1_rt_df, on='ms1_id', how='left')
        self.ms2_df = joined_df.with_columns(
            pl.when(pl.col('ms1_id').is_not_null())
            .then(pl.col('ms1_rt'))
            .otherwise(pl.col('rt'))
            .alias('rt')
        ).drop('ms1_rt')

    def search_ms1_by_range(
        self,
        coordinates: tuple[
            float, # min_rt
            float, # max_rt
        ],
        return_type: Literal["id", "indices", "df"] = "id",
    ) -> list[int] | list[str] | pl.DataFrame:
        coordinates = (coordinates[0], -1, coordinates[1], 1)
        iloc = list(self.ms1_index.sindex.intersection(coordinates))
        if return_type == "id":
            return self.ms1_index.index[iloc].tolist()
        elif return_type == "df":
            return self.ms1_df[iloc]
        else:
            return iloc

    def search_ms2_by_range(
        self,
        coordinates: tuple[
            float, # min_rt
            float, # min_mz
            float, # max_rt
            float, # max_mz
        ],
        return_type: Literal["id", "indices", "df"] = "id",
    ) -> list[int] | list[str] | pl.DataFrame:
        iloc = list(self.ms2_index.sindex.intersection(coordinates))
        if return_type == "id":
            return self.ms2_index.index[iloc].tolist()
        elif return_type == "df":
            return self.ms2_df[iloc]
        else:
            return iloc

    @classmethod
    def from_oms(
        cls,
        exp: oms.MSExperiment,
        exp_name: str,
        worker_type: Literal["threads", "processes", "synchronous"] = "threads",
        num_workers: int | None = None,
    ) -> SpectrumMap:
        spec_bag = db.from_sequence(exp,npartitions=num_workers)
        ms1_bag = spec_bag.filter(lambda x: x.getMSLevel() == 1)
        ms2_bag = spec_bag.filter(lambda x: x.getMSLevel() == 2)
        ms1_bag = ms1_bag.map(cls.ms1spec2dfdict)
        ms2_bag = ms2_bag.map(cls.ms2spec2dfdict)
        ms1,ms2 = dask.compute(ms1_bag,ms2_bag,scheduler=worker_type,num_workers=num_workers)
        ms1_df = pl.DataFrame(ms1,schema={
            "spec_id":pl.Int32,
            "rt":pl.Float32,
            "mz_array":pl.List(pl.Float32),
            "intensity_array":pl.List(pl.Float32),
        })
        ms1_df = ms1_df.with_columns(
            (pl.col('rt') / 60.0),
        )
        ms1_index = gpd.GeoDataFrame(
            {"iloc":range(len(ms1_df))},
            index=ms1_df['spec_id'],
            geometry=gpd.points_from_xy(
                x=ms1_df['rt'],
                y=[0] * len(ms1_df),
            )
        )
        ms2_df = pl.DataFrame(ms2,schema={
            "spec_id":pl.Int32,
            "rt":pl.Float32,
            "precursor_mz":pl.Float32,
            "base_peak_mz":pl.Float32,
            "base_peak_intensity":pl.Float32,
            "mz_array":pl.List(pl.Float32),
            "intensity_array":pl.List(pl.Float32),
        })
        ms2_df = ms2_df.with_columns(
            (pl.col('rt') / 60.0),
        )
        ms2_index = gpd.GeoDataFrame(
            {"iloc":range(len(ms2_df))},
            index=ms2_df['spec_id'],
            geometry=gpd.points_from_xy(
                x=ms2_df['rt'],
                y=ms2_df['precursor_mz'],
            )
        )
        metadata = cls.get_exp_meta(exp)
        spectrum_map = cls(
            exp_name=exp_name,
            metadata=metadata,
            ms1_index=ms1_index,
            ms1_df=ms1_df,
            ms2_index=ms2_index,
            ms2_df=ms2_df,
        )
        spectrum_map.insert_ms1_id_to_ms2()
        spectrum_map.convert_scan_to_spec_id()
        spectrum_map.modify_ms2_rt()
        return spectrum_map

    def save(self, save_dir_path: str):

        self_to_save = copy.copy(self)
        self_to_save.ms1_df = self.ms1_df.with_columns(
            ("[" + pl.col("mz_array").cast(pl.List(pl.String)).list.join(",") + "]")
            .alias("mz_array"),
            ("[" + pl.col("intensity_array").cast(pl.List(pl.String)).list.join(",") + "]")
            .alias("intensity_array"),
        )
        self_to_save.ms2_df = self.ms2_df.with_columns(
            ("[" + pl.col("mz_array").cast(pl.List(pl.String)).list.join(",") + "]")
            .alias("mz_array"),
            ("[" + pl.col("intensity_array").cast(pl.List(pl.String)).list.join(",") + "]")
            .alias("intensity_array"),
        )

        super(SpectrumMap, self_to_save).save(save_dir_path)

    @classmethod
    def load(cls, save_dir_path: str):

        data_dict = cls._base_load(save_dir_path)

        if 'ms1_df' in data_dict:
            if isinstance(data_dict['ms1_df'], pl.DataFrame):
                data_dict['ms1_df'] = data_dict['ms1_df'].with_columns(
                    pl.col("mz_array")
                        .str.strip_chars_start("[")
                        .str.strip_chars_end("]")
                        .str.split(",")
                        .cast(pl.List(pl.Float32)),
                    pl.col("intensity_array")
                        .str.strip_chars_start("[")
                        .str.strip_chars_end("]")
                        .str.split(",")
                        .cast(pl.List(pl.Float32)),
                )
        if 'ms2_df' in data_dict:
            if isinstance(data_dict['ms2_df'], pl.DataFrame):
                data_dict['ms2_df'] = data_dict['ms2_df'].with_columns(
                    pl.col("mz_array")
                        .str.strip_chars_start("[")
                        .str.strip_chars_end("]")
                        .str.split(",")
                        .cast(pl.List(pl.Float32)),
                    pl.col("intensity_array")
                        .str.strip_chars_start("[")
                        .str.strip_chars_end("]")
                        .str.split(",")
                        .cast(pl.List(pl.Float32)),
                )

        return cls(**data_dict)

In [48]:
spectrum_map = SpectrumMap.from_oms(datas.exps[0], datas.exp_names[0])

In [49]:
spectrum_map.save("../cache/test_spectrum_map")

In [50]:
reload_spectrum_map = SpectrumMap.load("../cache/test_spectrum_map")

In [51]:
reload_spectrum_map: SpectrumMap = pickle.loads(pickle.dumps(reload_spectrum_map))

In [52]:
reload_spectrum_map.ms1_index

Unnamed: 0,iloc,geometry
QC1.mzML::ms1::797,0,POINT (3.34399 0)
QC1.mzML::ms1::802,1,POINT (3.36396 0)
QC1.mzML::ms1::807,2,POINT (3.384 0)
QC1.mzML::ms1::812,3,POINT (3.40474 0)
QC1.mzML::ms1::817,4,POINT (3.4255 0)
...,...,...
QC1.mzML::ms1::2423,329,POINT (9.90593 0)
QC1.mzML::ms1::2428,330,POINT (9.92663 0)
QC1.mzML::ms1::2433,331,POINT (9.94589 0)
QC1.mzML::ms1::2438,332,POINT (9.9659 0)


In [53]:
reload_spectrum_map.ms2_index

Unnamed: 0,iloc,geometry
QC1.mzML::ms2::795,0,POINT (3.335 173.128)
QC1.mzML::ms2::796,1,POINT (3.339 233.128)
QC1.mzML::ms2::801,2,POINT (3.359 207.985)
QC1.mzML::ms2::804,3,POINT (3.37 224.128)
QC1.mzML::ms2::808,4,POINT (3.386 131.118)
...,...,...
QC1.mzML::ms2::2434,713,POINT (9.947 279.093)
QC1.mzML::ms2::2436,714,POINT (9.957 371.316)
QC1.mzML::ms2::2440,715,POINT (9.972 224.128)
QC1.mzML::ms2::2441,716,POINT (9.977 233.128)


In [54]:
reload_spectrum_map.ms1_df

spec_id,rt,mz_array,intensity_array
str,f32,list[f32],list[f32]
"""QC1.mzML::ms1::797""",3.343986,"[200.09166, 200.128098, … 794.977478]","[4236.865234, 2695.840576, … 1839.659058]"
"""QC1.mzML::ms1::802""",3.363965,"[200.091904, 200.128159, … 782.615112]","[1270.613525, 7574.313477, … 2615.272705]"
"""QC1.mzML::ms1::807""",3.384004,"[200.005188, 200.070435, … 796.565796]","[2496.52124, 759.796326, … 1426.792603]"
"""QC1.mzML::ms1::812""",3.40474,"[200.091904, 200.128113, … 798.653381]","[5781.839844, 971.905701, … 1681.213257]"
"""QC1.mzML::ms1::817""",3.425502,"[200.005081, 200.061249, … 781.128113]","[0.0, 6932.186035, … 1800.135132]"
…,…,…,…
"""QC1.mzML::ms1::2423""",9.905934,"[200.12822, 200.18399, … 798.510986]","[0.0, 13750.636719, … 2882.806641]"
"""QC1.mzML::ms1::2428""",9.92663,"[200.091843, 200.113205, … 799.597229]","[3245.281738, 0.0, … 2757.355957]"
"""QC1.mzML::ms1::2433""",9.945886,"[200.127975, 200.200912, … 799.996826]","[556.130371, 23764.205078, … 2381.471436]"
"""QC1.mzML::ms1::2438""",9.965904,"[200.09169, 200.128021, … 799.945312]","[0.0, 10572.25293, … 1917.658203]"


In [55]:
reload_spectrum_map.ms2_df

spec_id,rt,precursor_mz,base_peak_mz,base_peak_intensity,mz_array,intensity_array,ms1_id
str,f32,f32,f32,f32,list[f32],list[f32],str
"""QC1.mzML::ms2::795""",3.334743,173.128479,173.128403,20729.001953,[200.683197],[2170.664795],
"""QC1.mzML::ms2::796""",3.339177,233.128433,174.091309,106759.648438,"[212.950302, 216.101776, 233.128708]","[3391.868408, 8426.791016, 4821.996582]",
"""QC1.mzML::ms2::801""",3.343986,207.985245,184.969223,129272.453125,"[202.979721, 207.985367, 208.133255]","[2566.662598, 7957.563965, 19053.580078]","""QC1.mzML::ms1::797"""
"""QC1.mzML::ms2::804""",3.363965,224.128128,165.054581,57772.613281,"[203.77977, 224.128128, 229.219147]","[2704.05127, 19318.791016, 2518.687744]","""QC1.mzML::ms1::802"""
"""QC1.mzML::ms2::808""",3.384004,131.11792,90.947571,41532.6875,[207.170685],[2081.906982],"""QC1.mzML::ms1::807"""
…,…,…,…,…,…,…,…
"""QC1.mzML::ms2::2434""",9.945886,279.093353,219.05687,33164.667969,"[201.0457, 219.05687, 252.445023]","[2686.182373, 33164.667969, 2422.131592]","""QC1.mzML::ms1::2433"""
"""QC1.mzML::ms2::2436""",9.945886,371.315826,147.065063,56023.519531,"[241.179855, 259.189697, … 355.069611]","[5598.549805, 2642.418945, … 2885.548584]","""QC1.mzML::ms1::2433"""
"""QC1.mzML::ms2::2440""",9.965904,224.128128,155.974701,34200.113281,[224.128311],[11006.429688],"""QC1.mzML::ms1::2438"""
"""QC1.mzML::ms2::2441""",9.965904,233.128494,174.091278,65971.070312,"[216.101868, 233.127686]","[5839.043945, 4289.237305]","""QC1.mzML::ms1::2438"""


In [56]:
reload_spectrum_map.search_ms2_by_range(
    (3,200,4,250)
)

['QC1.mzML::ms2::941',
 'QC1.mzML::ms2::851',
 'QC1.mzML::ms2::801',
 'QC1.mzML::ms2::915',
 'QC1.mzML::ms2::895',
 'QC1.mzML::ms2::864',
 'QC1.mzML::ms2::844',
 'QC1.mzML::ms2::925',
 'QC1.mzML::ms2::804',
 'QC1.mzML::ms2::884',
 'QC1.mzML::ms2::824',
 'QC1.mzML::ms2::904',
 'QC1.mzML::ms2::944',
 'QC1.mzML::ms2::796',
 'QC1.mzML::ms2::821',
 'QC1.mzML::ms2::910',
 'QC1.mzML::ms2::880',
 'QC1.mzML::ms2::841']

In [57]:
class FeatureMap(BaseMap):

    table_schema: ClassVar[dict[str, dict]] = {
        "feature_info": {
                "RT": pl.Float32,
                "mz": pl.Float32,
                "intensity": pl.Float32,
                "MZstart": pl.Float32,
                "RTstart": pl.Float32,
                "MZend": pl.Float32,
                "RTend": pl.Float32,
                "hull_num": pl.Int8,
        },
        "hull_info": {
            "RTstart": pl.Float32,
            "RTend": pl.Float32,
            "MZstart": pl.Float32,
            "MZend": pl.Float32,
        }
    }

    feature_index: gpd.GeoDataFrame | None = Field(
        default=None,
        data_type="index",
        save_mode="parquet",
        description="Feature的空间索引表，基于geopandas"
    )
    feature_info: pl.DataFrame | None = Field(
        default=None,
        data_type="data",
        save_mode="sqlite",
        description="Feature信息表，基于polars"
    )
    hull_index: gpd.GeoDataFrame | None = Field(
        default=None,
        data_type="index",
        save_mode="parquet",
        description="Hull的空间索引表，基于geopandas"
    )
    hull_info: pl.DataFrame | None = Field(
        default=None,
        data_type="data",
        save_mode="sqlite",
        description="Hull信息表，基于polars"
    )

    @staticmethod
    def get_feature_metadata(feature: oms.Feature) -> dict[
        Literal[
            'hull_num',"hull_mz","hull_rt","hull_intensity",
            "isotope_pattern",
            "adduct_type","adduct_mass",
        ],
        str | float | int | list[float]
    ]:
        all_keys = []
        feature.getKeys(all_keys)
        all_keys = set(all_keys)
        metadata = {
            "hull_num": feature.getMetaValue("num_of_masstraces"),
            "hull_mz": feature.getMetaValue("masstrace_centroid_mz"),
            "hull_rt": feature.getMetaValue("masstrace_centroid_rt"),
            "hull_intensity": feature.getMetaValue("masstrace_intensity"),
            "isotope_pattern": feature.getMetaValue("isotope_distances"),
        }
        if "dc_charge_adducts" in all_keys:
            metadata["adduct_type"] = feature.getMetaValue("dc_charge_adducts")
            metadata["adduct_mass"] = feature.getMetaValue("dc_charge_adduct_mass")
        return metadata

    @staticmethod
    def get_feature_info(
        feature_map: oms.FeatureMap,
        worker_type: Literal["threads", "processes", "synchronous"] = "threads",
        num_workers: int | None = None,
    ) -> pl.DataFrame:
        feature_info = feature_map.get_df()[
            ["RT","mz","intensity","MZstart","RTstart","MZend","RTend"]
        ]
        feature_info.index.name = "feature_id"
        feature_info = pl.from_pandas(
            feature_info,
            schema_overrides = {
                "RT": pl.Float32,
                "mz": pl.Float32,
                "intensity": pl.Float32,
                "MZstart": pl.Float32,
                "RTstart": pl.Float32,
                "MZend": pl.Float32,
                "RTend": pl.Float32,
            },
            include_index=True
        )
        feature_info = feature_info.with_columns(
            (pl.col("RT") / 60).alias("RT"),
            (pl.col("RTstart") / 60).alias("RTstart"),
            (pl.col("RTend") / 60).alias("RTend"),
        )
        feature_bag = db.from_sequence(feature_map, npartitions=num_workers)
        feature_metadata_bag = feature_bag.map(FeatureMap.get_feature_metadata)
        feature_metadata_list = dask.compute(
            feature_metadata_bag, scheduler=worker_type, num_workers=num_workers
        )[0]
        feature_metadata_df = pl.DataFrame(
            feature_metadata_list,
            schema_overrides={
                "hull_num": pl.Int8,
                "hull_mz": pl.List(pl.Float32),
                "hull_rt": pl.List(pl.Float32),
                "hull_intensity": pl.List(pl.Float32),
                "isotope_pattern": pl.List(pl.Float32),
            }
        )
        feature_metadata_df = feature_metadata_df.with_columns(
            pl.col("isotope_pattern").list.eval(pl.element().cum_sum()),
            pl.col("hull_rt").list.eval(pl.element() / 60),
        )
        if "adduct_mass" in feature_metadata_df.columns:
            feature_metadata_df = feature_metadata_df.with_columns(
                pl.col("adduct_mass").cast(pl.Float32),
            )
        feature_info = pl.concat([feature_info, feature_metadata_df], how="horizontal")
        return feature_info

    @staticmethod
    def get_hulls(
        feature_map: oms.FeatureMap,
        feature_xic: dict[dict[oms.MSChromatogram]],
    ) -> pl.DataFrame:
        rt_hulls = {}
        for feature_rt_hulls in feature_xic:
            for rt_hull in feature_rt_hulls:
                rt_hulls[rt_hull.getNativeID().replace("_","::")] = rt_hull
        mz_hulls = {}
        for feature in feature_map:
            for i,mz_hull in enumerate(feature.getConvexHulls()):
                mz_hulls[f"{feature.getUniqueId()}::{i}"] = mz_hull
        hulls = []
        hulls_id = list(mz_hulls.keys())
        for hull_id in hulls_id:
            hull = {}
            hull['hull_id'] = hull_id
            rt_points, intens_points = rt_hulls[hull_id].get_peaks()
            mz_points = mz_hulls[hull_id].getHullPoints()[:,1][:len(rt_points)]
            hull['rt_points'] = rt_points.tolist()
            hull['mz_points'] = mz_points.tolist()
            hull['intens_points'] = intens_points.tolist()
            hulls.append(hull)
        hull_info = pl.DataFrame(
            hulls,
            schema_overrides={
                'rt_points': pl.List(pl.Float32),
                'mz_points': pl.List(pl.Float32),
                'intens_points': pl.List(pl.Float32),
            }
        )
        hull_info = hull_info.with_columns(
            pl.col("rt_points").list.eval(pl.element() / 60),
        )
        hull_info = hull_info.with_columns(
            pl.col("rt_points").list.min().alias("RTstart"),
            pl.col("rt_points").list.max().alias("RTend"),
            pl.col("mz_points").list.min().alias("MZstart"),
            pl.col("mz_points").list.max().alias("MZend"),
        )
        hull_info = hull_info.select(
            "hull_id",
            "RTstart",
            "RTend",
            "MZstart",
            "MZend",
            "rt_points",
            "mz_points",
            "intens_points",
        )
        return hull_info

    @classmethod
    def from_oms(
        cls,
        feature_map: oms.FeatureMap,
        feature_xic: list[list[oms.MSChromatogram]],
        exp_name: str,
        worker_type: Literal["threads", "processes", "synchronous"] = "threads",
        num_workers: int | None = None,
    ) -> FeatureMap:
        feature_info = cls.get_feature_info(feature_map, worker_type, num_workers)
        hull_info = cls.get_hulls(feature_map, feature_xic)
        feature_info = feature_info.with_columns(
            (f"{exp_name}::" + pl.col("feature_id").cast(str)).alias("feature_id"),
        )
        feature_index = gpd.GeoDataFrame(
            {"iloc": range(len(feature_info))},
            index=feature_info["feature_id"].to_list(),
            geometry=[
                box(rt_start, mz_start, rt_end, mz_end) \
                    for rt_start, mz_start, rt_end, mz_end in zip(
                        feature_info["RTstart"],
                        feature_info["MZstart"],
                        feature_info["RTend"],
                        feature_info["MZend"],
                    )
            ],
        )
        hull_info = hull_info.with_columns(
            (f"{exp_name}::" + pl.col("hull_id").cast(str)).alias("hull_id"),
        )
        hull_index = gpd.GeoDataFrame(
            {"iloc": range(len(hull_info))},
            index=hull_info["hull_id"].to_list(),
            geometry=[
                box(rt_start, mz_start, rt_end, mz_end) \
                    for rt_start, mz_start, rt_end, mz_end in zip(
                        hull_info["RTstart"],
                        hull_info["MZstart"],
                        hull_info["RTend"],
                        hull_info["MZend"],
                    )
            ]
        )
        return cls(
            exp_name=exp_name,
            feature_info=feature_info,
            hull_info=hull_info,
            feature_index=feature_index,
            hull_index=hull_index,
        )

    def get_oms_feature_map(self) -> oms.FeatureMap:
        feature_map = oms.FeatureMap()
        feature_info = self.feature_info.select(
            "feature_id","mz", "RT", "intensity",
        ).with_columns(
            pl.col("feature_id").str.split("::").list.get(1).cast(pl.Int128).alias("feature_id"),
        )
        for i in range(len(feature_info)):
            feature = oms.Feature()
            feature.setUniqueId(feature_info[i, "feature_id"])
            feature.setMZ(feature_info[i, "mz"])
            feature.setRT(feature_info[i, "RT"])
            feature.setIntensity(feature_info[i, "intensity"])
            feature_map.push_back(feature)
        return feature_map

    def search_feature_by_range(
        self,
        coordinates: tuple[
            float, # min_rt
            float, # min_mz
            float, # max_rt
            float, # max_mz
        ],
        return_type: Literal["id", "indices", "df"] = "id",
    ) -> list[int] | list[str] | pl.DataFrame:
        iloc = list(self.feature_index.sindex.intersection(coordinates))
        if return_type == "id":
            return self.feature_index[iloc].index.tolist()
        elif return_type == "df":
            return self.feature_info[iloc]
        else:
            return iloc

    def search_hull_by_range(
        self,
        coordinates: tuple[
            float, # min_rt
            float, # min_mz
            float, # max_rt
            float, # max_mz
        ],
        return_type: Literal["id", "indices", "df"] = "id",
    ) -> list[int] | list[str] | pl.DataFrame:
        iloc = list(self.hull_index.sindex.intersection(coordinates))
        if return_type == "id":
            return self.hull_index[iloc].index.tolist()
        elif return_type == "df":
            return self.hull_info[iloc]
        else:
            return iloc

    def save(self, save_dir_path: str):

        self_to_save = copy.copy(self)

        self_to_save.feature_info = self.feature_info.with_columns(
            ("[" + pl.col("hull_mz").cast(pl.List(pl.String)).list.join(",") + "]")
            .alias("hull_mz"),
            ("[" + pl.col("hull_rt").cast(pl.List(pl.String)).list.join(",") + "]")
            .alias("hull_rt"),
            ("[" + pl.col("hull_intensity").cast(pl.List(pl.String)).list.join(",") + "]")
            .alias("hull_intensity"),
            ("[" + pl.col("isotope_pattern").cast(pl.List(pl.String)).list.join(",") + "]")
            .alias("isotope_pattern"),
        )

        self_to_save.hull_info = self.hull_info.with_columns(
            ("[" + pl.col("rt_points").cast(pl.List(pl.String)).list.join(",") + "]")
            .alias("rt_points"),
            ("[" + pl.col("mz_points").cast(pl.List(pl.String)).list.join(",") + "]")
            .alias("mz_points"),
            ("[" + pl.col("intens_points").cast(pl.List(pl.String)).list.join(",") + "]")
            .alias("intens_points"),
        )

        super(FeatureMap, self_to_save).save(save_dir_path)

    @classmethod
    def load(cls, save_dir_path: str):

        data_dict = cls._base_load(save_dir_path)

        feature_info: pl.DataFrame | None = data_dict.pop("feature_info")

        if feature_info is not None:
            data_dict['feature_info'] = feature_info.with_columns(
                pl.col("hull_mz")
                    .str.strip_chars_start("[")
                    .str.strip_chars_end("]")
                    .str.split(",")
                    .cast(pl.List(pl.Float32)),
                pl.col("hull_rt")
                    .str.strip_chars_start("[")
                    .str.strip_chars_end("]")
                    .str.split(",")
                    .cast(pl.List(pl.Float32)),
                pl.col("hull_intensity")
                    .str.strip_chars_start("[")
                    .str.strip_chars_end("]")
                    .str.split(",")
                    .cast(pl.List(pl.Float32)),
                pl.col("isotope_pattern")
                    .str.strip_chars_start("[")
                    .str.strip_chars_end("]")
                    .str.split(",")
                    .cast(pl.List(pl.Float32)),
            )

        hull_info: pl.DataFrame | None = data_dict.pop("hull_info")
        if hull_info is not None:
            data_dict['hull_info'] = hull_info.with_columns(
                pl.col("rt_points")
                    .str.strip_chars_start("[")
                    .str.strip_chars_end("]")
                    .str.split(",")
                    .cast(pl.List(pl.Float32)),
                pl.col("mz_points")
                    .str.strip_chars_start("[")
                    .str.strip_chars_end("]")
                    .str.split(",")
                    .cast(pl.List(pl.Float32)),
                pl.col("intens_points")
                    .str.strip_chars_start("[")
                    .str.strip_chars_end("]")
                    .str.split(",")
                    .cast(pl.List(pl.Float32)),
            )

        return cls(**data_dict)

In [58]:
feature_map = FeatureMap.from_oms(datas.features[0],datas.chromatogram_peaks[0],datas.exp_names[0])

In [59]:
feature_map.save("../cache/test_feature_map")

In [60]:
reload_feature_map = FeatureMap.load("../cache/test_feature_map")

In [61]:
reload_feature_map:FeatureMap = pickle.loads(pickle.dumps(reload_feature_map))

In [62]:
reload_feature_map.feature_index

Unnamed: 0,iloc,geometry
QC1.mzML::537262113478123113,0,"POLYGON ((3.65 445.902, 3.65 446.903, 3.344 44..."
QC1.mzML::2409176774723997744,1,"POLYGON ((4.077 252.182, 4.077 253.186, 3.813 ..."
QC1.mzML::5674215280029449342,2,"POLYGON ((4.117 344.227, 4.117 345.232, 3.69 3..."
QC1.mzML::15249434478318674630,3,"POLYGON ((4.282 300.203, 4.282 301.207, 4.017 ..."
QC1.mzML::900057232785487999,4,"POLYGON ((4.261 297.191, 4.261 298.194, 4.017 ..."
...,...,...
QC1.mzML::4271180125471554430,70,"POLYGON ((9.459 579.064, 9.459 580.068, 9.113 ..."
QC1.mzML::15269413475962407845,71,"POLYGON ((9.624 671.788, 9.624 672.793, 9.278 ..."
QC1.mzML::14031656565992301066,72,"POLYGON ((9.644 598.412, 9.644 599.416, 9.418 ..."
QC1.mzML::13765281633924752481,73,"POLYGON ((9.765 705.811, 9.765 706.818, 9.563 ..."


In [63]:
reload_feature_map.hull_index

Unnamed: 0,iloc,geometry
QC1.mzML::537262113478123113::0,0,"POLYGON ((3.65 445.902, 3.65 445.903, 3.344 44..."
QC1.mzML::537262113478123113::1,1,"POLYGON ((3.548 446.902, 3.548 446.903, 3.344 ..."
QC1.mzML::2409176774723997744::0,2,"POLYGON ((4.077 252.182, 4.077 252.182, 3.813 ..."
QC1.mzML::2409176774723997744::1,3,"POLYGON ((3.995 253.185, 3.995 253.186, 3.813 ..."
QC1.mzML::5674215280029449342::0,4,"POLYGON ((4.117 344.227, 4.117 344.229, 3.69 3..."
...,...,...
QC1.mzML::14031656565992301066::1,150,"POLYGON ((9.603 599.415, 9.603 599.416, 9.501 ..."
QC1.mzML::13765281633924752481::0,151,"POLYGON ((9.765 705.811, 9.765 705.812, 9.563 ..."
QC1.mzML::13765281633924752481::1,152,"POLYGON ((9.765 706.811, 9.765 706.818, 9.563 ..."
QC1.mzML::9380827912261487001::0,153,"POLYGON ((9.765 209.19, 9.765 209.19, 9.664 20..."


In [64]:
reload_feature_map.feature_info

feature_id,RT,mz,intensity,MZstart,RTstart,MZend,RTend,hull_num,hull_mz,hull_rt,hull_intensity,isotope_pattern
str,f32,f32,f32,f32,f32,f32,f32,i8,list[f32],list[f32],list[f32],list[f32]
"""QC1.mzML::537262113478123113""",3.588297,445.902374,104382.570312,445.90213,3.343986,446.902557,3.650401,2,"[445.902374, 446.9021]","[3.588297, 3.363965]","[104382.570312, 26270.710938]",[0.999718]
"""QC1.mzML::2409176774723997744""",3.852764,252.1819,1.5391e6,252.181808,3.812734,253.185608,4.076713,2,"[252.1819, 253.185181]","[3.852764, 3.852764]","[1.5391e6, 321779.25]",[1.003272]
"""QC1.mzML::5674215280029449342""",3.893466,344.228027,197418.84375,344.226501,3.69046,345.231903,4.116995,2,"[344.228027, 345.231506]","[3.893466, 3.953598]","[197418.84375, 47032.886719]",[1.003477]
"""QC1.mzML::15249434478318674630""",4.05662,300.203064,1.7284e6,300.202606,4.01652,301.206665,4.281793,2,"[300.203064, 301.205994]","[4.05662, 4.05662]","[1.7284e6, 125906.328125]",[1.002912]
"""QC1.mzML::900057232785487999""",4.076713,297.190918,225549.453125,297.190704,4.01652,298.194336,4.261061,2,"[297.190918, 298.194275]","[4.076713, 4.05662]","[225549.453125, 35371.0625]",[1.003371]
…,…,…,…,…,…,…,…,…,…,…,…,…
"""QC1.mzML::4271180125471554430""",9.195253,579.064331,73000.851562,579.063843,9.11296,580.06781,9.459101,2,"[579.064331, 580.067444]","[9.195253, 9.277564]","[73000.851562, 34937.773438]",[1.003106]
"""QC1.mzML::15269413475962407845""",9.418377,671.789124,112240.289062,671.78772,9.277564,672.792969,9.623575,2,"[671.789124, 672.79187]","[9.418377, 9.357579]","[112240.289062, 42289.277344]",[1.002782]
"""QC1.mzML::14031656565992301066""",9.542112,598.411926,55715.152344,598.41156,9.418377,599.415588,9.64362,2,"[598.411926, 599.415222]","[9.542112, 9.582781]","[55715.152344, 25916.025391]",[1.003269]
"""QC1.mzML::13765281633924752481""",9.602751,705.811768,54921.792969,705.810608,9.56284,706.818054,9.764957,2,"[705.811768, 706.814697]","[9.602751, 9.623575]","[54921.792969, 30861.908203]",[1.002939]


In [65]:
reload_feature_map.hull_info

hull_id,RTstart,RTend,MZstart,MZend,rt_points,mz_points,intens_points
str,f32,f32,f32,f32,list[f32],list[f32],list[f32]
"""QC1.mzML::537262113478123113::…",3.343986,3.650401,445.90213,445.902527,"[3.343986, 3.363965, … 3.650401]","[445.902344, 445.902405, … 445.902191]","[5323.933105, 5364.649902, … 2042.94458]"
"""QC1.mzML::537262113478123113::…",3.343986,3.547573,446.90155,446.902557,"[3.343986, 3.363965, … 3.547573]","[446.902557, 446.90213, … 446.90155]","[2132.606689, 2963.083252, … 1389.333374]"
"""QC1.mzML::2409176774723997744:…",3.812734,4.076713,252.181808,252.181946,"[3.812734, 3.832, … 4.076713]","[252.181808, 252.181915, … 252.18187]","[6225.435547, 203993.890625, … 3754.619873]"
"""QC1.mzML::2409176774723997744:…",3.812734,3.995065,253.185089,253.185608,"[3.812734, 3.832, … 3.995065]","[253.185226, 253.185211, … 253.185272]","[3768.085938, 27921.669922, … 2633.637695]"
"""QC1.mzML::5674215280029449342:…",3.69046,4.116995,344.226501,344.228577,"[3.69046, 3.711202, … 4.116995]","[344.226501, 344.226898, … 344.227692]","[2444.303955, 2652.002441, … 2965.388428]"
…,…,…,…,…,…,…,…
"""QC1.mzML::14031656565992301066…",9.500566,9.602751,599.414551,599.415588,"[9.500566, 9.582781, 9.602751]","[599.415588, 599.415283, 599.414551]","[3903.819336, 4733.27002, 2966.22876]"
"""QC1.mzML::13765281633924752481…",9.56284,9.764957,705.810608,705.812317,"[9.56284, 9.582781, … 9.764957]","[705.812012, 705.811584, … 705.812317]","[6845.885254, 6800.776367, … 2179.937744]"
"""QC1.mzML::13765281633924752481…",9.56284,9.764957,706.811401,706.818054,"[9.56284, 9.582781, … 9.764957]","[706.815308, 706.814514, … 706.818054]","[3681.091797, 1325.098511, … 1915.797119]"
"""QC1.mzML::9380827912261487001:…",9.664292,9.764957,209.189941,209.190094,"[9.664292, 9.684225, … 9.764957]","[209.190002, 209.189972, … 209.190094]","[33363.902344, 53031.917969, … 5013.371582]"


In [155]:
rebuild_datas = OpenMSDataWrapper(features=[
    FeatureMap.from_oms(f,x,en).get_oms_feature_map() \
    for f,x,en in zip(datas.features,datas.chromatogram_peaks,datas.exp_names)
])
rebuild_datas.exp_names = datas.exp_names
rebuild_datas = FeatureLinker()(rebuild_datas)

Progress of 'Linking features':
-- done [took 0.01 s (CPU), 0.01 s (Wall)] -- 


In [156]:
rebuild_datas.consensus_map.get_df()

Unnamed: 0_level_0,sequence,charge,RT,mz,quality,QC1.mzML,QC2.mzML
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
0,,0,517.366455,225.148521,0.998806,5.911864e+05,6.568321e+05
0,,0,360.167831,207.159065,0.935903,5.035236e+06,2.195314e+06
0,,0,359.011963,207.985222,0.934963,9.938316e+05,1.047103e+06
0,,0,507.315598,293.109055,0.990750,9.651988e+04,1.042219e+05
0,,0,404.847656,695.425873,0.557731,1.026271e+05,4.537626e+06
...,...,...,...,...,...,...,...
0,,0,293.642609,225.131454,0.000000,0.000000e+00,2.846717e+05
0,,0,401.348358,755.347107,0.000000,0.000000e+00,1.737777e+05
0,,0,446.029358,618.372375,0.000000,0.000000e+00,5.882982e+06
0,,0,496.929535,670.389038,0.000000,0.000000e+00,4.167085e+05


In [253]:
def infer_sub_hull_from_feature(
    df_feat: pl.DataFrame,
    df_hull: pl.DataFrame,
    *,
    feature_id_col: str = "feature_id",
    num_hull_col: str = "hull_num",
    hull_id_col: str = "hull_id",
) -> list[pl.DataFrame]:
    df_ids = (
        df_feat
        .with_columns(
            idx=pl.int_ranges(0, pl.col(num_hull_col))
        )
        .explode("idx")
        .with_columns(
            hull_id=pl.format("{}::{}", pl.col(feature_id_col), pl.col("idx"))
        )
    )
    merged = df_ids.join(
        df_hull,
        left_on="hull_id",
        right_on=hull_id_col,
        how="left",
    )
    sub_map = (
        merged.group_by(feature_id_col)
        .agg(pl.all())                # 每组的全部列 -> 列表
        .partition_by(feature_id_col, as_dict=False)
    )
    id_order = df_feat[feature_id_col].to_list()
    id_to_df = {df[feature_id_col][0]: df.drop(feature_id_col) for df in sub_map}
    return [id_to_df.get(fid, pl.DataFrame()).select(
                ["hull_id", "RTstart", "RTend", "MZstart", "MZend",
                "rt_points", "mz_points", "intens_points"]
            ).explode(
                ["hull_id", "RTstart", "RTend", "MZstart", "MZend",
                "rt_points", "mz_points", "intens_points"]) \
            for fid in id_order]

def link_ms2_to_feature(
    feature_hulls: pd.DataFrame | pl.DataFrame,
    spectrum_map: SpectrumMap
) -> list[str]:
    spectrum_id_list = []
    for mz_start,rt_start,mz_end,rt_end in zip(
        feature_hulls['MZstart'],
        feature_hulls['RTstart'],
        feature_hulls['MZend'],
        feature_hulls['RTend'],
    ):
        spectrum_id_list += spectrum_map.search_ms2_by_range(
            (rt_start,mz_start,rt_end,mz_end),"id"
        )
    return spectrum_id_list

def link_ms2_and_feature_map(
    feature_map: FeatureMap,
    spectrum_map: SpectrumMap,
    key_id: Literal["feature","spectrum"] = "feature",
    worker_type: Literal["threads","processes","synchronous"] = "threads",
    num_workers: int | None = None,
) -> pl.DataFrame:
    hull_info_queue = infer_sub_hull_from_feature(feature_map.feature_info, feature_map.hull_info)
    hull_info_bag = db.from_sequence(hull_info_queue, npartitions=num_workers)
    spectrum_id_bag = hull_info_bag.map(
        lambda x: link_ms2_to_feature(x,spectrum_map)
    )
    spectrum_id_list = dask.compute(
        spectrum_id_bag, scheduler=worker_type, num_workers=num_workers
    )[0]
    mapping_df = pl.DataFrame(
        data = {
            "feature_id": feature_map.feature_info['feature_id'],
            "spectrum_id": spectrum_id_list,
        },
        schema=pl.Schema({
            "feature_id": pl.String,
            "spectrum_id": pl.List(pl.String),
        })
    )
    if key_id == "spectrum":
        mapping_df = mapping_df.explode("spectrum_id").filter(
            pl.col("spectrum_id").is_not_null()
        ).select(["spectrum_id", "feature_id"])
    return mapping_df

In [250]:
infer_sub_hull_from_feature(feature_map.feature_info, feature_map.hull_info)[0]

hull_id,RTstart,RTend,MZstart,MZend,rt_points,mz_points,intens_points
str,f32,f32,f32,f32,list[f32],list[f32],list[f32]
"""QC1.mzML::13005910100464212996…",3.343986,3.650401,445.90213,446.902557,"[3.343986, 3.363965, … 3.650401]","[445.902344, 445.902405, … 445.902191]","[5323.933105, 5364.649902, … 2042.94458]"
"""QC1.mzML::13005910100464212996…",3.343986,3.650401,445.90213,446.902557,"[3.343986, 3.363965, … 3.547573]","[446.902557, 446.90213, … 446.90155]","[2132.606689, 2963.083252, … 1389.333374]"


In [251]:
link_ms2_and_feature_map(feature_map, spectrum_map, "feature")

feature_id,spectrum_id
str,list[str]
"""QC1.mzML::13005910100464212996""",[]
"""QC1.mzML::14115125470287761880""","[""QC1.mzML::ms2::923"", ""QC1.mzML::ms2::946"", … ""QC1.mzML::ms2::946""]"
"""QC1.mzML::10693648182463709255""","[""QC1.mzML::ms2::936"", ""QC1.mzML::ms2::936""]"
"""QC1.mzML::8958118472990820462""","[""QC1.mzML::ms2::995"", ""QC1.mzML::ms2::973"", … ""QC1.mzML::ms2::973""]"
"""QC1.mzML::15745383017486200183""","[""QC1.mzML::ms2::976"", ""QC1.mzML::ms2::976""]"
…,…
"""QC1.mzML::14833318846184236006""",[]
"""QC1.mzML::17235000915360007629""",[]
"""QC1.mzML::4346243706513282941""",[]
"""QC1.mzML::10497263946829528981""",[]


In [254]:
link_ms2_and_feature_map(feature_map, spectrum_map, "spectrum")

spectrum_id,feature_id
str,str
"""QC1.mzML::ms2::923""","""QC1.mzML::14115125470287761880"""
"""QC1.mzML::ms2::946""","""QC1.mzML::14115125470287761880"""
"""QC1.mzML::ms2::923""","""QC1.mzML::14115125470287761880"""
"""QC1.mzML::ms2::946""","""QC1.mzML::14115125470287761880"""
"""QC1.mzML::ms2::936""","""QC1.mzML::10693648182463709255"""
…,…
"""QC1.mzML::ms2::2094""","""QC1.mzML::492909883837939811"""
"""QC1.mzML::ms2::2109""","""QC1.mzML::492909883837939811"""
"""QC1.mzML::ms2::2094""","""QC1.mzML::492909883837939811"""
"""QC1.mzML::ms2::2370""","""QC1.mzML::14614210394166893117"""


In [66]:
class XICMap(BaseMap):

    table_schema: ClassVar[dict[str,dict]] = {
        "ion_df": {
            "mz": pl.Float32,
            "rt": pl.Float32,
            "i": pl.Float32,
        }
    }

    ion_index: gpd.GeoDataFrame | None = Field(
        default=None,
        data_type="index",
        save_mode="parquet",
        description="离子流的空间索引，基于geopandas",
    )
    ion_df: pl.DataFrame | None = Field(
        default=None,
        data_type="data",
        save_mode="sqlite",
        description="离子流的空间数据，基于polars",
    )

    @classmethod
    def from_oms(
        cls,
        exp: oms.MSExperiment,
        exp_name: str,
    ) -> XICMap:
        ion_df = pl.from_pandas(
            exp.get_massql_df()[0][["mz","rt","i"]],
            schema_overrides = {
                "mz": pl.Float32,
                "rt": pl.Float32,
                "i": pl.Float32,
            }
        )
        ion_index = gpd.GeoDataFrame(
            {"iloc": range(len(ion_df))},
            geometry=gpd.points_from_xy(
                ion_df['rt'],
                ion_df['mz'],
                ion_df['i'],
            )
        )
        return cls(
            exp_name=exp_name,
            ion_index=ion_index,
            ion_df=ion_df,
        )

    def search_ion_by_range(
        self,
        coordinates: tuple[
            float, # min_rt
            float, # min_mz
            float, # max_rt
            float, # max_mz
        ],
        return_type: Literal["id", "indices", "df"] = "id",
    ) -> list[int] | list[str] | pl.DataFrame:
        iloc = list(self.ion_index.sindex.intersection(coordinates))
        if return_type == "id":
            return self.ion_index[iloc].index.tolist()
        elif return_type == "df":
            return self.ion_df[iloc]
        else:
            return iloc

In [67]:
xic_map = XICMap.from_oms(datas.exps[0],datas.exp_names[0])

In [68]:
xic_map.save("../cache/test_xic_map")

In [69]:
reload_xic_map = XICMap.load("../cache/test_xic_map")

In [70]:
reload_xic_map: XICMap = pickle.loads(pickle.dumps(reload_xic_map))

In [71]:
reload_xic_map.ion_index

Unnamed: 0,iloc,geometry
0,0,POINT Z (3.344 200.092 4236.865)
1,1,POINT Z (3.344 200.128 2695.841)
2,2,POINT Z (3.344 200.173 6285.766)
3,3,POINT Z (3.344 200.184 11543.618)
4,4,POINT Z (3.344 200.201 15958.018)
...,...,...
259199,259199,POINT Z (9.986 797.68 2920.09)
259200,259200,POINT Z (9.986 797.905 2668.783)
259201,259201,POINT Z (9.986 798.187 2438.909)
259202,259202,POINT Z (9.986 798.975 2471.288)


In [72]:
reload_xic_map.ion_df

mz,rt,i
f32,f32,f32
200.09166,3.343986,4236.865234
200.128098,3.343986,2695.840576
200.172775,3.343986,6285.765625
200.183792,3.343986,11543.618164
200.200943,3.343986,15958.017578
…,…,…
797.680359,9.985891,2920.090332
797.904968,9.985891,2668.782715
798.186646,9.985891,2438.909424
798.975037,9.985891,2471.288086


In [73]:
reload_xic_map.search_ion_by_range(
    (3.5-0.25,250-0.1,3.5+0.25,250+0.1),"df"
)

mz,rt,i
f32,f32,f32
249.970184,3.446201,3814.350342
249.970428,3.343986,1605.962036
249.943741,3.40474,4938.681152
249.970322,3.40474,1874.907959
249.970169,3.486936,679.727905
…,…,…
249.970123,3.73187,3304.230469
249.970169,3.711202,1800.305664
249.970291,3.69046,4475.595703
249.970306,3.670388,1380.609131


In [21]:
class ConsensusMap(BaseModel):

    model_config = ConfigDict({"arbitrary_types_allowed": True})

    consensus_df: pd.DataFrame
    consensus_feature_mapping: pd.Series
    feature_consensus_mapping: pd.Series

    @classmethod
    def from_oms(cls, consensus_map: oms.ConsensusMap) -> ConsensusMap:
        raw_consensus_df = consensus_map.get_df()
        exp_names = raw_consensus_df.columns[5:][::-1]
        consensus_bag = db.from_sequence(consensus_map)
        feature_id_bag = consensus_bag.map(
            lambda x: \
                [
                    f.getUniqueId() if f.getUniqueId() is str \
                    else f"{exp_names[f.getMapIndex()]}::{f.getUniqueId()}" \
                    for f in x.getFeatureList()
                ]
        )
        consensus_df = consensus_map.get_df().iloc[:,2:].reset_index(drop=True)
        consensus_df.index.name = "consensus_id"
        consensus_feature_mapping = pd.Series(
            feature_id_bag.compute(scheduler="threads"),
            index=consensus_df.index
        )
        consensus_feature_mapping.name = "feature_ids"
        consensus_feature_mapping.index.name = "consensus_id"
        feature_consensus_mapping = {}
        for cid,fids in consensus_feature_mapping.items():
            for fid in fids:
                feature_consensus_mapping[fid] = cid
        feature_consensus_mapping = pd.Series(feature_consensus_mapping)
        feature_consensus_mapping.name = "consensus_id"
        feature_consensus_mapping.index.name = "feature_id"
        return cls(
            consensus_df=consensus_df,
            consensus_feature_mapping=consensus_feature_mapping,
            feature_consensus_mapping=feature_consensus_mapping
        )

    def as_oms_feature_map(self) -> oms.FeatureMap:
        feature_map = oms.FeatureMap()
        for i,row in self.consensus_df.iterrows():
            feature = oms.Feature()
            feature.setUniqueId(i)
            feature.setMZ(row["mz"])
            feature.setRT(row["RT"])
            feature.setIntensity(row.iloc[3:].max())
            feature_map.push_back(feature)
        return feature_map

In [22]:
consensus_map = ConsensusMap.from_oms(datas.consensus_map)

In [23]:
consensus_map.consensus_df

Unnamed: 0_level_0,RT,mz,quality,QC2.mzML,QC1.mzML
consensus_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,404.847657,695.425855,0.557743,4.537626e+06,1.026271e+05
1,576.286209,209.189986,0.952328,1.552605e+05,1.354698e+05
2,360.167820,207.159064,0.935903,2.195314e+06,5.035236e+06
3,349.334946,533.324174,0.840550,7.844999e+04,1.233872e+05
4,507.315594,293.109054,0.990750,1.042219e+05,9.651988e+04
...,...,...,...,...,...
126,446.029355,725.372814,0.000000,2.473700e+05,0.000000e+00
127,446.029355,709.356876,0.000000,6.133417e+05,0.000000e+00
128,446.029355,741.365603,0.000000,2.869066e+05,0.000000e+00
129,446.029355,618.372397,0.000000,5.882982e+06,0.000000e+00


In [24]:
consensus_map.consensus_feature_mapping

consensus_id
0      [QC1.mzML::1058058667861826189, QC2.mzML::5829...
1      [QC1.mzML::5198541186238866932, QC2.mzML::9748...
2      [QC1.mzML::6356764604209267192, QC2.mzML::1215...
3      [QC1.mzML::8102842054760010345, QC2.mzML::1834...
4      [QC1.mzML::10946711822686169604, QC2.mzML::526...
                             ...                        
126                     [QC2.mzML::15378851355750419626]
127                     [QC2.mzML::17722064519600937218]
128                     [QC2.mzML::18049518165761848085]
129                     [QC2.mzML::18163286508655570769]
130                     [QC2.mzML::18444431267099520080]
Name: feature_ids, Length: 131, dtype: object

In [25]:
consensus_map.feature_consensus_mapping

feature_id
QC1.mzML::1058058667861826189       0
QC2.mzML::5829911669225560868       0
QC1.mzML::5198541186238866932       1
QC2.mzML::974875078852663642        1
QC1.mzML::6356764604209267192       2
                                 ... 
QC2.mzML::15378851355750419626    126
QC2.mzML::17722064519600937218    127
QC2.mzML::18049518165761848085    128
QC2.mzML::18163286508655570769    129
QC2.mzML::18444431267099520080    130
Name: consensus_id, Length: 139, dtype: int64

In [26]:
consensus_map.get_oms_feature_map().get_df()

Unnamed: 0_level_0,peptide_sequence,peptide_score,ID_filename,ID_native_id,charge,RT,mz,RTstart,RTend,MZstart,MZend,quality,intensity
feature_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
0,,,,,0,404.847657,695.425855,1.797693e+308,-1.797693e+308,1.797693e+308,-1.797693e+308,0.0,4.537626e+06
1,,,,,0,576.286209,209.189986,1.797693e+308,-1.797693e+308,1.797693e+308,-1.797693e+308,0.0,1.552605e+05
2,,,,,0,360.167820,207.159064,1.797693e+308,-1.797693e+308,1.797693e+308,-1.797693e+308,0.0,5.035236e+06
3,,,,,0,349.334946,533.324174,1.797693e+308,-1.797693e+308,1.797693e+308,-1.797693e+308,0.0,1.233872e+05
4,,,,,0,507.315594,293.109054,1.797693e+308,-1.797693e+308,1.797693e+308,-1.797693e+308,0.0,1.042219e+05
...,...,...,...,...,...,...,...,...,...,...,...,...,...
126,,,,,0,446.029355,725.372814,1.797693e+308,-1.797693e+308,1.797693e+308,-1.797693e+308,0.0,2.473700e+05
127,,,,,0,446.029355,709.356876,1.797693e+308,-1.797693e+308,1.797693e+308,-1.797693e+308,0.0,6.133417e+05
128,,,,,0,446.029355,741.365603,1.797693e+308,-1.797693e+308,1.797693e+308,-1.797693e+308,0.0,2.869066e+05
129,,,,,0,446.029355,618.372397,1.797693e+308,-1.797693e+308,1.797693e+308,-1.797693e+308,0.0,5.882982e+06
