## SpaRED Library MISC DEMO

This demonstration illustrates the usage of various functions available in our SpaRED PyPI library. The functions we will explore are categorized as follows:

* Spot Features
* Graph operations
* Dataloaders
* Models
* Metrics

These functions provide essential tools for preparing your data for model training and inference, as well as for evaluating model performance. By leveraging these capabilities, users can streamline their workflow and enhance the efficiency and accuracy of their spatial transcriptomics analysis.

In [2]:
import matplotlib.pyplot as plt
import matplotlib.image as im
import os
import sys
import argparse
import torch
from pathlib import Path

currentdir = os.getcwd()
parentdir = str(Path(currentdir).parents[2])
sys.path.insert(0, parentdir)
print(parentdir)

import spared

/media/SSD4/dvegaa/SpaRED


### Load Datasets

The `datasets` file has a function to get any desired dataset and return the adata as well as the parameter dictionary. This function returns a filtered and processed adata. This function has a parameter called *visualize* that allows for all visualizations if set to True. The fuction also saves the raw_adata (not processed) in case it is required. 

We will begin by loading a dataset and setting the *visualize* parameter as False since no images are required for the functions analized in this DEMO.

In [4]:
from spared.datasets import get_dataset
import anndata as ad

#get dataset
data = get_dataset("vicari_mouse_brain", visualize=False)

#adata
adata = data.adata

#parameters dictionary
param_dict = data.param_dict

#loading raw adata 
dataset_path = os.getcwd()
files_path = os.path.join(dataset_path, "processed_data/vicari_data/vicari_mouse_brain/")
files = os.listdir(files_path)
adata_path = os.path.join(files_path, files[0], "adata_raw.h5ad")
raw_adata = ad.read_h5ad(adata_path)

Loading vicari_mouse_brain dataset with the following data split:
train data: ['V11L12-038_A1', 'V11L12-038_B1', 'V11L12-038_C1', 'V11L12-038_D1', 'V11L12-109_A1', 'V11L12-109_B1', 'V11L12-109_C1', 'V11L12-109_D1']
val data: ['V11T16-085_A1', 'V11T16-085_B1', 'V11T16-085_C1', 'V11T16-085_D1']
test data: ['V11T17-101_A1', 'V11T17-101_B1']
Parameters already saved in /media/SSD4/dvegaa/SpaRED/spared/processed_data/vicari_data/vicari_mouse_brain/2024-07-08-11-11-47/parameters.json
Loading main adata file from disk (/media/SSD4/dvegaa/SpaRED/docs/notebooks/tutorials/processed_data/vicari_data/vicari_mouse_brain/2024-07-08-11-11-47/adata.h5ad)...
The loaded adata object looks like this:
AnnData object with n_obs × n_vars = 43804 × 128
    obs: 'in_tissue', 'array_row', 'array_col', 'patient', 'slide_id', 'split', 'unique_id', 'n_genes_by_counts', 'total_counts'
    var: 'gene_ids', 'feature_types', 'genome', 'gene_symbol', 'exp_frac', 'glob_exp_frac', 'n_cells_by_counts', 'mean_counts', 'pc

We are ready to explore the functions one by one. This tutorial will demostrate how to use each function, what to introduce as input and the expected output. 

### Spot Features Functions

We will begin with the spot feature functions available in the SpaRED library. These functions are crucial for extracting and analyzing the unique characteristics and spatial relationships of each spot within spatial transcriptomics data. This tutorial will guide you on how to utilize each function, specifying the required inputs and showcasing the expected outputs.

### Function:`compute_patches_embeddings`

The `compute_patches_embeddings` function computes embeddings for image patches stored in an AnnData object using a specified backbone model. The embeddings are stored in the `AnnData` object's `.obsm` attribute. This function allows for the use of either a pretrained model from PyTorch or a custom model provided via a file path.

##### <u>Parameters:</u>

* **adata (ad.AnnData):** The AnnData object containing the data to be processed.
* **backbone (str):** Specifies the backbone model to use for generating embeddings. Defaults to 'densenet'.
* **model_path (str):** Path to a stored model. If set to 'None', a pretrained model is used. Defaults to "None".
* **patch_size (int):** The size of the patches. Defaults to 224.

##### <u>Returns:</u>

This function modifies the `AnnData` object in place, adding the computed embeddings to `adata.obsm[f'embeddings_{backbone}']`. It does not return any value.

### Explanation

The `compute_patches_embeddings` function is designed to extract embeddings from image patches using a deep learning model specified by the backbone parameter. The function provides flexibility in using either pretrained models or custom models stored locally. This function is critical for generating embeddings that can be used for downstream analysis in spatial transcriptomics.


In [6]:
from spared.spot_features import compute_patches_embeddings

