# Training Demo

Here we demonstrate how to train Popari downloading and preprocessing a multisample spatial transcriptomics dataset for analysis with Popari. In particular, we will be working with the **Alzheimer's Disease (AD)** dataset from the "Preprocessing Demo" notebook.

In [1]:
# Disable warnings for prettier notebook
import warnings
warnings.filterwarnings("ignore")

In [2]:
from pathlib import Path
from tqdm.auto import trange

import torch

import popari
from popari.model import Popari
from popari import pl, tl

In [3]:
data_directory = Path("/path/to/directory/")

In [4]:
data_directory = Path("/work/magroup/shahula/spatiotemporal_transcriptomics_integration/data/STARmapPlus/SCP1375/")

In [5]:
K = 15
dataset_path = data_directory / f"preprocessed_dataset.h5ad"
context = {"device": "cuda:1", "dtype": torch.float64}

popari_example = Popari(
    K=K,
    dataset_path=dataset_path,
    torch_context=context,
    initial_context=context,
    verbose=0
)

## Pretraining

In [6]:
# Initialization with NMF
progress_bar = trange(10, leave=True)                                                                                                                                                                               
for preiteration in progress_bar:                                                             
    popari_example.estimate_parameters(update_spatial_affinities=False)
    popari_example.estimate_weights(use_neighbors=False)

  0%|          | 0/10 [00:00<?, ?it/s]

In [7]:
# Reinitialize spatial affinities
popari_example.parameter_optimizer.reinitialize_spatial_affinities()

In [8]:
# Initialization with SpiceMix
progress_bar = trange(50, leave=True)                                                                                                                                                                               
for iteration in progress_bar:
    popari_example.estimate_parameters()
    popari_example.estimate_weights()

  0%|          | 0/50 [00:00<?, ?it/s]

In [9]:
from popari.model import from_pretrained

popari_pretrained = from_pretrained(popari_example, popari_context=context, lambda_Sigma_bar=1e-4)

In [10]:
# Initialization with SpiceMix
progress_bar = trange(200, leave=True)                                                                                                                                                                               
for iteration in progress_bar:
    popari_pretrained.estimate_parameters()
    popari_pretrained.estimate_weights()

  0%|          | 0/200 [00:00<?, ?it/s]

## Hierarchical Training

Using hierarchical mode, we can train Popari more robustly using a lower resolution view of the original spatial transcriptomics data. We can then "superresolve" the embeddings at the higher resolution to regain a fine-grained view.

In [12]:
K = 15
dataset_path = data_directory / f"preprocessed_dataset.h5ad"
context = {"device": "cuda:1", "dtype": torch.float64}
hierarchical_levels = 2
superresolution_lr = 1e-1

hierarchical_example = Popari(
    K=K,
    dataset_path=dataset_path,
    torch_context=context,
    initial_context=context,
    hierarchical_levels=hierarchical_levels,
    superresolution_lr=superresolution_lr,
    verbose=0
)

[2023/06/22 17:56:43]	 Initializing hierarchy level 1
[2023/06/22 17:56:44]	 Downsized dataset from 8186 to 1298 spots.
[2023/06/22 17:56:44]	 Downsized dataset from 10372 to 1167 spots.


In [13]:
# Initialization with NMF
progress_bar = trange(10, leave=True)                                                                                                                                                                               
for preiteration in progress_bar:                                                             
    hierarchical_example.estimate_parameters(update_spatial_affinities=False)
    hierarchical_example.estimate_weights(use_neighbors=False)

  0%|          | 0/10 [00:00<?, ?it/s]

In [14]:
# Reinitialize spatial affinities
hierarchical_example.parameter_optimizer.reinitialize_spatial_affinities()

In [15]:
# Initialization with NMF
progress_bar = trange(50, leave=True)                                                                                                                                                                               
for iteration in progress_bar:                                                             
    hierarchical_example.estimate_parameters()
    hierarchical_example.estimate_weights()

  0%|          | 0/50 [00:00<?, ?it/s]

In [16]:
hierarchical_example.superresolve(n_epochs=5000, tol=1e-6)

### Save results to disk

In [18]:
hierarchical_example.save_results(data_directory / f"hierarchical_results.h5ad")

### Load a pretrained model

In [19]:
from popari.model import load_trained_model

In [20]:
reloaded_model = load_trained_model(data_directory / f"hierarchical_results.h5ad")