# Dataset Splitting Strategies

Here we present various data splitting strategies for our datasets

In [1]:
import os.path

import dask.utils
import numpy as np
from numpy.random import SeedSequence

import dask.bag as db

In [2]:
entropy = 0x87351080e25cb0fad77a44a3be03b491
seed = np.random.SeedSequence(entropy)

In [3]:
from typing import List, Dict, Any, Union, Iterable

from matchms import Spectrum
from matchms.importing import load_from_msp
from matchms.exporting import save_as_msp


def extract_spectrum(spectrum: Spectrum) -> Dict[str, Any]:
    return {
        'mz': np.asarray(spectrum.peaks.mz),
        'intensity': np.asarray(spectrum.peaks.intensities),
        **spectrum.metadata
    }


def read_msp(files: Union[str, List[str]]) -> db.Bag:
    spectra = db.from_sequence(files).map(load_from_msp).flatten()
    spectra = spectra.map(extract_spectrum)
    return spectra

def write_msp_splits(splits: Iterable[db.Bag], filenames: Iterable[str]) -> None:
    for split, filename in zip(splits, filenames):
        save_as_msp(iter(split), filename)

In [4]:
nist = read_msp('data/NIST_EI_MS.msp')
mona = read_msp('data/MONA.msp')

### Random Split

The most trivial splitting strategy based on pure randomness.

In [None]:
from operator import getitem
from dask.base import tokenize
from dask.highlevelgraph import HighLevelGraph

def category_split(bag: db.Bag, categories):
    ncategories = len(np.unique(categories))

    token = tokenize(bag, categories)
    name = 'split-' + token
    layer = {
        (name, i): (getitem, ) for i in range(bag.npartitions)
    }

    out = []
    for i in range(ncategories):
        name2 = f'category-split-{i}-{token}'
        dask2 = {
            (name2, j): (getitem, (bag.name, i), i) for j in range(bag.npartitions)
        }
        graph = HighLevelGraph.from_collections(name2, dask2, dependencies=[bag])
        out.append(db.Bag(graph, name2, bag.npartitions))
    return out


def cat_split(bag: db.Bag, categories: db.Bag):
    bag.map_partitions(getitem, )


def random_split(data, path, random_state = None):
    index = dask.utils.pseudorandom(len(data), [.7, .15, .15], random_state=random_state)

    splits = spectra.random_split([.7, .15, .15], seed)

    write_msp_splits(splits, ['train.msp', 'test.msp', 'validation.msp'])

In [None]:
random_split(nist, 'nist_random_split', seed)
random_split(mona, 'mona_random_split', seed)

### Split A

Take the first occurrence of inchikey and put it in the training split, the rest will fall into the test split. This split will be focused on recognizing known compounds but with previously unseen fragmentation.

In [None]:
def duplicated(spectra):
    spectra['inchikey'].duplicated()

def unique_split(data, path, key = 'inchikey'):
    splits = category_split(data, category)

    write_msp_splits(splits, [''])

In [None]:
unique_split(nist, 'nist_unique_split')
unique_split(mona, 'mona_unique_split')

### Split B

In [9]:
from operator import getitem

x = db.from_sequence(range(10), npartitions=2)
y = db.from_sequence([0,0,0,0,0,1,1,1,1,1], npartitions=2)

x = x.map_partitions(lambda a: [a[y == i] for i in range(2)])
x.compute()
#x.map_partitions(getitem, y).compute()

[0, 0, 5, 5]

In [62]:
from typing import List
from dask.bag import Bag
from dask.highlevelgraph import HighLevelGraph

def unique(seq: Bag) -> Bag:
    return seq.fold(lambda acc, item: acc | {item}, set.union, initial=set())


def _select(seq, cat, sel):
    return list(a for a, b in zip(seq, cat) if b == sel)

def partition(bag: Bag, cat: Bag) -> List[Bag]:
    assert (bag.npartitions == cat.npartitions)

    out = []
    name = 'partition-' + tokenize(bag, cat)

    for sel in unique(cat).compute():
        name2 = f'{name}-{sel}'
        dsk = {
            (name2, i): (_select, (bag.name, i), (cat.name, i), sel) for i in range(bag.npartitions)
        }
        graph = HighLevelGraph.from_collections(name2, dsk, dependencies=[bag, cat])
        out.append(Bag(graph, name2, bag.npartitions))
    return out