compute_patches_embeddings(adata=adata, backbone='densenet', model_path="None", patch_size= 224)
adata

Getting embeddings: 100%|██████████| 172/172 [01:00<00:00,  2.84it/s]


AnnData object with n_obs × n_vars = 43804 × 128
    obs: 'in_tissue', 'array_row', 'array_col', 'patient', 'slide_id', 'split', 'unique_id', 'n_genes_by_counts', 'total_counts'
    var: 'gene_ids', 'feature_types', 'genome', 'gene_symbol', 'exp_frac', 'glob_exp_frac', 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts', 'gene_length', 'd_log1p_moran', 'log1p_avg_exp', 'd_log1p_avg_exp', 'c_log1p_avg_exp', 'c_d_log1p_avg_exp'
    uns: 'spatial'
    obsm: 'patches_scale_1.0', 'spatial', 'embeddings_densenet'
    layers: 'c_d_deltas', 'c_d_log1p', 'c_deltas', 'c_log1p', 'counts', 'd_deltas', 'd_log1p', 'deltas', 'log1p', 'mask', 'tpm'

### Function: `compute_patches_predictions`

The `compute_patches_predictions` function computes predictions for image patches stored in an `AnnData` object using a specified backbone model. The predictions are stored in the `.obsm` attribute of the AnnData object. This function allows for the use of either a pretrained model from PyTorch or a custom model provided via a file path.

##### <u>Parameters:</u>

* **adata (ad.AnnData):** The AnnData object containing the data to be processed.
* **backbone (str):** Specifies the backbone model to use for generating predictions. Defaults to 'densenet'.
* **model_path (str):** Path to a stored model. If set to 'None', a pretrained model is used. Defaults to "None".
* **patch_size (int):** The size of the patches. Defaults to 224.

##### <u>Returns:</u>

This function modifies the AnnData object in place, adding the computed predictions to `adata.obsm[f'predictions_{backbone}']`. It does not return any value.

### Explanation

The `compute_patches_predictions` function is designed to generate predictions for image patches using a deep learning model specified by the backbone parameter. This function is very similar to the `compute_patches_embeddings` function but is focused on producing predictions rather than embeddings. This function is particularly useful for applying a trained model to new data or evaluating patches from a dataset to predict specific outcomes.

In [7]:
from spared.spot_features import compute_patches_predictions

compute_patches_predictions(adata=adata, backbone='densenet', model_path="None", patch_size= 224)
adata

Getting predictions: 100%|██████████| 172/172 [00:59<00:00,  2.89it/s]


AnnData object with n_obs × n_vars = 43804 × 128
    obs: 'in_tissue', 'array_row', 'array_col', 'patient', 'slide_id', 'split', 'unique_id', 'n_genes_by_counts', 'total_counts'
    var: 'gene_ids', 'feature_types', 'genome', 'gene_symbol', 'exp_frac', 'glob_exp_frac', 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts', 'gene_length', 'd_log1p_moran', 'log1p_avg_exp', 'd_log1p_avg_exp', 'c_log1p_avg_exp', 'c_d_log1p_avg_exp'
    uns: 'spatial'
    obsm: 'patches_scale_1.0', 'spatial', 'embeddings_densenet', 'predictions_densenet'
    layers: 'c_d_deltas', 'c_d_log1p', 'c_deltas', 'c_log1p', 'counts', 'd_deltas', 'd_log1p', 'deltas', 'log1p', 'mask', 'tpm'

### Function: `compute_dim_red`

The `compute_dim_red function` is a streamlined utility for performing dimensionality reduction and clustering on single-cell data stored in an `AnnData` object. It automates the process of computing PCA, constructing a nearest neighbors graph, generating UMAP embeddings, and performing Leiden clustering, all based on the expression matrix stored in a specified layer of the `AnnData` object.

##### <u>Parameters:</u>

* **adata (ad.AnnData):** The AnnData object that contains the data for analysis. The data for dimensionality reduction and clustering should be stored in a specified layer within this object.
* **from_layer (str):** The key in adata.layers where the expression matrix is located. This layer will be used as the input for all computations.

##### <u>Returns:</u>

The function returns a new `AnnData` object with the computed PCA, UMAP embeddings, and Leiden clusters added.

### Explanation

The `compute_dim_red` function simplifies the analysis of single-cell RNA-seq data by automating several critical steps:

##### PCA Computation:

The function starts by performing Principal Component Analysis (PCA) on the expression matrix stored in the specified layer (from_layer). PCA reduces the dimensionality of the data, capturing the most significant variations. The resulting PCA embeddings are stored in `adata.obsm['X_pca']`.

##### Neighbor Graph Construction:

