# Training Demo

Here we demonstrate how to train Popari downloading and preprocessing a multisample spatial transcriptomics dataset for analysis with Popari. In particular, we will be working with the **Alzheimer's Disease (AD)** dataset from the "Preprocessing Demo" notebook.

In [1]:
# Disable warnings for prettier notebook
import warnings
warnings.filterwarnings("ignore")

In [2]:
from pathlib import Path
from tqdm.auto import trange

import torch

import popari
from popari.model import Popari
from popari import pl, tl
from popari.train import TrainParameters, Trainer

In [3]:
data_directory = Path("/path/to/directory/")

In [4]:
data_directory = Path("/work/magroup/shahula/spatiotemporal_transcriptomics_integration/data/STARmapPlus/SCP1375/")

In [5]:
model_parameters = {
    'K': 15,
    'dataset_path': data_directory / f"preprocessed_dataset.h5ad",
    'lambda_Sigma_x_inv': 1e-4,
    'lambda_Sigma_bar': 1e-4,
    'initial_context': {
        'device': 'cuda:0',
        'dtype': torch.float64
    },
    'torch_context': {
        'device': 'cuda:0',
        'dtype': torch.float64
    },
    'verbose': 0,
    'spatial_affinity_mode': 'differential lookup',
}

In [6]:
popari_example = Popari(**model_parameters)

## Training Loop

In [8]:
train_parameters = TrainParameters(
    nmf_iterations=10,
    iterations=200,
    savepath=data_directory / "results.h5ad",
)

In [9]:
trainer = Trainer(
    parameters=train_parameters,
    model=popari_example,
    verbose=1,
)

Below, we train Popari for `200` iterations; this should take ~30 minutes on a standard GPU.

In [10]:
trainer.train()

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

## Hierarchical Training

Using hierarchical mode, we can train Popari more robustly using a lower resolution view of the original spatial transcriptomics data. We can then "superresolve" the embeddings at the higher resolution to regain a fine-grained view.

In [5]:
hierachical_parameters = {
    'K': 15,
    'dataset_path': data_directory / f"preprocessed_dataset.h5ad",
    'lambda_Sigma_x_inv': 1e-4,
    'lambda_Sigma_bar': 1e-4,
    'initial_context': {
        'device': 'cuda:0',
        'dtype': torch.float64
    },
    'torch_context': {
        'device': 'cuda:0',
        'dtype': torch.float64
    },
    'verbose': 0,
    'spatial_affinity_mode': 'differential lookup',
    'downsampling_method': 'partition',
    'hierarchical_levels': 2,
    'superresolution_lr': 1e-1,
}

In [8]:
K = 15
dataset_path = data_directory / f"preprocessed_dataset.h5ad"
context = {"device": "cuda:0", "dtype": torch.float64}
hierarchical_levels = 2
superresolution_lr = 1e-1

hierarchical_example = Popari(**hierachical_parameters)

[2024/10/15 00:37:39]	 Initializing hierarchy level 1
[2024/10/15 00:37:40]	 Downsized dataset from 8186 to 1637 spots.
[2024/10/15 00:37:42]	 Downsized dataset from 10372 to 2074 spots.


In [9]:
hierarchical_train_parameters = TrainParameters(
    nmf_iterations=10,
    iterations=200,
    savepath=data_directory / "hierarchical_results.h5ad",
)

In [10]:
hierarchical_trainer = Trainer(
    parameters=hierarchical_train_parameters,
    model=hierarchical_example,
    verbose=True,
)

In [11]:
hierarchical_trainer.train()

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

The optimization for the hierarchical trainer is done at the lowest resolution (`level = model.hierarchical_levels - 1`). In order to recover spatially-informed embeddings `X` for the higher resolutions, we provide a superresolution subroutine that can be run after the main training loop. 

In [12]:
hierarchical_trainer.superresolve(n_epochs=10000, tol=1e-6)

### Save results to disk

In [1]:
hierarchical_example.save_results(data_directory / f"hierarchical_results", ignore_raw_data=False)

NameError: name 'hierarchical_example' is not defined

### Load a pretrained model

In [20]:
from popari.model import load_trained_model

In [21]:
reloaded_model = load_trained_model(data_directory / f"hierarchical_results")

[2024/10/13 15:53:55]	 Reloading level 0
[2024/10/13 15:53:55]	 Reloading level 1