a = db.from_sequence(range(10), npartitions=2)
b = db.from_sequence([0,1,0,1,0,1,0,1,0,1], npartitions=2)

for p in partition(a, b):
    print(p.compute())

[0, 2, 4, 6, 8]
[1, 3, 5, 7, 9]


In [98]:
def _foo(seq, cat, fact):
    return [[a for a, b in zip(seq, cat) if b == f] for f in fact]

def _unique(seq: Bag) -> Bag:
    return seq.fold(lambda acc, item: acc | {item}, set.union, initial=set())

def _split2(sequence: Bag, categories: Bag):
    name = 'split-' + tokenize(sequence, categories)
    fact = _unique(categories).compute()

    g1 = {
        (name, i): (_foo, (sequence.name, i), (categories.name, i), fact) for i in range(sequence.npartitions)
    }

    return Bag(g1, name, npartitions=sequence.npartitions)
    #[Bag(g2, n2, sequence.npartitions) for f in fact]

_split2(a, b).compute()

[[0], [], [], [], [], [], [1], [], [], []]

In [97]:
def _split(seq, cat):
    return [[a for a, b in zip(seq, cat) if b == c] for c in set(cat)]

_foo(list(range(10)), [0,1,0,1,0,1,0,1,0,1], set([0,1]))

[[0, 2, 4, 6, 8], [1, 3, 5, 7, 9]]

In [90]:
dsk = {('x', 0): (range, 5),
       ('x', 1): (range, 5),
       ('x', 2): (range, 5)}
db.Bag(dsk, 'x', npartitions=3).compute()

[0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4]

In [None]:
import numpy

from matchms.importing import load_from_msp

def random_split(data, frac, seed = None):
    data = np.asarray(data)
    frac = np.asarray(frac)

    if not numpy.allclose(sum(frac), 1):
        raise ValueError('frac should sum to 1')

    n = len(spectra)
    x = default_rng(seed).random(n)
    cp = numpy.cumsum(numpy.append([0], frac))
    indices = np.empty(n, dtype=numpy.int64)

    for i, (low, high) in enumerate(zip(cp[:-1], cp[1:])):
        indices[(x >= low) & (x < high)] = i

    return [data[(x >= low) & (x < high)] for low, high in zip(cp[:-1], cp[1:])]

In [21]:
from numpy.random import default_rng

rng = default_rng()
vals = rng.random(10)
print(vals)

[0.66759752 0.62637313 0.02760391 0.29911756 0.68579046 0.58472615
 0.35768609 0.30958816 0.8702187  0.41426965]


In [1]:
from typing import Tuple, Optional, Generator, Dict

import matchms.filtering
import matchms.importing
import numpy
import torch
import torch.nn.functional as F
from spec2vec import SpectrumDocument
from torch import Tensor
from torch.utils.data import Dataset


def process_spectrum(spectrum: Optional[matchms.Spectrum], n_required_peaks: Optional[int] = 10,
                     n_max_peaks: Optional[int] = None, min_relative_intensity: Optional[int] = None) -> Optional[
    matchms.Spectrum]:
    spectrum = matchms.filtering.select_by_mz(spectrum, mz_from=0, mz_to=1000)
    spectrum = matchms.filtering.normalize_intensities(spectrum)

    if n_required_peaks is not None:
        spectrum = matchms.filtering.require_minimum_number_of_peaks(spectrum, n_required=n_required_peaks)
    if n_max_peaks is not None:
        spectrum = matchms.filtering.reduce_to_number_of_peaks(spectrum, n_max=n_max_peaks)
    if min_relative_intensity is not None:
        spectrum = matchms.filtering.select_by_relative_intensity(spectrum, intensity_from=min_relative_intensity)
    return spectrum


def load_msp_documents(filename: str, n_max_peaks: Optional[int] = None) -> Generator[SpectrumDocument, None, None]:
    spectra = (process_spectrum(spectrum, n_max_peaks=n_max_peaks) for spectrum in
               matchms.importing.load_from_msp(filename))
    spectra = (SpectrumDocument(spectrum, n_decimals=0) for spectrum in spectra if spectrum is not None)
    return spectra


