# SPEED workflow : Training on the spatial epigenomic data without prior information from single-cell data

Dataset: The E13 mouse embryo spatial CUT&Tag-RNA-seq dataset by Zhang et al ([here](https://doi.org/10.5281/zenodo.14948507))

In [1]:
import torch
print("Whether GPU is detected:", torch.cuda.is_available())
print("CUDA version:", torch.version.cuda)

Whether GPU is detected: True
CUDA version: 11.7


In [2]:
import SPEED
import scanpy as sc

adata_input_path = 'spCUT_Tag/tile_H3K27ac.h5ad'
adata_output_path = './H3K27ac_out'

load the spatial epigenomic data. 

## load the data

Load the spatial epigenomic data without the corresponding single-cell data. 

In [3]:
adata = sc.read(adata_input_path)

## Initialize the SPEED model

Initialize the model with spatial data.

`is_spatial` is set to `True` during the second stage of training on spatial data.

`k_degree` is the degree of spatial neighbor used for spatial relative position encoding. For data with a 50 μm resolution, k is defaulted to 5. For data with a 20 μm resolution, k is recommanded to 12.

`adata_sc` is set to `None` when training without prior information from single-cell data.

In [4]:
speed = SPEED.SPEED(adata,image=None, is_spatial=True,k_degree=12, adata_sc=None)

matrix ready...
use 0-1 matrix...
cell_features ready...
peak features ready...
Without single-cell reference


### Spliting training and validation sets.

`num_workers` is the number of subprocesses for data loading (default = 4).

`data_type` sets the input data format used by SPEED. SPEED will handle this format internally, so no external action is required from the user. For lower GPU memory and faster training, it is recommended to set `dense = False` (default) when training on GPU, and `dense = True` when training on CPU.

`batch_size_cell` and `batch_size_peak` are the batch sizes at the cell-level and peak-level. SPEED will choose automatically according to dataset size, but if the batch size is too large for your GPU, you can reduce it manually.

`split_ratio` sets the proportion of the validation set at both the cell level and peak level. (default = [1/6, 1/6])

In [5]:
speed.setup_data(num_workers=4)

batch_size_cell = 1024, batch_size_peak = 32768
split ready...
labels ready...
peak embedding is given
dataset ready...


### Build the neural network model for SPEED.

`emb_features` is the number of embedding features (default = 32).

`dropout_p` is the dropout probability of the model. For spatial data training, `dropout_p` is recommended to 0.4.

In [6]:
speed.build_model(emb_features=32,dropout_p=0.4)

## Train the SPEED model

`lr` is the learning rate. `device` specifies whether to train with GPU or CPU.

`epoch_num` is the maximum number of training epochs (default = 500). If no improvement is observed on the validation set within `epo_max` epochs, training is considered converged and will stop (default `epo_max=30`).

`alpha` represents the weight of the constraint on the similarity between peak embeddings of spatial data. The default value is 10. A larger `alpha` means the model relies more on single-cell prior information. 

`beta` represents the importance of image information for spot embedding. The default value is 1. A larger `beta` means the model relies more on image information.

In [7]:
speed.train(lr=1e-5, device='cuda:2')

Use spatial information...
Starting training...
trainset:  (7809, 245219) (204350, 9370)


  return torch.sparse_csr_tensor(temp.indptr,temp.indices, temp.data, size=temp.shape)
100%|██████████| 8/8 [00:19<00:00,  2.42s/it]
100%|██████████| 2/2 [00:01<00:00,  1.47it/s]


Epoch[1/500], Loss: 0.86955, Val Loss: 0.86762, 


100%|██████████| 8/8 [00:19<00:00,  2.50s/it]
100%|██████████| 2/2 [00:01<00:00,  1.22it/s]


Epoch[2/500], Loss: 0.86895, Val Loss: 0.86786, 


100%|██████████| 8/8 [00:20<00:00,  2.51s/it]
100%|██████████| 2/2 [00:01<00:00,  1.15it/s]


Epoch[3/500], Loss: 0.86744, Val Loss: 0.86627, 


100%|██████████| 8/8 [00:20<00:00,  2.54s/it]
100%|██████████| 2/2 [00:01<00:00,  1.25it/s]


Epoch[4/500], Loss: 0.86524, Val Loss: 0.86435, 


100%|██████████| 8/8 [00:20<00:00,  2.53s/it]
100%|██████████| 2/2 [00:01<00:00,  1.15it/s]


Epoch[5/500], Loss: 0.86336, Val Loss: 0.86189, 


100%|██████████| 8/8 [00:20<00:00,  2.57s/it]
100%|██████████| 2/2 [00:01<00:00,  1.10it/s]


Epoch[6/500], Loss: 0.86147, Val Loss: 0.86035, 


100%|██████████| 8/8 [00:20<00:00,  2.54s/it]
100%|██████████| 2/2 [00:01<00:00,  1.20it/s]


Epoch[7/500], Loss: 0.85942, Val Loss: 0.85643, 


100%|██████████| 8/8 [00:20<00:00,  2.60s/it]
100%|██████████| 2/2 [00:02<00:00,  1.03s/it]


Epoch[8/500], Loss: 0.85776, Val Loss: 0.85600, 


100%|██████████| 8/8 [00:20<00:00,  2.57s/it]
100%|██████████| 2/2 [00:01<00:00,  1.12it/s]


Epoch[9/500], Loss: 0.85609, Val Loss: 0.85489, 


100%|██████████| 8/8 [00:20<00:00,  2.58s/it]
100%|██████████| 2/2 [00:01<00:00,  1.13it/s]


Epoch[10/500], Loss: 0.85437, Val Loss: 0.85245, 


100%|██████████| 8/8 [00:20<00:00,  2.51s/it]
100%|██████████| 2/2 [00:01<00:00,  1.18it/s]


Epoch[11/500], Loss: 0.85246, Val Loss: 0.84992, 


100%|██████████| 8/8 [00:20<00:00,  2.54s/it]
100%|██████████| 2/2 [00:01<00:00,  1.14it/s]


Epoch[12/500], Loss: 0.85066, Val Loss: 0.84763, 


100%|██████████| 8/8 [00:20<00:00,  2.59s/it]
100%|██████████| 2/2 [00:01<00:00,  1.08it/s]


Epoch[13/500], Loss: 0.84903, Val Loss: 0.84766, 


100%|██████████| 8/8 [00:20<00:00,  2.61s/it]
100%|██████████| 2/2 [00:01<00:00,  1.08it/s]


Epoch[14/500], Loss: 0.84731, Val Loss: 0.84530, 


100%|██████████| 8/8 [00:20<00:00,  2.57s/it]
100%|██████████| 2/2 [00:01<00:00,  1.15it/s]


Epoch[15/500], Loss: 0.84602, Val Loss: 0.84621, 


100%|██████████| 8/8 [00:20<00:00,  2.56s/it]
100%|██████████| 2/2 [00:01<00:00,  1.15it/s]


Epoch[16/500], Loss: 0.84474, Val Loss: 0.84445, 


100%|██████████| 8/8 [00:20<00:00,  2.56s/it]
100%|██████████| 2/2 [00:01<00:00,  1.14it/s]


Epoch[17/500], Loss: 0.84383, Val Loss: 0.84484, 


100%|██████████| 8/8 [00:20<00:00,  2.56s/it]
100%|██████████| 2/2 [00:01<00:00,  1.15it/s]


Epoch[18/500], Loss: 0.84268, Val Loss: 0.84359, 


100%|██████████| 8/8 [00:20<00:00,  2.55s/it]
100%|██████████| 2/2 [00:01<00:00,  1.13it/s]


Epoch[19/500], Loss: 0.84199, Val Loss: 0.84330, 


100%|██████████| 8/8 [00:20<00:00,  2.56s/it]
100%|██████████| 2/2 [00:01<00:00,  1.13it/s]


Epoch[20/500], Loss: 0.84126, Val Loss: 0.84392, 


100%|██████████| 8/8 [00:20<00:00,  2.55s/it]
100%|██████████| 2/2 [00:01<00:00,  1.11it/s]


Epoch[21/500], Loss: 0.84067, Val Loss: 0.84347, 


100%|██████████| 8/8 [00:21<00:00,  2.64s/it]
100%|██████████| 2/2 [00:02<00:00,  1.05s/it]


Epoch[22/500], Loss: 0.84026, Val Loss: 0.84361, 


100%|██████████| 8/8 [00:21<00:00,  2.70s/it]
100%|██████████| 2/2 [00:02<00:00,  1.03s/it]


Epoch[23/500], Loss: 0.83979, Val Loss: 0.84260, 


100%|██████████| 8/8 [00:20<00:00,  2.62s/it]
100%|██████████| 2/2 [00:01<00:00,  1.12it/s]


Epoch[24/500], Loss: 0.83935, Val Loss: 0.84179, 


100%|██████████| 8/8 [00:20<00:00,  2.60s/it]
100%|██████████| 2/2 [00:01<00:00,  1.12it/s]


Epoch[25/500], Loss: 0.83882, Val Loss: 0.84163, 


100%|██████████| 8/8 [00:20<00:00,  2.60s/it]
100%|██████████| 2/2 [00:01<00:00,  1.08it/s]


Epoch[26/500], Loss: 0.83810, Val Loss: 0.84141, 


100%|██████████| 8/8 [00:20<00:00,  2.60s/it]
100%|██████████| 2/2 [00:01<00:00,  1.10it/s]


Epoch[27/500], Loss: 0.83776, Val Loss: 0.84155, 


100%|██████████| 8/8 [00:20<00:00,  2.54s/it]
100%|██████████| 2/2 [00:01<00:00,  1.05it/s]


Epoch[28/500], Loss: 0.83725, Val Loss: 0.84171, 


100%|██████████| 8/8 [00:20<00:00,  2.58s/it]
100%|██████████| 2/2 [00:01<00:00,  1.11it/s]


Epoch[29/500], Loss: 0.83690, Val Loss: 0.84214, 


100%|██████████| 8/8 [00:20<00:00,  2.52s/it]
100%|██████████| 2/2 [00:01<00:00,  1.14it/s]


Epoch[30/500], Loss: 0.83630, Val Loss: 0.84075, 


100%|██████████| 8/8 [00:20<00:00,  2.61s/it]
100%|██████████| 2/2 [00:01<00:00,  1.05it/s]


Epoch[31/500], Loss: 0.83602, Val Loss: 0.84199, 


100%|██████████| 8/8 [00:20<00:00,  2.59s/it]
100%|██████████| 2/2 [00:01<00:00,  1.12it/s]


Epoch[32/500], Loss: 0.83576, Val Loss: 0.84052, 


100%|██████████| 8/8 [00:20<00:00,  2.59s/it]
100%|██████████| 2/2 [00:01<00:00,  1.18it/s]


Epoch[33/500], Loss: 0.83504, Val Loss: 0.84230, 


100%|██████████| 8/8 [00:20<00:00,  2.57s/it]
100%|██████████| 2/2 [00:01<00:00,  1.10it/s]


Epoch[34/500], Loss: 0.83495, Val Loss: 0.84138, 


100%|██████████| 8/8 [00:20<00:00,  2.61s/it]
100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch[35/500], Loss: 0.83446, Val Loss: 0.84128, 


100%|██████████| 8/8 [00:20<00:00,  2.59s/it]
100%|██████████| 2/2 [00:01<00:00,  1.14it/s]


Epoch[36/500], Loss: 0.83404, Val Loss: 0.84249, 


100%|██████████| 8/8 [00:20<00:00,  2.60s/it]
100%|██████████| 2/2 [00:01<00:00,  1.10it/s]


Epoch[37/500], Loss: 0.83336, Val Loss: 0.84247, 


100%|██████████| 8/8 [00:20<00:00,  2.59s/it]
100%|██████████| 2/2 [00:01<00:00,  1.10it/s]


Epoch[38/500], Loss: 0.83298, Val Loss: 0.84182, 


100%|██████████| 8/8 [00:21<00:00,  2.66s/it]
100%|██████████| 2/2 [00:02<00:00,  1.03s/it]


Epoch[39/500], Loss: 0.83282, Val Loss: 0.84135, 


100%|██████████| 8/8 [00:21<00:00,  2.69s/it]
100%|██████████| 2/2 [00:01<00:00,  1.15it/s]


Epoch[40/500], Loss: 0.83219, Val Loss: 0.84137, 


100%|██████████| 8/8 [00:20<00:00,  2.61s/it]
100%|██████████| 2/2 [00:02<00:00,  1.05s/it]


Epoch[41/500], Loss: 0.83182, Val Loss: 0.84249, 


100%|██████████| 8/8 [00:22<00:00,  2.75s/it]
100%|██████████| 2/2 [00:02<00:00,  1.00s/it]


Epoch[42/500], Loss: 0.83168, Val Loss: 0.84162, 


100%|██████████| 8/8 [00:21<00:00,  2.73s/it]
100%|██████████| 2/2 [00:02<00:00,  1.00s/it]


Epoch[43/500], Loss: 0.83126, Val Loss: 0.84161, 


100%|██████████| 8/8 [00:21<00:00,  2.73s/it]
100%|██████████| 2/2 [00:02<00:00,  1.03s/it]


Epoch[44/500], Loss: 0.83068, Val Loss: 0.84193, 


100%|██████████| 8/8 [00:20<00:00,  2.59s/it]
100%|██████████| 2/2 [00:01<00:00,  1.09it/s]


Epoch[45/500], Loss: 0.83001, Val Loss: 0.84129, 


100%|██████████| 8/8 [00:20<00:00,  2.58s/it]
100%|██████████| 2/2 [00:01<00:00,  1.15it/s]


Epoch[46/500], Loss: 0.83010, Val Loss: 0.84087, 


100%|██████████| 8/8 [00:20<00:00,  2.59s/it]
100%|██████████| 2/2 [00:01<00:00,  1.15it/s]


Epoch[47/500], Loss: 0.82956, Val Loss: 0.84172, 


100%|██████████| 8/8 [00:21<00:00,  2.71s/it]
100%|██████████| 2/2 [00:02<00:00,  1.03s/it]


Epoch[48/500], Loss: 0.82938, Val Loss: 0.84206, 


100%|██████████| 8/8 [00:22<00:00,  2.77s/it]
100%|██████████| 2/2 [00:02<00:00,  1.00s/it]


Epoch[49/500], Loss: 0.82898, Val Loss: 0.84133, 


100%|██████████| 8/8 [00:20<00:00,  2.58s/it]
100%|██████████| 2/2 [00:01<00:00,  1.15it/s]


Epoch[50/500], Loss: 0.82846, Val Loss: 0.84076, 


100%|██████████| 8/8 [00:20<00:00,  2.56s/it]
100%|██████████| 2/2 [00:01<00:00,  1.14it/s]


Epoch[51/500], Loss: 0.82814, Val Loss: 0.84154, 


100%|██████████| 8/8 [00:21<00:00,  2.65s/it]
100%|██████████| 2/2 [00:02<00:00,  1.10s/it]


Epoch[52/500], Loss: 0.82793, Val Loss: 0.84104, 


100%|██████████| 8/8 [00:20<00:00,  2.61s/it]
100%|██████████| 2/2 [00:01<00:00,  1.06it/s]


Epoch[53/500], Loss: 0.82784, Val Loss: 0.84144, 


100%|██████████| 8/8 [00:20<00:00,  2.59s/it]
100%|██████████| 2/2 [00:01<00:00,  1.21it/s]


Epoch[54/500], Loss: 0.82745, Val Loss: 0.84216, 


100%|██████████| 8/8 [00:20<00:00,  2.57s/it]
100%|██████████| 2/2 [00:01<00:00,  1.21it/s]


Epoch[55/500], Loss: 0.82715, Val Loss: 0.84213, 


100%|██████████| 8/8 [00:20<00:00,  2.58s/it]
100%|██████████| 2/2 [00:01<00:00,  1.14it/s]


Epoch[56/500], Loss: 0.82727, Val Loss: 0.84095, 


100%|██████████| 8/8 [00:20<00:00,  2.60s/it]
100%|██████████| 2/2 [00:01<00:00,  1.17it/s]


Epoch[57/500], Loss: 0.82696, Val Loss: 0.84152, 


100%|██████████| 8/8 [00:20<00:00,  2.62s/it]
100%|██████████| 2/2 [00:01<00:00,  1.16it/s]


Epoch[58/500], Loss: 0.82646, Val Loss: 0.84053, 


100%|██████████| 8/8 [00:20<00:00,  2.58s/it]
100%|██████████| 2/2 [00:01<00:00,  1.18it/s]


Epoch[59/500], Loss: 0.82661, Val Loss: 0.84155, 


100%|██████████| 8/8 [00:20<00:00,  2.58s/it]
100%|██████████| 2/2 [00:01<00:00,  1.21it/s]


Epoch[60/500], Loss: 0.82596, Val Loss: 0.84205, 


100%|██████████| 8/8 [00:20<00:00,  2.59s/it]
100%|██████████| 2/2 [00:01<00:00,  1.16it/s]


Epoch[61/500], Loss: 0.82562, Val Loss: 0.84239, 


100%|██████████| 8/8 [00:20<00:00,  2.58s/it]
100%|██████████| 2/2 [00:01<00:00,  1.16it/s]


Epoch[62/500], Loss: 0.82576, Val Loss: 0.84169, 
convinient


## Get the results

Use `SPEED.SPEED.get_embedding` to get the low-dimensional embedding.

The spot/cell embeddings will be stored in `adata.obsm['X_SPEED']`. The peak embeddings will be stored in `adata.varm['peak_SPEED']`

In [8]:
adata = speed.get_embedding(adata)

get cell/spot embedding...


100%|██████████| 5/5 [00:04<00:00,  1.13it/s]


get peak embedding...


100%|██████████| 5/5 [00:01<00:00,  4.31it/s]


get spatial embedding...


  0%|          | 0/5 [00:00<?, ?it/s]

(1875, 9370)
(1875, 9370)


100%|██████████| 5/5 [00:00<00:00, 24.61it/s]

(1875, 9370)
(1875, 9370)
(1870, 9370)
the shape of embedding: (9370, 32)





Use `SPEED.SPEED.get_denoise_result` to get the denoised matrix.

In [9]:
adata.X = speed.get_denoise_result()

In [10]:
adata = speed.binarize(adata)

100%|██████████| 9370/9370 [00:29<00:00, 313.07it/s]
100%|██████████| 245219/245219 [01:27<00:00, 2797.61it/s]


In [11]:
adata.write(f'H3K27ac_out/adata_speed_cpu.h5ad')

In [12]:
exit