In [1]:
from geneformer.tokenizer import TOKEN_DICTIONARY_FILE

from __future__ import annotations

import logging
import pickle
import warnings
from pathlib import Path
from typing import Literal

import anndata as ad
import numpy as np
import scipy.sparse as sp
from datasets import Dataset

warnings.filterwarnings("ignore", message=".*The 'nopython' keyword.*")  # noqa
import loompy as lp  # noqa

logger = logging.getLogger(__name__)

import loompy as lp
import numpy as np

In [2]:
def rank_genes(gene_vector, gene_tokens):
    """
    Rank gene expression vector.
    """
    # sort by median-scaled gene values
    sorted_indices = np.argsort(-gene_vector)
    return gene_tokens[sorted_indices]


def tokenize_ind(gene_vector, gene_tokens):
    """
    Convert normalized gene expression vector to tokenized rank value encoding.
    """
    # create array of gene vector with token indices
    # mask undetected genes
    nonzero_mask = np.nonzero(gene_vector)[0]
    # rank by median-scaled gene values
    return rank_genes(gene_vector[nonzero_mask], gene_tokens[nonzero_mask])


class ProteomicsTokenizer:
    def __init__(
        self,
        custom_attr_name_dict=None,
        nproc=1,
        chunk_size=512,
        model_input_size=2048,
        special_token=False,
        # gene_median_file=GENE_MEDIAN_FILE,
        token_dictionary_file=TOKEN_DICTIONARY_FILE,
    ):
        """
        Initialize tokenizer.

        **Parameters:**

        custom_attr_name_dict : None, dict
            | Dictionary of custom attributes to be added to the dataset.
            | Keys are the names of the attributes in the loom file.
            | Values are the names of the attributes in the dataset.
        nproc : int
            | Number of processes to use for dataset mapping.
        chunk_size : int = 512
            | Chunk size for anndata tokenizer.
        model_input_size : int = 2048
            | Max input size of model to truncate input to.
        special_token : bool = False
            | Adds CLS token before and SEP token after rank value encoding.
        # gene_median_file : Path
        #     | Path to pickle file containing dictionary of non-zero median
        #     | gene expression values across Genecorpus-30M.
        token_dictionary_file : Path
            | Path to pickle file containing token dictionary (Ensembl IDs:token).

        """
        # dictionary of custom attributes {output dataset column name: input .loom column name}
        self.custom_attr_name_dict = custom_attr_name_dict

        # number of processes for dataset mapping
        self.nproc = nproc

        # chunk size for anndata tokenizer
        self.chunk_size = chunk_size

        # input size for tokenization
        self.model_input_size = model_input_size

        # add CLS and SEP tokens
        self.special_token = special_token

        # load dictionary of gene normalization factors
        # (non-zero median value of expression across Genecorpus-30M)
        # with open(gene_median_file, "rb") as f:
        #     self.gene_median_dict = pickle.load(f)

        # load token dictionary (Ensembl IDs:token)
        with open(token_dictionary_file, "rb") as f:
            self.gene_token_dict = pickle.load(f)

        # gene keys for full vocabulary
        self.gene_keys = list(self.gene_token_dict.keys())

        # protein-coding and miRNA gene list dictionary for selecting .loom rows for tokenization
        self.genelist_dict = dict(zip(self.gene_keys, [True] * len(self.gene_keys)))

    def tokenize_loom(self, loom_file_path, target_sum=10_000):
        if self.custom_attr_name_dict is not None:
            file_ind_metadata = {
                attr_key: [] for attr_key in self.custom_attr_name_dict.keys()
            }

        with lp.connect(str(loom_file_path)) as data:
            # define coordinates of detected protein-coding or miRNA genes and vector of their normalization factors

            coding_miRNA_loc = np.where(
                [self.genelist_dict.get(i, False) for i in data.ra["ensembl_id"]]
            )[0]

            # norm_factor_vector = np.array(
            #     [
            #         self.gene_median_dict[i]
            #         for i in data.ra["ensembl_id"][coding_miRNA_loc]
            #     ]
            # )
            coding_miRNA_ids = data.ra["ensembl_id"][coding_miRNA_loc]

            not_in_gene_ids = set(data.ra["ensembl_id"]) - set(self.gene_keys)
            print(
                f"{len(not_in_gene_ids)} genes not in gene token dictionary, skipping them, some are: {list(not_in_gene_ids)[:5]}"
            )

            coding_miRNA_tokens = np.array(
                [self.gene_token_dict[i] for i in coding_miRNA_ids]
            )

            # define coordinates of individual passing filters for inclusion (e.g. QC)
            try:
                data.ca["filter_pass"]
            except AttributeError:
                var_exists = False
            else:
                var_exists = True

            if var_exists:
                filter_pass_loc = np.where([i == 1 for i in data.ca["filter_pass"]])[0]
            elif not var_exists:
                print(
                    f"{loom_file_path} has no column attribute 'filter_pass'; tokenizing all inds."
                )
                filter_pass_loc = np.array([i for i in range(data.shape[1])])

            # scan through .loom files and tokenize inds
            tokenized_ind = []
            for _ix, _selection, view in data.scan(
                items=filter_pass_loc, axis=1, batch_size=self.chunk_size
            ):
                # select subview with protein-coding and miRNA genes
                subview = view.view[coding_miRNA_loc, :]
                # Currently do not norm ,as the values is NPX by UKB

                # tokenize subview gene vectors
                tokenized_ind += [
                    tokenize_ind(subview[:, i], coding_miRNA_tokens)
                    for i in range(subview.shape[1])
                ]

                # add custom attributes for subview to dict
                if self.custom_attr_name_dict is not None:
                    for k in file_ind_metadata.keys():
                        file_ind_metadata[k] += subview.ca[k].tolist()
                else:
                    file_ind_metadata = None

        return tokenized_ind, file_ind_metadata

    def create_dataset(
        self,
        tokenized_inds,
        ind_metadata,
        use_generator=False,
        keep_uncropped_input_ids=False,
    ):
        print("Creating dataset.")
        # create dict for dataset creation
        dataset_dict = {"input_ids": tokenized_inds}
        if self.custom_attr_name_dict is not None:
            dataset_dict.update(ind_metadata)

        # create dataset
        if use_generator:

            def dict_generator():
                for i in range(len(tokenized_inds)):
                    yield {k: dataset_dict[k][i] for k in dataset_dict.keys()}

            output_dataset = Dataset.from_generator(dict_generator, num_proc=self.nproc)
        else:
            output_dataset = Dataset.from_dict(dataset_dict)

        def format_ind_features(example):
            # Store original uncropped input_ids in separate feature
            if keep_uncropped_input_ids:
                example["input_ids_uncropped"] = example["input_ids"]
                example["length_uncropped"] = len(example["input_ids"])

            # Truncate/Crop input_ids to input size
            if self.special_token:
                example["input_ids"] = example["input_ids"][
                    0 : self.model_input_size - 2
                ]  # truncate to leave space for CLS and SEP token
                example["input_ids"] = np.insert(
                    example["input_ids"], 0, self.gene_token_dict.get("<cls>")
                )
                example["input_ids"] = np.insert(
                    example["input_ids"],
                    len(example["input_ids"]),
                    self.gene_token_dict.get("<sep>"),
                )
            else:
                # Truncate/Crop input_ids to input size
                example["input_ids"] = example["input_ids"][0 : self.model_input_size]
            example["length"] = len(example["input_ids"])

            return example

        output_dataset_truncated = output_dataset.map(
            format_ind_features, num_proc=self.nproc
        )
        return output_dataset_truncated

