In [1]:
import sys
sys.path.append("/home/kevin/Documents")  # parent of `perturbgene` directory

import json
import os
import pickle

import matplotlib.pyplot as plt
import numpy as np
import scipy
import torch
import transformers
from tqdm import tqdm  # https://discuss.pytorch.org/t/error-while-multiprocessing-in-dataloader/46845/9
from braceexpand import braceexpand

from perturbgene.data_utils import GeneTokenizer, IterableAnnDataset, EvalJsonDataset, read_h5ad_file
from perturbgene.data_utils.data_collators import collate_fn_wrapper
from perturbgene.data_utils.tokenization import _prepend_bin, phenotype_to_token
from perturbgene.eval_utils import mlm_metrics_wrapper, cls_metrics_wrapper, preprocess_logits_argmax
from perturbgene.model import GeneBertForPhenotypicMLM, GeneBertForClassification
from perturbgene.inference_utils import get_inference_config, prepare_cell, test_cell, mlm_for_phenotype_cls

In [2]:
# phenotype_category = "tissue"  # category to perform classifcation on
phenotype_category = "cell_type"

model_type = "mlm"
model_checkpt_path = "/home/kevin/Documents/perturbgene/outputs/base_v0_1024_mlm_bins1/checkpoint-32000"
# model_checkpt_path = "/home/kevin/Documents/transformeromics/test_mlm_inference_130_cls_bins1/checkpoint-3200"

# model_type = "cls"
# model_checkpt_path = "/home/kevin/Documents/perturbgene/test_cls_inference_130_cls_bins1/checkpoint-3200"
assert model_type in ("mlm", "cls")

# device = "cpu"
device = "cuda:0"

# Load tokenizer first, so that we can get the config
expected_tokenizer_path = os.path.join(
    os.path.dirname(model_checkpt_path),
    "tokenizer.pkl",
)

if os.path.isfile(expected_tokenizer_path):
    with open(expected_tokenizer_path, "rb") as f:
        tokenizer = pickle.load(f)
else:
    print("Saved tokenizer not found, creating tokenizer with common parameters")
    tokenizer = GeneTokenizer(get_inference_config(  # change these parameters
        bin_edges=[0.1], 
        pretrained_model_path="/dev/null",  # needs to be a path that exists
        max_length=130,        
        num_top_genes=128
    ))

tokenizer.config.vocab_path = os.path.join("/home/kevin/Documents/", tokenizer.config.vocab_path)  # rel path -> abs path
config = tokenizer.config

if model_type == "mlm":
    tokenizer.config.binary_label = None
    tokenizer.config.phenotype_category = phenotype_category
    config = tokenizer.config
    
    model_class = GeneBertForPhenotypicMLM
elif model_type == "cls":
    assert tokenizer.config.phenotype_category == phenotype_category, tokenizer.config.phenotype_category
    
    model_class = GeneBertForClassification
else:
    raise NotImplementedError

data_collator = collate_fn_wrapper(tokenizer)
compute_metrics = cls_metrics_wrapper(tokenizer)  # always using cls_metrics

# Largely copied from IterableAnnDataset
phenotype_category_labels = tokenizer.phenotypic_tokens_map[config.phenotype_category]

if config.binary_label is None:
    label2id = {label: i for i, label in enumerate(phenotype_category_labels)}
else:
    label2id = {label: int(label == config.binary_label)
                     for label in phenotype_category_labels}

In [3]:
model = model_class.from_pretrained(model_checkpt_path)
model.eval()
model.to(device);

Some weights of the model checkpoint at /home/kevin/Documents/perturbgene/outputs/base_v0_1024_mlm_bins1/checkpoint-32000 were not used when initializing GeneBertForPhenotypicMLM: ['mlm_loss_fct.weight']
- This IS expected if you are initializing GeneBertForPhenotypicMLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GeneBertForPhenotypicMLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


# Inference

In [4]:
validation_data = read_h5ad_file("/home/shared_folder/TabulaSapiens/ranked/Tabula_Sapiens_ranked_47.h5ad", config.num_top_genes)

## MLM (for phenotype classification)

### One cell

In [5]:
assert model_type == "mlm"
cell = validation_data[0]

top_id = mlm_for_phenotype_cls(cell, phenotype_category, model, tokenizer, data_collator)
print(f"Pred: {tokenizer.flattened_tokens[top_id]}\n"
      f"Label: {phenotype_to_token(cell.obs[phenotype_category].item())}")

Pred: [CD4-positive,_alpha-beta_memory_T_cell]
Label: [CD4-positive,_alpha-beta_memory_T_cell]


### Entire h5ad file

In [6]:
assert model_type == "mlm"

all_preds = []
all_labels = []
for cell in tqdm(validation_data):
    top_id = mlm_for_phenotype_cls(cell, phenotype_category, model, tokenizer, data_collator)
    all_preds.append(phenotype_category_labels.index(tokenizer.flattened_tokens[top_id]))
    all_labels.append(phenotype_category_labels.index(phenotype_to_token(cell.obs[phenotype_category].item())))

metrics = compute_metrics(transformers.EvalPrediction(
    predictions=np.array(all_preds), 
    label_ids=np.array(all_labels).reshape(-1, 1),
))

print(f"{metrics=}")

100%|█████████████████████████████████████████████████████████████| 10000/10000 [05:24<00:00, 30.79it/s]

metrics={'accuracy': 0.938, 'precision': 0.8227611374546668, 'recall': 0.8150250505737551, 'f1': 0.8188748233019938}





## CLS

### One cell

In [12]:
assert model_type == "cls"
cell = validation_data[0]

phenotype_category = config.phenotype_category
print(f"{phenotype_category=}")

prepared_cell = prepare_cell(cell, model_type, tokenizer, label2id)
output = test_cell(prepared_cell, model, data_collator)
print(f"Pred: {tokenizer.flattened_tokens[output.logits.argmax(dim=-1).item()]}\n"
      f"Label: {phenotype_to_token(cell.obs[phenotype_category].item())}")

phenotype_category='tissue'
Pred: [kidney_epithelial_cell]
Label: [blood]


### Entire h5ad file

In [13]:
assert model_type == "cls"
eval_batch_size = 64
eval_dataset = IterableAnnDataset(["/home/shared_folder/TabulaSapiens/ranked/Tabula_Sapiens_ranked_47.h5ad"], config)
eval_dataloader = torch.utils.data.DataLoader(
    eval_dataset,
    batch_size=eval_batch_size,
    collate_fn=data_collator,
    num_workers=1,
    pin_memory=True
)

all_preds = []
all_labels = []
for batch in tqdm(eval_dataloader):
    with torch.no_grad():
        preds = model(**{key: val.to(device) for key, val in batch.items()})
        all_preds.extend(preds.logits.argmax(dim=-1))
        all_labels.extend(batch["labels"])

metrics = compute_metrics(transformers.EvalPrediction(
    predictions=torch.stack(all_preds).cpu().numpy(), 
    label_ids=torch.stack(all_labels).cpu().numpy()
))
print(f"{metrics=}")

100%|█████████████████████████████████████████| 157/157 [00:41<00:00,  3.74it/s]

metrics={'accuracy': 0.5765, 'precision': 0.5889780711350704, 'recall': 0.5946053679796419, 'f1': 0.5917783421862598}