After PCA, the function constructs a nearest neighbors graph using the PCA-reduced data. This graph is crucial for downstream analyses like UMAP and clustering, as it defines the local neighborhood relationships between cells. The neighbor graph is stored in `adata.obsp['distances']` and `adata.obsp['connectivities']`.

##### UMAP Embedding:

The function then applies Uniform Manifold Approximation and Projection (UMAP) to generate a low-dimensional embedding of the data. UMAP is widely used for visualizing single-cell datasets, making complex data more interpretable. The UMAP embeddings are saved in `adata.obsm['X_umap']`.

##### Leiden Clustering:

Finally, the function performs Leiden clustering to identify distinct cell populations based on the UMAP embeddings. The clusters are stored in `adata.obs['cluster']`, enabling further exploration and analysis of the data.

In [8]:
from spared.spot_features import compute_dim_red

adata = compute_dim_red(adata=adata, from_layer="c_d_log1p")
adata

AnnData object with n_obs × n_vars = 43804 × 128
    obs: 'in_tissue', 'array_row', 'array_col', 'patient', 'slide_id', 'split', 'unique_id', 'n_genes_by_counts', 'total_counts', 'cluster'
    var: 'gene_ids', 'feature_types', 'genome', 'gene_symbol', 'exp_frac', 'glob_exp_frac', 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts', 'gene_length', 'd_log1p_moran', 'log1p_avg_exp', 'd_log1p_avg_exp', 'c_log1p_avg_exp', 'c_d_log1p_avg_exp'
    uns: 'spatial', 'pca', 'neighbors', 'umap', 'leiden'
    obsm: 'patches_scale_1.0', 'spatial', 'embeddings_densenet', 'predictions_densenet', 'X_pca', 'X_umap'
    varm: 'PCs'
    layers: 'c_d_deltas', 'c_d_log1p', 'c_deltas', 'c_log1p', 'counts', 'd_deltas', 'd_log1p', 'deltas', 'log1p', 'mask', 'tpm'
    obsp: 'distances', 'connectivities'

### Function: `get_spatial_neighbors`

The `get_spatial_neighbors` function computes a dictionary of spatial neighbors for observations in a single-slide `AnnData` object. It calculates neighbors based on topological distances within a graph defined by either a hexagonal or grid geometry. The resulting dictionary contains the indexes of each observation as keys and lists of neighboring observation indexes as values. Neighbors include the observation itself and those within a specified number of hops.

##### <u>Parameters:</u>

* **adata (ad.AnnData):** The AnnData object containing the data. The function assumes that the data is from a single slide and cannot be a collection of slides.
* **n_hops (int):** tThe number of hops (or steps) to consider when determining the neighborhood of each observation. This defines the size of the vicinity.
* **hex_geometry (bool):** A boolean flag indicating the geometry of the graph. If `True`, a hexagonal grid is assumed (typically for Visium datasets); if `False`, a standard grid is used.

##### <u>Returns:</u>

A dictionary where each key corresponds to the index of an observation, and the value is a list of indexes representing the neighbors of that observation, including itself.

### Explanation

The `get_spatial_neighbors` function is designed to identify and return spatial neighbors for each observation in a single-slide AnnData object. It works by first constructing a spatial graph based on the geometry (either hexagonal or grid) and then calculating neighbors within a specified number of hops. The spatial neighbors are computed using the `sq.gr.spatial_neighbors` function and are then stored in the adjacency matrix adata.obsp['spatial_connectivities'].

In [9]:
from spared.spot_features import get_spatial_neighbors

dict_spatial_neighbors = get_spatial_neighbors(adata=adata, n_hops=6, hex_geometry=param_dict["hex_geometry"])

### Graph Operations Functions

In this section, we will explore the graph operations functions available in the SpaRED library. These functions are essential for converting spatial transcriptomics data into geometric graphs, where the central nodes represent spots and the adjacent nodes represent the spot's neighbors. These graphs integrate information from multiple patches, enabling the prediction of gene expression for the central node by leveraging the spatial information from its neighbors.

### Function: `get_graphs_one_slide`

The `get_graphs_one_slide` function computes spatial neighbor graphs for each node in a single-slide `AnnData` object. These graphs are constructed within an n_hops radius and are formatted for use with `PyTorch Geometric`. The function assumes that both embeddings and predictions are already computed and stored in the `adata.obsm` attribute.

##### <u>Parameters:</u>

* **adata (ad.AnnData):** The AnnData object containing data from a single slide. This object must include both embeddings and predictions.
* **n_hops (int):** The number of hops to consider when constructing the graph around each node.
* **layer:** The specific layer in the AnnData object to use for predictions, which will be added as the y attribute in the resulting graphs.
* **hex_geometry (bool):** A boolean flag indicating whether the slide data is organized in a hexagonal grid (e.g., for Visium datasets) or a regular grid.

##### <u>Returns:</u>