In [5]:
loom_file_path = "2_train_imputed.loom"
outputpath = "2_Data/GeneFormer/imputed/train"
Path(outputpath).mkdir(parents=True, exist_ok=True)

proteomics_tokenizer = ProteomicsTokenizer(
    {"incident_cad": "incident_cad", "eid": "eid"},
    model_input_size=2048,
    special_token=False,
)  # TODO: model_input_size may be larger if it is ok; special_token=True if we want to add CLS and SEP tokens
tokenized_ind, file_ = proteomics_tokenizer.tokenize_loom(loom_file_path)  # toknize
output_dataset_truncated = proteomics_tokenizer.create_dataset(
    tokenized_ind, file_
)  # create dataset
output_dataset_truncated.save_to_disk(outputpath)  # save to disk

13 genes not in gene token dictionary, skipping them, some are: ['ENSG00000248546', 'ENSG00000204936', 'ENSG00000275841', 'ENSG00000221957', 'ENSG00000275960']
2_train_imputed.loom has no column attribute 'filter_pass'; tokenizing all inds.
Creating dataset.


Map:   0%|          | 0/40806 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/40806 [00:00<?, ? examples/s]

In [7]:
loom_file_path = "2_test_imputed.loom"
outputpath = "2_Data/GeneFormer/imputed/test"
Path(outputpath).mkdir(parents=True, exist_ok=True)