class HuggingfaceDataset(Dataset):
    def __init__(self, filename: str, vocabulary: Dict, max_length: int = 256, include_intensity: bool = False,
                 quadratic_bins: bool = False):
        super().__init__()

        self.spectrum_documents = list(load_msp_documents(filename))
        self.vocabulary = vocabulary
        self.max_length = max_length
        self.include_intensity = include_intensity

        if quadratic_bins:
            self.bins = ((numpy.arange(max_length) ** 2) / ((max_length - 1) ** 2))[::-1]
        else:
            self.bins = (numpy.arange(max_length) / (max_length - 1))[::-1]

    def __len__(self) -> int:
        return len(self.spectrum_documents)

    def __getitem__(self, index) -> Dict[str, Tensor]:
        document = self.spectrum_documents[index]
        indices = torch.argsort(document.peaks.intensities)[::-1]
        indices = indices[:self.max_length]

        x = self.encode_spectrum(torch.asarray(document.words)[indices])
        x_padded = torch.cat([x, torch.zeros(self.max_length - len(x), dtype=torch.int)])

        attention_mask = torch.zeros(self.max_length, dtype=torch.int)
        attention_mask[:len(x)] = 1

        result = {"input_ids": x_padded, "attention_mask": attention_mask}

        if self.include_intensity:
            position_ids = torch.asarray(numpy.digitize(document.peaks.intensities[indices], self.bins, right=False))
            padding = torch.zeros(self.max_length - len(position_ids), dtype=torch.int) + self.max_length - 1
            result["position_ids"] = torch.cat([position_ids, padding])
        return result

    def size(self) -> int:
        return len(self.vocabulary)

    def encode_peak(self, peak) -> Tensor:
        return torch.tensor(self.vocabulary[peak], dtype=torch.int)

    def encode_spectrum(self, spectrum) -> Tensor:
        encoded_peaks = [self.encode_peak(peak) for peak in spectrum if peak in self.vocabulary]
        return torch.tensor(encoded_peaks, dtype=torch.int)


class GenerativeDataset(Dataset):
    def __init__(self, filename: str, vocabulary: Dict, onehot: bool = True, include_intensity: bool = True):
        super().__init__()

        self.spectrum_documents = list(load_msp_documents(filename))
        self.vocabulary = vocabulary
        self.onehot = onehot
        self.include_intensity = include_intensity

    def __len__(self) -> int:
        return len(self.spectrum_documents)

    def __getitem__(self, index) -> Tuple[Tensor, Tensor]:
        indices = numpy.argsort(self.spectrum_documents[index].peaks.intensities)[::-1]
        x_indices = indices[:-1]
        y_indices = indices[1:]

        intensity = self.spectrum_documents[index].peaks.intensities[x_indices]

        x = self.encode_spectrum(self.spectrum_documents[index].words[x_indices], onehot=self.onehot)
        y = self.encode_spectrum(self.spectrum_documents[index].words[y_indices], onehot=False)

        if self.include_intensity:
            intensity = self.spectrum_documents[index].peaks.intensities[x_indices]
            x = self.encode_spec_intens(self.spectrum_documents[index].words[x_indices], intensity)
        else:
            x = self.encode_spectrum(self.spectrum_documents[index].words[x_indices])
        return x, y

    def size(self) -> int:
        return len(self.vocabulary)

    def encode_peak(self, peak) -> Tensor:
        return torch.tensor(self.vocabulary[peak], dtype=torch.int)

    def encode_spectrum(self, spectrum, intensity, onehot: bool):
        encoded_peaks = [self.encode_peak(peak) for peak in spectrum if peak in self.vocabulary]

        if onehot:
            encoded_peaks = [F.one_hot(p, num_classes=len(self.vocabulary)) for p in encoded_peaks]
        return torch.stack(encoded_peaks)

In [5]:
spectra = load_msp_documents('data/NIST_EI_MS.msp')
doc = next(spectra)

In [14]:
type(doc.peaks.mz)

numpy.ndarray