The function returns a tuple containing two elements: a dictionary and an integer. The dictionary maps each patch name (central node) to a PyTorch Geometric graph, where the graph represents the spatial neighborhood around that node, including attributes such as embeddings, predictions, positional information, and data from the specified layer. The first node in each graph is always the central node. The integer represents the maximum absolute difference in position (either row or column) between the central node and its neighbors across all graphs, which can be useful for positional encoding.

### Explanation
The `get_graphs_one_slide` function is designed to create localized spatial graphs for each node in a single-slide spatial transcriptomics dataset. The function begins by constructing a spatial connectivity graph based on the geometry of the data, determined by the `hex_geometry` parameter (hexagonal for Visium or grid for other datasets). Using this graph, it computes neighbors within a specified number of hops, capturing the local spatial relationships around each node.


In [10]:
from spared.graph_operations import get_graphs_one_slide

#Graph operation must have embedding and prediction layers
compute_patches_embeddings(adata=adata, backbone='densenet', model_path="None", patch_size= 224)
compute_patches_predictions(adata=adata, backbone='densenet', model_path="None", patch_size= 224)

#Get slide adata
slide_id = adata.obs.slide_id.unique()[0]
slide_adata = adata[adata.obs.slide_id == slide_id]

#Get graph for one slide
dict_graph_slide, max_pos = get_graphs_one_slide(adata=slide_adata, n_hops=6, layer="c_d_log1p", hex_geometry=param_dict["hex_geometry"])

Getting embeddings: 100%|██████████| 172/172 [01:00<00:00,  2.86it/s]
Getting predictions: 100%|██████████| 172/172 [01:01<00:00,  2.82it/s]


### Function: `get_sin_cos_positional_embeddings`

The `get_sin_cos_positional_embeddings` function adds transformer-like sinusoidal positional encodings to each graph in a given dictionary of graphs. These positional encodings are designed to capture the relative spatial positions of nodes within each graph and are added as an attribute (`positional_embeddings`) to each graph.

##### <u>Parameters:</u>

* **graph_dict (dict):** A dictionary where the keys are patch names (central node identifiers) and the values are PyTorch Geometric graphs. Each graph represents the spatial neighborhood around the central node.
* **max_d_pos (int):** The maximum absolute value in the relative position matrix, used to determine the size of the positional encoding grid.

##### <u>Returns:</u>

This function returns the input `graph_dict`, now augmented with positional encodings for each graph under the attribute `positional_embeddings`.

### Explanation
The `get_sin_cos_positional_embeddings` function adds positional encodings to each graph, which is crucial for capturing the spatial relationships between nodes within a graph. In transformer-based models, these positional encodings help the model understand the relative positions of elements, enabling it to better learn patterns that depend on spatial arrangement. By incorporating these positional embeddings, the graphs become more informative, allowing models to leverage spatial context more effectively, which can significantly improve performance in tasks like spatial transcriptomics analysis.

In [11]:
from spared.graph_operations import get_sin_cos_positional_embeddings

dict_pos_emb = get_sin_cos_positional_embeddings(graph_dict=dict_graph_slide, max_d_pos=max_pos)

### Function: `get_graphs`

The `get_graphs` function generates spatial neighbor graphs for all slides in a dataset, building upon the functionality provided by the `get_graphs_one_slide` function. It processes each slide individually, constructs the graphs within a specified number of hops, and then combines these graphs into a single dictionary. Afterward, it enriches the graphs with sinusoidal positional embeddings using the `get_sin_cos_positional_embeddings` function. The function assumes that both embeddings and predictions are already computed and stored in the `adata.obsm` attribute.

##### <u>Parameters:</u>

* **adata (ad.AnnData):** The AnnData object containing the dataset, which includes multiple slides.
* **n_hops (int):** The number of hops to consider when constructing each graph, defining the neighborhood size around each node.
* **layer (str):** The specific layer in the AnnData object to use for predictions, which will be added as the *y* attribute in the resulting graphs.
* **hex_geometry (bool):** A boolean flag indicating whether the slide data is organized in a hexagonal grid (e.g., for Visium datasets). Defaults to `True`.

##### <u>Returns:</u>

This function returns a dictionary where the keys are spot names (or other unique identifiers) and the values are `PyTorch Geometric` graphs, each representing a spatial neighborhood around a central node, enriched with positional embeddings.

In [12]:
from spared.graph_operations import get_graphs

#Graph operation must have embedding and prediction layers
compute_patches_embeddings(adata=adata, backbone='densenet', model_path="None", patch_size= 224)
compute_patches_predictions(adata=adata, backbone='densenet', model_path="None", patch_size= 224)

#Get graphs
dict_graphs = get_graphs(adata=adata, n_hops=6, layer="c_d_log1p", hex_geometry=param_dict["hex_geometry"])