proteomics_tokenizer = ProteomicsTokenizer(
    {"incident_cad": "incident_cad"}, model_input_size=2048, special_token=False
)  # TODO: model_input_size may be larger if it is ok; special_token=True if we want to add CLS and SEP tokens
tokenized_ind, file_ = proteomics_tokenizer.tokenize_loom(loom_file_path)  # toknize
output_dataset_truncated = proteomics_tokenizer.create_dataset(
    tokenized_ind, file_
)  # create dataset
output_dataset_truncated.save_to_disk(outputpath)  # save to disk

13 genes not in gene token dictionary, skipping them, some are: ['ENSG00000248546', 'ENSG00000204936', 'ENSG00000275841', 'ENSG00000221957', 'ENSG00000275960']
2_test_imputed.loom has no column attribute 'filter_pass'; tokenizing all inds.
Creating dataset.


Map:   0%|          | 0/10195 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/10195 [00:00<?, ? examples/s]

In [8]:
tokenized_ind

[array([ 2529,   392,  9230, ...,  6870,  3749, 16193], dtype=int16),
 array([ 1690,  1561,  3035, ..., 10788,  3749,  5295], dtype=int16),
 array([ 7588,  2529, 14127, ...,  3749,  6721, 16193], dtype=int16),
 array([ 7189,  4262,  2364, ..., 10785,  3749, 12072], dtype=int16),
 array([ 4454, 19950, 13055, ..., 10912, 15808,  3749], dtype=int16),
 array([ 6262,  5222,  2513, ..., 15808,  7652,  3749], dtype=int16),
 array([12504, 10898,  9563, ...,  9687,  4454,  2990], dtype=int16),
 array([ 8211,  2390,  3511, ...,  5372,  6721, 16193], dtype=int16),
 array([24267, 21163,  6797, ..., 17183,  2529, 14678], dtype=int16),
 array([16693, 15757,  4482, ..., 12072, 14678,  6165], dtype=int16),
 array([ 5222,  3367, 21163, ..., 11531,  8069,  3749], dtype=int16),
 array([14678,  5734,  5222, ..., 10898,  2135, 16741], dtype=int16),
 array([16693,  2529,  6053, ...,  3749, 16193, 12072], dtype=int16),
 array([  278, 15349,  7548, ...,  3749, 17183,  4183], dtype=int16),
 array([ 2529,  9497

In [9]:
tokenized_ind

[array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 

In [8]:
output_dataset_truncated["input_ids"]

[[],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],
 [],


In [82]:
outputpath = "2_Data/GeneFormer/imputed/test"
Path(outputpath).mkdir(parents=True, exist_ok=True)


Saving the dataset (0/1 shards):   0%|          | 0/10195 [00:00<?, ? examples/s]

In [None]:
file_

In [86]:
data.ca.keys()

['BMI',
 'BSA',
 'CAD',
 'Cr',
 'PC1',
 'PC10',
 'PC2',
 'PC3',
 'PC4',
 'PC5',
 'PC6',
 'PC7',
 'PC8',
 'PC9',
 'PRS',
 'age',
 'age_squared',
 'apob',
 'assessment_center',
 'birth_date',
 'cad',
 'cad_age',
 'cad_date',
 'cad_status',
 'crp',
 'dbp',
 'dbp_a',
 'death_age',
 'death_date',
 'eGFR',
 'future_cad_time_days',
 'future_cad_time_months',
 'future_cad_time_years',
 'genotype_array',
 'hdl',
 'hdl_a',
 'height',
 'incident_cad',
 'is_earily_CAD',
 'ldl',
 'ldl_a',
 'mi',
 'prevalent_cad',
 'recuit_age',
 'recuit_date',
 'sbp',
 'sbp_a',
 'sex',
 'survival_time',
 'tc',
 'tc_a',
 'tg',
 'tg_a',
 'weight',
 'year_of_cad_after_recuit']

In [15]:
target_sum = 10000
data = lp.connect(loom_file_path)
data.ca.keys()

['BMI',
 'BSA',
 'CAD',
 'Cr',
 'PC1',
 'PC10',
 'PC2',
 'PC3',
 'PC4',
 'PC5',
 'PC6',
 'PC7',
 'PC8',
 'PC9',
 'PRS',
 'age',
 'age_squared',
 'apob',
 'assessment_center',
 'birth_date',
 'cad',
 'cad_age',
 'cad_date',
 'cad_status',
 'crp',
 'dbp',
 'dbp_a',
 'death_age',
 'death_date',
 'eGFR',
 'eid',
 'future_cad_time_days',
 'future_cad_time_months',
 'future_cad_time_years',
 'genotype_array',
 'hdl',
 'hdl_a',
 'height',
 'incident_cad',
 'is_earily_CAD',
 'ldl',
 'ldl_a',
 'mi',
 'prevalent_cad',
 'recuit_age',
 'recuit_date',
 'sbp',
 'sbp_a',
 'sex',
 'survival_time',
 'tc',
 'tc_a',
 'tg',
 'tg_a',
 'weight',
 'year_of_cad_after_recuit']

In [18]:
def rank_genes(gene_vector, gene_tokens):
    """
    Rank gene expression vector.
    """
    # sort by median-scaled gene values
    sorted_indices = np.argsort(-gene_vector)
    return gene_tokens[sorted_indices]


def tokenize_cell(gene_vector, gene_tokens):
    """
    Convert normalized gene expression vector to tokenized rank value encoding.
    """
    # create array of gene vector with token indices
    # mask undetected genes
    nonzero_mask = np.nonzero(gene_vector)[0]
    # rank by median-scaled gene values
    return rank_genes(gene_vector[nonzero_mask], gene_tokens[nonzero_mask])


coding_miRNA_loc = np.where(
    [proteomics_tokenizer.genelist_dict.get(i, False) for i in data.ra["ensembl_id"]]
)[0]


# norm_factor_vector = np.array(
#     [
#         proteomics_tokenizer.gene_median_dict[i]
#         for i in data.ra["ensembl_id"][coding_miRNA_loc]
#     ]
# )
coding_miRNA_ids = data.ra["ensembl_id"][coding_miRNA_loc]

not_in_gene_ids = set(data.ra["ensembl_id"]) - set(proteomics_tokenizer.gene_keys)
print(
    f"{len(not_in_gene_ids)} genes not in gene token dictionary, skipping them, some are: {list(not_in_gene_ids)[:5]}"
)

coding_miRNA_tokens = np.array(
    [proteomics_tokenizer.gene_token_dict[i] for i in coding_miRNA_ids]
)

# define coordinates of cells passing filters for inclusion (e.g. QC)
try:
    data.ca["filter_pass"]
except AttributeError:
    var_exists = False
else:
    var_exists = True

if var_exists:
    filter_pass_loc = np.where([i == 1 for i in data.ca["filter_pass"]])[0]
elif not var_exists:
    print(
        f"{loom_file_path} has no column attribute 'filter_pass'; tokenizing all Individuals."
    )
    filter_pass_loc = np.array([i for i in range(data.shape[1])])

# scan through .loom files and tokenize cells
tokenized_ind = []
for _ix, _selection, view in data.scan(
    items=filter_pass_loc, axis=1, batch_size=proteomics_tokenizer.chunk_size
):
    # select subview with protein-coding and miRNA genes
    subview = view.view[coding_miRNA_loc, :]

    # Currently do not norm ,as the values is NPX by UKB

    # normalize by total counts per cell and multiply by 10,000 to allocate bits to precision
    # and normalize by gene normalization factors
    # subview_norm_array = (
    #     subview[:, :] / subview.ca.n_counts * target_sum / norm_factor_vector[:, None]
    # )

    # tokenize subview gene vectors
    tokenized_ind += [
        tokenize_cell(subview[:, i], coding_miRNA_tokens)
        for i in range(subview.shape[1])
    ]

    # # add custom attributes for subview to dict
    # if proteomics_tokenizer.custom_attr_name_dict is not None:
    #     for k in file_cell_metadata.keys():
    #         file_cell_metadata[k] += subview.ca[k].tolist()
    # else:
    #     file_cell_metadata = None
    # break

28 genes not in gene token dictionary, skipping them, some are: ['IL12A_IL12B', 'CTAG1A_CTAG1B', 'SKIV2L', 'CGB3_CGB5_CGB8', 'MYLPF']
2_test_imputed.loom has no column attribute 'filter_pass'; tokenizing all Individuals.


In [19]:
tokenized_ind

[array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 

In [None]:
import pandas as pd

pd.DataFrame([subview[:, 2], coding_miRNA_tokens]).T.sort_values(0, ascending=False)

In [None]:
tokenized_ind

In [None]:
set(data.ra["ensembl_id"]) - set(coding_miRNA_ids)

In [None]:
proteomics_tokenizer.genelist_dict

In [None]:
proteomics_tokenizer.gene_token_dict["ENSG00000175164"]