## Example for model inference

In [3]:
import torch
import torch
from hydra import compose, initialize, initialize_config_dir
from omegaconf import OmegaConf
import hydra
import numpy as np

import sys
import pascient  # your new package
sys.modules['cellm'] = pascient


def load_model(config_path, checkpoint_path):
    """
    Utility function to load a model from a checkpoint.
    """

    with initialize_config_dir(version_base=None, config_dir=config_path, job_name="test_app"):
        cfg = compose(config_name="config.yaml", return_hydra_config=True, 
                    overrides=["data.multiprocessing_context=null", "data.batch_size=16","data.sampler_cls._target_=cellm.data.data_samplers.BaseSampler","+data.output_map.return_index=True"])
        print(OmegaConf.to_yaml(cfg))

    checkpoint = torch.load(checkpoint_path)
    metrics = hydra.utils.instantiate(cfg.get("metrics"))
    model = hydra.utils.instantiate(cfg.model, metrics = metrics)
    model.load_state_dict(checkpoint["state_dict"])
    model.eval()

    cfg.paths.output_dir = ""

    return model

class ForwardModel(torch.nn.Module):
    """
    A wrapper class for the model to handle the forward pass.
    If last layer = True, it returns the last layer of the embedding

    Output :
    - patient embedding
    - cell cross embedding
    - patient prediction
    """
    def __init__(self, base_model, last_layer = True):
        super().__init__()
        self.base_model = base_model
        self.last_layer = last_layer
    def forward(self, x, padding_mask):
        #assert x.shape[0] == 1
        cell_embds = self.base_model.gene2cell_encoder(x)
        cell_cross_embds = self.base_model.cell2cell_encoder(cell_embds, padding_mask = padding_mask)
        patient_embds = self.base_model.cell2patient_aggregation.aggregate(data = cell_cross_embds, mask = padding_mask)
        patient_embds_2 = self.base_model.patient_encoder(patient_embds)
        patient_preds = self.base_model.patient_predictor(patient_embds_2)
        if self.last_layer:
            return patient_embds_2, cell_cross_embds, patient_preds
        else:
            return patient_embds, cell_cross_embds, patient_preds


def lognormalize(x, padded_mask, target_sum = 1e4):
    """
    Normalize the input tensor using log normalization.
    """
    X = x
    pad_mask = padded_mask

    counts_per_cell = X.sum(axis=-1) +1e-8
    counts_per_cell = counts_per_cell / target_sum

    counts_per_cell[~pad_mask] = 1

    X_padded_norm = X / counts_per_cell[..., None]

    X_padded_out = X_padded_norm.log1p()

    return X_padded_out, padded_mask

### Loading the model

In [12]:
# Add path to model
resources_path = "/homefs/home/debroue1/projects/pascient_github/resources/multilabel_model"
config_path = f"{resources_path}/.hydra/"
checkpoint_path = f"{resources_path}/checkpoints/pascient.ckpt"
model = load_model(config_path, checkpoint_path)
model_fwd = ForwardModel(model, last_layer = True)

hydra:
  run:
    dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
  sweep:
    dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
    subdir: ${hydra.job.num}
  launcher:
    _target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher
  sweeper:
    _target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper
    max_batch_size: null
    params: null
  help:
    app_name: ${hydra.job.name}
    header: '${hydra.help.app_name} is powered by Hydra.

      '
    footer: 'Powered by Hydra (https://hydra.cc)

      Use --hydra-help to view Hydra specific help

      '
    template: '${hydra.help.header}

      == Configuration groups ==

      Compose your configuration from those groups (group=option)


      $APP_CONFIG_GROUPS


      == Config ==

      Override anything in the config (foo.bar=value)


      $CONFIG


      ${hydra.help.footer}

      '
  hydra_help:
    template: 'Hydra (${hydra.runtime.version})

      See https://hydra.cc for more info.


      == Flags ==

      

  checkpoint = torch.load(checkpoint_path)
/homefs/home/debroue1/miniforge3/envs/pascient_test/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'gene2cell_encoder' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['gene2cell_encoder'])`.
/homefs/home/debroue1/miniforge3/envs/pascient_test/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'cell2cell_encoder' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['cell2cell_encoder'])`.
/homefs/home/debroue1/miniforge3/envs/pascient_test/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'cell2patient_aggregation' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(igno

## Running inference on the model

In [13]:
from cellm.data.data_structures import SampleBatch

# Example for orginal count data
# Tensor size should be Samples x 1 x Cells x Genes
x = 50 + torch.randn(16,1,1000,28231)
# Padding mask is True if cell is observed and False if cell is masked
padding_mask = torch.ones(16,1,1000).bool()

x, padding_mask = lognormalize(x, padding_mask)
sample_embeds, cell_embds, sample_preds = model_fwd(x, padding_mask)

In [15]:
sample_embeds.shape

torch.Size([16, 1, 512])