Getting embeddings: 100%|██████████| 172/172 [01:00<00:00,  2.85it/s]
Getting predictions: 100%|██████████| 172/172 [01:01<00:00,  2.81it/s]


Computing graphs...


100%|██████████| 14/14 [01:24<00:00,  6.04s/it]


### Dataloader Functions

In this section, we will explore the dataloader functions provided in the SpaRED library. These functions are essential for preparing spatial transcriptomics data for machine learning tasks. They cover a range of operations, from generating dataloaders for pretraining vision-based models to creating dataloaders for graph-based models. Whether you're training on raw image patches or leveraging spatial graphs to predict gene expression, these functions automate the data preparation process, ensuring your datasets are efficiently organized and ready for training and evaluation.

### Function: `get_pretrain_dataloaders`

The `get_pretrain_dataloaders` function generates and returns dataloaders for pretraining an image encoder on spatial transcriptomics data. The goal is to train a vision-based model to predict gene expression directly from image patches. It prepares the dataset by selecting the appropriate layer and handling noisy data, and then returns dataloaders for the training, validation, and test sets.

##### <u>Parameters:</u>

* **adata (ad.AnnData):** The AnnData object containing the dataset to be used for pretraining.
* **layer (str):** The specific layer in the AnnData object to use for pretraining. Defaults to `c_d_log1p`.
* **batch_size (int):** The batch size for the dataloaders. Defaults to 128.
* **shuffle (bool):** Whether to shuffle the data in the dataloaders. Defaults to `True`.
* **use_cude (bool):** Whether to use CUDA for loading the data. Defaults to `False`.

##### <u>Returns:</u>

This function returns a tuple containing the train, validation, and test dataloaders. If the dataset does not include a test set, the test dataloader is `None`.


In [13]:
from spared.dataloaders import get_pretrain_dataloaders

train_loader, val_loader, test_loader = get_pretrain_dataloaders(adata = adata, layer = 'c_d_log1p', batch_size = 128, shuffle = True, use_cuda = False)

Using noisy_delta layer for training. This will probably yield bad results.
Percentage of imputed observations with median filter: 27.503%


### Function: `get_graph_dataloaders`

The `get_graph_dataloaders` function creates and returns dataloaders for training, validation, and testing on graph data derived from a spatial transcriptomics dataset. The function goes through a series of steps to either compute or load precomputed graph data, ensuring that the data is ready for use in model training.

##### <u>Parameters:</u>

* **adata (ad.AnnData):** The AnnData object containing the dataset to be processed.
* **dataset_path (str):** The path where the dataset and graphs are stored or will be saved. Defaults to an empty string.
* **layer (str):** The specific layer in the AnnData object to use for prediction. Defaults to `c_t_log1p`.
* **n_hops (int):** The number of hops to consider when constructing each graph. Defaults to 2.
* **backbone (str):** The backbone model to use for computing embeddings and predictions. Defaults to *densenet*.
* **model_path (str):** The path to a pretrained model to use for generating embeddings and predictions. Defaults to `None`.
* **batch_size (int):** The batch size for the dataloaders. Defaults to 128.
* **shuffle (bool):** Whether to shuffle the data in the dataloaders. Defaults to `True`.
* **hex_geometry (bool):** Whether the graph is based on hexagonal geometry (e.g., for Visium datasets). Defaults to `True`.
* **patch_size (int):** The size of the patches for computing embeddings. Defaults to 224.

##### <u>Returns:</u>

This function returns a tuple containing the train, validation, and test dataloaders for the graphs. If no test set is available, the test dataloader is `None`.

### Explanation
The `get_graph_dataloaders` function is essential for preparing the data pipeline needed to train models on spatial transcriptomics data in graph form. The function first checks if the required graph data has already been computed and saved in the specified dataset_path. If the graphs are found, they are loaded directly to save time. If not, the function computes the necessary embeddings and predictions using the specified backbone model, generates the graph data, and saves it for future use.

The graphs are constructed by considering a specified number of hops (n_hops) around each node, capturing the spatial relationships within the dataset. The function then creates dataloaders for the training, validation, and test sets, which are returned for use in model training.

In [14]:
from spared.dataloaders import get_graph_dataloaders

#Path to where the graphs will be saved
graphs_path = os.path.join(parentdir, "processed_data/vicari_data/vicari_mouse_brain/graphs")
os.makedirs(graphs_path, exist_ok=True)

train_graph_loader, val_graph_loader, test_graph_loader = get_graph_dataloaders(adata = adata, dataset_path = graphs_path, layer = 'c_d_log1p', n_hops = 2, backbone = 'densenet', model_path = "None", batch_size = 128, shuffle = True, hex_geometry = param_dict["hex_geometry"], patch_size = 224)


Graphs not found in file, computing graphs...
Using noisy_delta layer for training. This will probably yield bad results.


Getting embeddings: 100%|██████████| 172/172 [01:05<00:00,  2.63it/s]
Getting predictions: 100%|██████████| 172/172 [01:05<00:00,  2.63it/s]


Computing graphs...


100%|██████████| 14/14 [00:47<00:00,  3.37s/it]


Saving graphs...


### Models Functions

In this section, we explore the model functions provided in the SpaRED library. These functions are designed to facilitate the training and evaluation of deep learning models on spatial transcriptomics data. They include customizable modules that allow you to easily define and train models using various backbone architectures, process image patches, and predict gene expression levels. By leveraging the flexibility of PyTorch Lightning, these functions enable streamlined model development and provide a solid foundation for conducting advanced spatial biology research.

### Function: `ImageBackbone`
The `ImageBackbone` function is designed for training vision-based models on image patches with the goal of predicting gene expression levels. It supports a variety of backbone architectures, including ResNet, ShuffleNet, and Vision Transformer (ViT), and can be used with both pretrained models and models trained from scratch.


##### <u>Parameters:</u>

* **args (argparse.Namespace):** Argument namespace that includes various settings for the model, such as:

    * **img_backbone (str):** Specifies the backbone model to use (e.g., 'ShuffleNetV2', 'ResNet', 'ViT').
    * **img_use_pretrained (bool):** Indicates whether to use pretrained weights (True) or start training from scratch (False).
    * **average_test (bool):** If True, the model applies test-time augmentation by averaging predictions over different transformations.
    * **optim_metric (str):** Specifies the optimization metric (e.g., 'MSE' for Mean Squared Error) to evaluate model performance.
    * **robust_loss (bool):** Determines whether to use a robust loss function that can handle outliers by ignoring zeros.
    * **optimizer (str):** The type of optimizer to use (e.g., 'Adam', 'SGD').
    * **lr (float):** The learning rate for the optimizer.
    * **momentum (float):** Momentum factor for the optimizer, applicable when using optimizers like SGD.

* **latent_dim (int):** The size of the latent space, representing the number of output features from the final layer of the model. This typically corresponds to the number of variables in your dataset (e.g., data.adata.n_vars).

##### <u>Returns:</u>

This function returns the Image Backbone model with the specified arguments and backbone. The backbones available are display in the SpaRED library documentation. 

### Explanation

The `ImageBackbone` function is designed for implementation within `PyTorch Lightning`. It provides a flexible framework for training models on spatial transcriptomics data, allowing for easy customization through the args parameter and supporting a wide range of backbone architectures to suit different datasets and tasks.

In [15]:
from spared.models import ImageBackbone

# Define argparse variables
test_args = argparse.Namespace()
arg_dict = vars(test_args)
input_dict = {
    'img_backbone': 'ShuffleNetV2',
    'img_use_pretrained': True,
    'average_test': False,
    'optim_metric': 'MSE',
    'robust_loss': False,
    'optimizer': 'Adam',
    'lr': 0.0001,
    'momentum': 0.9,
}

for key,value in input_dict.items():
    arg_dict[key]= value


# Declare device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ImageBackbone(args=test_args,  latent_dim=data.adata.n_vars).to(device)
model

ImageEncoder(
  (encoder): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, trac

### Metrics Functions

Now lets delve into the metrics functions available in the SpaRED library. These functions are crucial for evaluating the performance of your models. By using the provided metrics, you can accurately assess various aspects of model accuracy, such as prediction errors and the quality of spatial predictions. These metrics help in quantifying the effectiveness of your models in predicting gene expression.

### Function: `get_pearsonr`

The `get_pearsonr` function computes the average Pearson correlation coefficient (PCC) between corresponding elements of two matrices (ground truth and predictions), considering only the unmasked values. This function is useful for evaluating the similarity between predicted and actual values across multiple variables or observations.

##### <u>Parameters:</u>

* **gt_mat (torch.Tensor):** The ground truth matrix of shape (n_observations, n_variables), containing the true values to compare against the predictions.
* **pred_mat (torch.Tensor):** The predicted matrix of shape (n_observations, n_variables), containing the model's predictions.
* **mask (torch.Tensor):** A boolean mask of shape (n_observations, n_variables) where `False` values indicate positions that should be ignored during the PCC computation.
* **axis (int):** Specifies the axis along which to compute the Pearson correlation. Use `axis=0` to compute the correlation across columns (variables) and `axis=1` to compute across rows (observations).

##### <u>Returns:</u>

The function returns a tuple consisting of two elements: the first element is a float representing the average Pearson correlation coefficient computed across the specified axis (either columns or rows). The second element is a list containing the individual Pearson correlation coefficients for each column or row. 

### Explanation

The Pearson correlation coefficient (PCC) is a measure of the linear correlation between two sets of data. It ranges from -1 to 1, where 1 indicates a perfect positive linear relationship, -1 indicates a perfect negative linear relationship, and 0 indicates no linear relationship. In the context of this function, PCC is used to evaluate how well the predicted values align with the ground truth values across multiple variables or observations.


We will be generating random prediction and ground truth matrices as well as a random mask where *26%* of the values will be masked. 

In [17]:
import torch

# Set number of observations and genes (hypothetical)
obs = 10
genes = 8
imputed_fraction = 0.26 # This is the percentage of zeros in the mask

# Henerate random matrices
pred = torch.randn((obs,genes))
gt = torch.randn((obs,genes))
mask = torch.rand((obs,genes))>imputed_fraction

In [18]:
from spared.metrics import get_pearsonr

mean_pcc_col, list_pcc_col = get_pearsonr(gt_mat=gt, pred_mat=pred, mask=mask, axis=0)
mean_pcc_row, list_pcc_row = get_pearsonr(gt_mat=gt, pred_mat=pred, mask=mask, axis=1)

print("PCC by columns: " + str(mean_pcc_col))
print(list_pcc_col)

print("PCC by rows: " + str(mean_pcc_row))
print(list_pcc_row)

PCC by columns: -0.18081757426261902
[-0.5434356331825256, 0.2459755688905716, 0.7937927842140198, -0.3614409863948822, -0.3860357403755188, -0.06298591196537018, -0.6647549271583557, -0.4676556885242462]
PCC by rows: -0.06456981599330902
[-0.10244599729776382, -0.30918097496032715, 0.8755841255187988, 0.16263891756534576, -0.5979809165000916, 0.22747549414634705, -0.2783399522304535, -0.32421642541885376, 0.02806819975376129, -0.32730066776275635]


### Function: `get_r2_score`

The `get_r2_score` function calculates the R² (coefficient of determination) score between corresponding elements of two matrices—ground truth and predicted values—considering only the unmasked values. The R² score is computed along a specified axis, either across rows or columns, and provides a measure of how well the predictions approximate the actual data. The function returns both the average R² score across the specified axis and a detailed list of R² scores for each individual row or column.

##### <u>Parameters:</u>

* **gt_mat (torch.Tensor):** The ground truth matrix of shape (n_observations, n_variables), containing the true values to compare against the predictions.
* **pred_mat (torch.Tensor):** The predicted matrix of shape (n_observations, n_variables), containing the model's predictions.
* **mask (torch.Tensor):** A boolean mask of shape (n_observations, n_variables) where `False` values indicate positions that should be ignored during the PCC computation.
* **axis (int):** Specifies the axis along which to compute the Pearson correlation. Use `axis=0` to compute the correlation across columns (variables) and `axis=1` to compute across rows (observations).

##### <u>Returns:</u>

The function returns a tuple consisting of two elements: the first element is a float representing the average R² score computed across the specified axis. The second element is a list containing the individual R² scores for each column or row. 

### Explanation:
The R² score, also known as the coefficient of determination, is a statistical measure that indicates the proportion of the variance in the dependent variable that is predictable from the independent variable(s). It ranges from 0 to 1, where 1 indicates perfect prediction and 0 indicates that the model does not predict the dependent variable at all.

We will be generating random prediction and ground truth matrices as well as a random mask where *26%* of the values will be masked.

In [19]:
import torch

# Set number of observations and genes (hypothetical)
obs = 10
genes = 8
imputed_fraction = 0.26 # This is the percentage of zeros in the mask

# Henerate random matrices
pred = torch.randn((obs,genes))
gt = torch.randn((obs,genes))
mask = torch.rand((obs,genes))>imputed_fraction

In [20]:
from spared.metrics import get_r2_score

mean_r2_col, list_r2_col = get_r2_score(gt_mat=gt, pred_mat=pred, mask=mask, axis=0)
mean_r2_row, list_r2_row = get_r2_score(gt_mat=gt, pred_mat=pred, mask=mask, axis=1)

print("R2 Score by columns: " + str(mean_r2_col))
print(list_r2_col)

print("R2 Score by rows: " + str(mean_r2_row))
print(list_r2_row)

R2 Score by columns: -1.4981153011322021
[-0.34498167037963867, -2.279686689376831, -2.8530077934265137, -1.471015214920044, -0.3273383378982544, -1.969834327697754, -0.3219904899597168, -2.4170680046081543]
R2 Score by rows: -1.3743867874145508
[-0.8520510196685791, -0.5073205232620239, -1.1874570846557617, -0.6010171175003052, -2.011044979095459, -0.8247621059417725, -1.4701313972473145, -1.2510161399841309, -4.709837913513184, -0.3292292356491089]


### Function: `get_metrics`
The `get_metrics` function computes a set of regression metrics to evaluate the performance of predictions against ground truth values across a dataset. It returns a dictionary of these metrics, which include Pearson correlation coefficients, R² scores, Mean Squared Error (MSE), Mean Absolute Error (MAE), and a combined Global metric. The function can also return detailed metrics for each individual gene and patch if the `detailed` flag is set to `True`.

##### <u>Parameters:</u>

* **gt_mat (Union[np.array, torch.Tensor]):** The ground truth matrix of shape (n_samples, n_genes), containing the actual values to compare against the predictions.
* **pred_mat (Union[np.array, torch.Tensor]):** The predicted matrix of shape (n_samples, n_genes), containing the model's predictions.
* **mask (Union[np.array, torch.Tensor]):** A boolean mask of shape (n_samples, n_genes) where `False` values indicate positions that should be ignored during the metric computation.
* **detailed (bool):** If set to `True`, the function returns detailed metrics for each gene and patch in addition to the general metrics. Defaults to `False`.

##### <u>Returns:</u>

The function returns a dictionary containing the computed metrics. The main metrics include `PCC-Gene`, `PCC-Patch`, `R2-Gene`, `R2-Patch`, `MSE`, `MAE`, and a combined `Global` metric. If `detailed=True`, the dictionary is extended with additional keys such as `detailed_PCC-Gene`, `detailed_PCC-Patch`, `detailed_R2-Gene`, `detailed_R2-Patch`, and various error metrics for each gene.

### Explanation
The `get_metrics` function is designed to provide a comprehensive evaluation of model predictions by calculating key regression metrics that measure how well the predicted values match the ground truth values across a dataset. Here's a brief explanation of the metrics computed:

1. **Pearson Correlation Coefficient (PCC):** Measures the linear correlation between predicted and actual values. The function calculates this both gene-wise (PCC-Gene) and patch-wise (PCC-Patch). A high PCC indicates a strong linear relationship.

1. **R² Score:** Indicates the proportion of variance in the dependent variable that is predictable from the independent variable. The function computes this score both gene-wise (R2-Gene) and patch-wise (R2-Patch). Higher R² scores signify better predictive accuracy.

3. **Mean Squared Error (MSE):** Represents the average of the squares of the differences between predicted and actual values. MSE penalizes larger errors more heavily.

4. **Mean Absolute Error (MAE):** Represents the average of the absolute differences between predicted and actual values. MAE provides a straightforward interpretation of prediction errors.

5. **Global Metric:** A composite metric that combines PCC and R² scores while penalizing MSE and MAE. It provides a single value summarizing overall performance.

We will be generating random prediction and ground truth matrices as well as a random mask where *26%* of the values will be masked.

In [21]:
import torch

# Set number of observations and genes (hypothetical)
obs = 10
genes = 8
imputed_fraction = 0.26 # This is the percentage of zeros in the mask

# Henerate random matrices
pred = torch.randn((obs,genes))
gt = torch.randn((obs,genes))

In [22]:
from spared.metrics import get_metrics

dict_metrics = get_metrics(gt_mat = gt, pred_mat = pred, mask = mask, detailed = False)
dict_metrics_detailed = get_metrics(gt_mat = gt, pred_mat = pred, mask = mask, detailed = True)

print("Metrics dictionary:")
print(dict_metrics)
print("Detailed metrics dictionary:")
print(dict_metrics_detailed)

Metrics dictionary:
{'PCC-Gene': 0.2479938268661499, 'PCC-Patch': 0.21199902892112732, 'R2-Gene': -0.48487168550491333, 'R2-Patch': -2.052999973297119, 'MSE': 1.822934627532959, 'MAE': 1.0708446502685547, 'Global': -4.971658080816269}
Detailed metrics dictionary:
{'PCC-Gene': 0.2479938268661499, 'PCC-Patch': 0.21199902892112732, 'R2-Gene': -0.48487168550491333, 'R2-Patch': -2.052999973297119, 'MSE': 1.822934627532959, 'MAE': 1.0708446502685547, 'Global': -4.971658080816269, 'detailed_PCC-Gene': [-0.0015925765037536621, 0.22179384529590607, -0.27338194847106934, 0.04636658728122711, 0.773283064365387, 0.6881754398345947, -0.19805178046226501, 0.7273579835891724], 'detailed_PCC-Patch': [0.029849424958229065, 0.5682603120803833, 0.3311068117618561, 0.14430180191993713, 0.1837208867073059, -0.1209012120962143, -0.439031720161438, 0.18234789371490479, 0.9513199925422668, 0.2890160083770752], 'detailed_R2-Gene': [-0.8768579959869385, -0.727558970451355, -1.0784060955047607, -1.43119192123413