## Bulk2Space Tutorial
Jie Liao,  Jingyang Qian, Yin Fang, Zhuo Chen, Xiang Zhuang et al.

## Outline
1. [Installation](#Installation)
2. [Import modules](#Import-modules)
3. [Parameter definition](#Parameter-definition)
4. [Load data](#Load-data)
5. [Calculate marker genes of each celltype](#Marker-used)
6. [Data processing](#Data-processing)
7. [Celltype ratio calculation](#Celltype-ratio-calculation)
8. [Prepare the model input](#Prepare-the-model-input)
9. [Model training/loading](#Model-training/loading)
10. [Data generation](#Data-generation)
11. [Data saving](#Data-saving)
12. [Data mapping](#12-a-id"data-mapping"mapping-generated-single-cells-to-spatial-locationsa)


### 1. <a id="Installation">Installation</a>
The installation should take a few minutes on a normal computer. To install Bulk2Space package you must make sure that your python version is over `3.8`. If you don’t know the version of python you can check it by:

In [15]:
import platform
print(platform.python_version())

3.7.10


Note: Because our Bulk2Space dpends on pytorch, you'd better make sure the torch is correctly installed.

### 2. <a id="Import-modules">Import modules</a>

In [16]:
from utils.tool import *
from utils.config import cfg, loadArgums
import numpy as np
import pandas as pd
import torch
import scanpy
from scipy.optimize import nnls
from collections import defaultdict
import argparse
import warnings
warnings.filterwarnings('ignore')

### 3. <a id="Parameter-definition">Parameter definition</a>
For the current version of Bulk2Space,

some parameters should be revised  according to the actual running environment and file Hierarchy:

- `gpu_id`: The GPU ID, eg:`--gpu_id 0`
- `project_name`: The name of your project, eg:`--project_name experiment1`
- `input_bulk_path`: The name of the input bulk-seq data, eg:`--input_bulk_path bulk_data.csv`
- `input_sc_data_path`: The name of the input scRNA-seq data, eg:`--input_sc_data_path sc_data.csv`
- `input_sc_meta_path`: The name of the input scRNA-seq meta, eg:`--input_sc_meta_path sc_meta.csv`
- `input_st_data_path`: The name of the input spatial transcriptomics data, eg:`--input_st_data_path st_data.csv`
- `input_st_meta_path`: The name of the input spatial transcriptomics meta, eg:`--input_st_meta_path st_meta.csv`
- `load_model_1`: Whether to load the trained bulk-deconvolution model, eg:`--load_model_1 False`
- `load_path_1`: The path of the trained bulk-deconvolution model to be loaded
- `train_model_2`: Whether to train the spatial mapping model, eg:`--train_model_2 True`
- `load_path_2`: The path of the trained spatial mapping model to be loaded
- `output_path`: The name of the folder where you store the output data, eg:`--output_path output_data`


some parameters could be revised  as needed:
- `BetaVAE_H`: Whether to use β-VAE model or not, eg:`--BetaVAE_H`
- `batch_size`: The batch size for β-VAE/VAE model training, eg:`--batch_size 512`
- `learning_rate`: The learning rate for β-VAE/VAE model training, eg:`--learning_rate 0.0001`
- `hidden_size`: The hidden size of β-VAE/VAE model, eg:`--hidden_size 256`
- `hidden_lay`: The hidden layer of β-VAE/VAE model(0:[2048, 1024, 512] \n 1: [4096, 2048, 1024, 512] \n 2: [8192, 4096, 2048, 1024]), eg:`--hidden_lay 0`
- `epoch_num`: The epoch number for β-VAE/VAE model training, eg:`--epoch_num 5000`
- `not_early_stop`: Whether to use the `early_stop` strategy, eg:`--not_early_stop False`
- `early_stop`: The model waits N epochs before stops if no progress on the validation set or the training loss dose not decline, eg:`--early_stop 50`
- `k`: The number of cells per spot set in spatial mapping step, eg:`--k 10`
- `marker_used`: Whether to only use marker genes of each celltype when calculating the celltype proportion, eg:`--marker_used True`
- `top_marker_num`: The number of marker genes of each celltype used, eg:`--top_marker_num 500`
- `ratio_num`: The multiples of the number of cells of generated scRNA-seq data, eg:`--ratio_num 1`
- `spot_data`: The type of the input spatial transcriptomics data, `True` for barcoded-based ST data (like ST, 10x Visium or Slide-seq) and  `False` for image-based ST data (like MERFISH, SeqFISH or STARmap)


In [17]:
global args 
args = dict(
    BetaVAE_H=True,
    batch_size=512,
    cell_num=10,
    data_path='example_data/demo1',
    dump_path='/data/zhuangxiang/code/bulk2space/bulk2space/dump',
    early_stop=50,
    epoch_num=10,
    exp_id='LR_0.0001_hiddenSize_256_lay_choice_0',
    exp_name='test1',
    feature_size=6588,
    gpu_id=-1,
    hidden_lay=0,
    hidden_size=256,
    input_bulk_path='/data/zhuangxiang/code/bulk2space/bulk2space/data/example_data/demo1/demo1_bulk.csv',
    input_sc_data_path='/data/zhuangxiang/code/bulk2space/bulk2space/data/example_data/demo1/demo1_sc_data.csv',
    input_sc_meta_path='/data/zhuangxiang/code/bulk2space/bulk2space/data/example_data/demo1/demo1_sc_meta.csv',
    input_st_data_path='/data/zhuangxiang/code/bulk2space/bulk2space/data/example_data/demo1/demo1_st_data.csv',
    input_st_meta_path='/data/zhuangxiang/code/bulk2space/bulk2space/data/example_data/demo1/demo1_st_meta.csv',
    k=10,
    kl_loss=False,
    learning_rate=0.0001,
    load_model_1=False,
    load_path_1='/data/zhuangxiang/code/bulk2space/bulk2space/save_model/',
    load_path_2='/data/zhuangxiang/code/bulk2space/bulk2space/save_model/',
    marker_used=True,
    max_cell_in_diff_spot_ratio=None,
    model_choice_1='vae', model_choice_2='df',
    mul_test=5,
    mul_train=1,
    no_tensorboard=False,
    not_early_stop=False,
    num_workers=12,
    output_path='/data/zhuangxiang/code/bulk2space/bulk2space/output_data',
    previous_project_name='demo', project_name='test1',
    random_seed=12345, ratio_num=1,
    save='/data/zhuangxiang/code/bulk2space/bulk2space/save_model',
    spot_data=True, spot_num=500,
    top_marker_num=500, train_model_2=True,
    xtest='xtest', xtrain='xtrain', ytest='ytest', ytrain='ytrain'
)
args = argparse.Namespace(**args)

In [18]:
used_device = torch.device(f"cuda:{args.gpu_id}") if args.gpu_id >= 0 and torch.cuda.is_available() else torch.device('cpu')

In [19]:
input_sc_meta_path = args.input_sc_meta_path
input_sc_data_path = args.input_sc_data_path
input_bulk_path = args.input_bulk_path
input_st_meta_path = args.input_st_meta_path
input_st_data_path = args.input_st_data_path

### 4. <a id="Load-data">Load data</a>
`Bulk2Space` requires five formatted data as input:
- Bulk-seq Normalized Data
    - a `.csv` file with genes as rows and sample as column
     
    
    |         | Sample  | 
    | :-----: | :-----: | 
    | Gene1   | 5.22    |
    | Gene2   | 3.67    |
    | ...     | ...     |
    | GeneN   | 15.76   |
    
- Single Cell RNA-seq Normalized Data
    - a `.csv` file with genes as rows and cells as columns
- Single Cell RNA-seq Annotation Data
    - a `.csv` file with cell names and celltype annotation columns. The column containing cell names should be named `Cell` and the column containing the labels should be named `Cell_type`
- Spatial Transcriptomics Normalized Data
    - a `.csv` file with genes as rows and cells/spots as columns
- Spatial Transcriptomics Coordinates Data
    - a `.csv` with cell/spot names and coordinates columns. The column containing cell/spot names should be named `Spot` and the column containing the coordinates should be named `xcoord` and `ycoord`


In [20]:
print("loading data......")

# load sc_meta.csv file, containing two columns of cell name and cell type
input_sc_meta = pd.read_csv(input_sc_meta_path, index_col=0)
# load sc_data.csv file, containing gene expression of each cell
input_sc_data = pd.read_csv(input_sc_data_path, index_col=0)
sc_gene = input_sc_data._stat_axis.values.tolist()
# load bulk.csv file, containing one column of gene expression in bulk
input_bulk = pd.read_csv(input_bulk_path, index_col=0)
bulk_gene = input_bulk._stat_axis.values.tolist()
# filter overlapping genes.
intersect_gene = list(set(sc_gene).intersection(set(bulk_gene)))
input_sc_data = input_sc_data.loc[intersect_gene]
input_bulk = input_bulk.loc[intersect_gene]
# load st_meta.csv and st_data.csv, containing coordinates and gene expression of each spot respectively.
input_st_meta = pd.read_csv(input_st_meta_path, index_col=0)
input_st_data = pd.read_csv(input_st_data_path, index_col=0)
print("load data ok")

loading data......
load data ok


input_sc_meta

In [None]:
input_sc_data

In [None]:
input_st_meta

In [None]:
input_st_data

### 5. <a id="Marker-used">Calculate marker genes of each celltype</a>


In [None]:
sc = scanpy.AnnData(input_sc_data.T)
sc.obs = input_sc_meta[['Cell_type']]
scanpy.tl.rank_genes_groups(sc, 'Cell_type', method='wilcoxon')
marker_df = pd.DataFrame(sc.uns['rank_genes_groups']['names']).head(args.top_marker_num)
marker_array = np.array(marker_df)
marker_array = np.ravel(marker_array)
marker_array = np.unique(marker_array)
marker = list(marker_array)
sc_marker = input_sc_data.loc[marker, :]
bulk_marker = input_bulk.loc[marker]

### 6. <a id="Data-processing">Data processing</a>

In [None]:
breed = input_sc_meta['Cell_type']
breed_np = breed.values
breed_set = set(breed_np)
id2label = sorted(list(breed_set))  # List of breed
label2id = {label: idx for idx, label in enumerate(id2label)}  # map breed to breed-id

cell2label = dict()  # map cell-name to breed-id
label2cell = defaultdict(set)  # map breed-id to cell-names
for row in input_sc_meta.itertuples():
    cell_name = getattr(row, 'Cell')
    cell_type = label2id[getattr(row, 'Cell_type')]
    cell2label[cell_name] = cell_type
    label2cell[cell_type].add(cell_name)
label_devide_data = dict()
for label, cells in label2cell.items():
    label_devide_data[label] = sc_marker[list(cells)]

single_cell_splitby_breed_np = {}
for key in label_devide_data.keys():
    single_cell_splitby_breed_np[key] = label_devide_data[key].values  # [gene_num, cell_num]
    single_cell_splitby_breed_np[key] = single_cell_splitby_breed_np[key].mean(axis=1)

max_decade = len(single_cell_splitby_breed_np.keys())
single_cell_matrix = []
#
for i in range(max_decade):
    single_cell_matrix.append(single_cell_splitby_breed_np[i].tolist())


single_cell_matrix = np.array(single_cell_matrix)
single_cell_matrix = np.transpose(single_cell_matrix)  # (gene_num, label_num)

bulk_marker = bulk_marker.values  # (gene_num, 1)
bulk_rep = bulk_marker.reshape(bulk_marker.shape[0],)

### 7.  <a id="Celltype-ratio-calculation">Celltype ratio calculation</a>


In [None]:
# calculate celltype ratio in each spot by NNLS
ratio = nnls(single_cell_matrix, bulk_rep)[0]
ratio = ratio/sum(ratio)

ratio_array = np.round(ratio * input_sc_meta.shape[0] * args.ratio_num)
ratio_list = [r for r in ratio_array]

cell_target_num = dict(zip(id2label, ratio_list))

### 8. <a id="Prepare-the-model-input">Prepare the model input</a>



In [None]:
# *********************************************************************
# input：data， celltype， bulk & output: label, dic, single_cell
single_cell = input_sc_data.values.T  # single cell data (600 * 6588)
index_2_gene = (input_sc_data.index).tolist()
breed = input_sc_meta['Cell_type']
breed_np = breed.values
breed_set = set(breed_np)
breed_2_list = list(breed_set)
dic = {}  # breed_set to index {'B cell': 0, 'Monocyte': 1, 'Dendritic cell': 2, 'T cell': 3}
label = []  # the label of cell (with index correspond)
cfg.nclass = len(breed_set)

cfg.ntrain = single_cell.shape[0]
cfg.FeaSize = single_cell.shape[1]
args.feature_size = single_cell.shape[1]
assert cfg.nclass == len(cell_target_num.keys()), "cell type num no match!!!"

for i in range(len(breed_set)):
    dic[breed_2_list[i]] = i
cell = input_sc_meta["Cell"].values

for i in range(cell.shape[0]):
    label.append(dic[breed_np[i]])

label = np.array(label)

# label index the data size of corresponding target
cell_number_target_num = {}
for k, v in cell_target_num.items():
    cell_number_target_num[dic[k]] = v
# *********************************************************************
# generate data by vae
load_model_1 = args.load_model_1
model_choice_1 = args.model_choice_1

### 9. <a id="Model-training/loading">Model training/loading</a>

In [None]:
ratio = -1
if not load_model_1:  # train
    print("begin vae model training...")
    # ********************* training *********************
    net = train_vae(args, single_cell, cfg, label, used_device)
    # ************** training finished *******************
    print("vae training finished!")
else:  # load model
    print("begin vae model loading...")
    net = load_vae(args, cfg, used_device)
    print("vae load finished!")


### 10. <a id="Data-generation">Data generation</a>


In [None]:
# generate and out put
generate_sc_meta, generate_sc_data = generate_vae(net, args, ratio, single_cell, cfg, label, breed_2_list, index_2_gene, cell_number_target_num, used_device)


In [None]:
generate_sc_meta

In [None]:
generate_sc_data

### 11. <a id="Data-saving">Data saving</a>

In [None]:
# saving.....
path = osp.join(args.output_path, args.project_name, 'predata')
if not osp.exists(path):
    os.makedirs(path)
name = "vae"
# kl_loss BetaVAE_H
if args.BetaVAE_H:
    name = "BetaVAE"
path_label_generate_csv = os.path.join(path, args.project_name + "_celltype_pred_" + name + "_epoch" + str(args.epoch_num) + '_lr' + str(args.learning_rate) + ".csv")
path_cell_generate_csv = os.path.join(path, args.project_name + "_data_pred_" + name + "_epoch" + str(args.epoch_num) + '_lr' + str(args.learning_rate) + ".csv")

generate_sc_meta.to_csv(path_label_generate_csv)
generate_sc_data.to_csv(path_cell_generate_csv)

print("bulk deconvolution finish!")

print('start to map data to space...')

### 12. <a id="Data-Mapping">Mapping generated single cells to spatial locations</a>

After obtaining the generated data, we needed to map each cell and the spot it belongs to.

In [21]:
processer = CreatData(generate_sc_data, generate_sc_meta, input_st_data, args)
processer.cre_data()
runner = Runner(generate_sc_data, generate_sc_meta, input_st_data, input_st_meta, args)
print('start to train the model...')
pred_meta, pred_spot = runner.run()
print("spatial mapping done!")

Unnamed: 0,C_417,C_369,C_265,C_450,C_440,C_347,C_463,C_429,C_442,C_430,...,C_269,C_257,C_50,C_304,C_313,C_439,C_363,C_90,C_54,C_289
SAMD3,0.0,0.0,0.00000,0.000000,0.0,0.0,0.000000,0.0,0.000000,2.178143,...,0.000000,0.0000,0.000000,0.0,0.000000,0.000000,0.0,0.00000,0.0,0.0
NDRG3,0.0,0.0,0.00000,0.000000,0.0,0.0,0.000000,0.0,0.000000,0.000000,...,0.000000,0.0000,0.000000,0.0,1.923612,0.000000,0.0,0.00000,0.0,0.0
ERAP2,0.0,0.0,0.00000,0.000000,0.0,0.0,0.000000,0.0,0.000000,0.000000,...,0.000000,2.1534,1.554921,0.0,0.000000,0.000000,0.0,0.00000,0.0,0.0
IL16,0.0,0.0,0.00000,1.948529,0.0,0.0,0.000000,0.0,0.000000,0.000000,...,2.198419,2.1534,0.000000,0.0,1.923612,0.000000,0.0,1.57153,0.0,0.0
ELP2,0.0,0.0,0.00000,0.000000,0.0,0.0,0.000000,0.0,0.000000,0.000000,...,0.000000,0.0000,0.000000,0.0,0.000000,0.000000,0.0,1.57153,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
GUCD1,0.0,0.0,1.43938,0.000000,0.0,0.0,0.000000,0.0,0.000000,0.000000,...,0.000000,0.0000,0.000000,0.0,0.000000,0.000000,0.0,0.00000,0.0,0.0
TNK2,0.0,0.0,0.00000,0.000000,0.0,0.0,0.000000,0.0,1.534554,0.000000,...,0.000000,0.0000,0.000000,0.0,0.000000,0.000000,0.0,0.00000,0.0,0.0
ABI3,0.0,0.0,0.00000,0.000000,0.0,0.0,0.000000,0.0,0.000000,0.000000,...,0.000000,0.0000,0.000000,0.0,0.000000,1.775942,0.0,0.00000,0.0,0.0
SEC31B,0.0,0.0,1.43938,0.000000,0.0,0.0,2.099617,0.0,1.534554,0.000000,...,2.198419,0.0000,0.000000,0.0,0.000000,0.000000,0.0,0.00000,0.0,0.0


In [22]:
input_st_meta

Unnamed: 0,Spot,xcoord,ycoord
1,spot_1,1,1
2,spot_2,2,1
3,spot_3,3,1
4,spot_4,4,1
5,spot_5,5,1
6,spot_6,6,1
7,spot_7,1,2
8,spot_8,2,2
9,spot_9,3,2
10,spot_10,4,2


In [23]:
input_st_data

Unnamed: 0,spot_1,spot_2,spot_3,spot_4,spot_5,spot_6,spot_7,spot_8,spot_9,spot_10,...,spot_21,spot_22,spot_23,spot_24,spot_25,spot_26,spot_27,spot_28,spot_29,spot_30
AAGAB,0.000000,0.000000,0.000000,2.365428,2.365428,0.793433,0.000000,2.337321,4.702748,0.000000,...,0.000000,0.000000,0.793433,0.000000,0.793433,0.000000,0.000000,2.337321,0.000000,0.793433
AAK1,0.000000,0.000000,0.000000,2.556952,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,2.556952,0.000000,0.000000,0.000000,0.000000,2.556952,0.000000,0.000000
AAMP,1.718439,0.000000,1.718439,1.718439,2.057984,4.690727,0.000000,2.057984,0.000000,0.000000,...,1.084092,3.776423,2.851417,0.000000,0.793433,1.839311,3.142075,1.238926,1.839311,2.511872
AASDHPPT,2.525116,0.000000,0.000000,3.477199,0.000000,0.000000,0.952409,1.883208,2.204595,5.538057,...,0.000000,4.866499,4.087803,2.702439,3.157004,0.952409,5.049906,5.227229,4.626074,0.778696
AATF,1.572888,0.000000,1.572888,0.952409,1.572888,1.630440,3.139488,2.187079,3.741098,2.525298,...,0.000000,1.449957,3.080397,2.291141,5.605695,0.952409,2.187079,1.238926,1.449957,1.630440
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
BUD23,1.572888,0.000000,3.503805,2.177013,3.503805,0.793433,0.952409,0.000000,2.291141,2.525298,...,3.015008,1.163867,2.724349,4.222057,5.249647,2.177013,1.084092,1.930916,1.224603,3.888216
NSD2,1.978389,0.000000,1.978389,0.000000,1.978389,0.000000,0.000000,0.000000,2.604381,0.000000,...,2.604381,2.424774,0.000000,0.000000,2.424774,1.978389,0.000000,0.000000,2.604381,0.000000
NSD3,11.510267,6.187656,1.978389,4.434076,5.874775,9.493980,7.769172,6.230352,3.810758,3.326291,...,3.056448,8.722214,2.851417,4.418391,4.191878,10.662594,11.980052,2.028423,5.933309,3.025372
PYM1,2.874875,0.000000,2.874875,3.459819,1.572888,0.000000,1.301986,0.000000,1.449957,1.572888,...,1.301986,2.751944,1.449957,3.459819,3.022846,0.000000,0.000000,0.000000,1.449957,0.000000


### 5. <a id="Marker-used">Calculate marker genes of each celltype</a>


In [24]:
sc = scanpy.AnnData(input_sc_data.T)
sc.obs = input_sc_meta[['Cell_type']]
scanpy.tl.rank_genes_groups(sc, 'Cell_type', method='wilcoxon')
marker_df = pd.DataFrame(sc.uns['rank_genes_groups']['names']).head(args.top_marker_num)
marker_array = np.array(marker_df)
marker_array = np.ravel(marker_array)
marker_array = np.unique(marker_array)
marker = list(marker_array)
sc_marker = input_sc_data.loc[marker, :]
bulk_marker = input_bulk.loc[marker]

### 6. <a id="Data-processing">Data processing</a>

In [25]:
breed = input_sc_meta['Cell_type']
breed_np = breed.values
breed_set = set(breed_np)
id2label = sorted(list(breed_set))  # List of breed
label2id = {label: idx for idx, label in enumerate(id2label)}  # map breed to breed-id

cell2label = dict()  # map cell-name to breed-id
label2cell = defaultdict(set)  # map breed-id to cell-names
for row in input_sc_meta.itertuples():
    cell_name = getattr(row, 'Cell')
    cell_type = label2id[getattr(row, 'Cell_type')]
    cell2label[cell_name] = cell_type
    label2cell[cell_type].add(cell_name)
label_devide_data = dict()
for label, cells in label2cell.items():
    label_devide_data[label] = sc_marker[list(cells)]

single_cell_splitby_breed_np = {}
for key in label_devide_data.keys():
    single_cell_splitby_breed_np[key] = label_devide_data[key].values  # [gene_num, cell_num]
    single_cell_splitby_breed_np[key] = single_cell_splitby_breed_np[key].mean(axis=1)

max_decade = len(single_cell_splitby_breed_np.keys())
single_cell_matrix = []
#
for i in range(max_decade):
    single_cell_matrix.append(single_cell_splitby_breed_np[i].tolist())


single_cell_matrix = np.array(single_cell_matrix)
single_cell_matrix = np.transpose(single_cell_matrix)  # (gene_num, label_num)

bulk_marker = bulk_marker.values  # (gene_num, 1)
bulk_rep = bulk_marker.reshape(bulk_marker.shape[0],)

### 7.  <a id="Celltype-ratio-calculation">Celltype ratio calculation</a>


In [26]:
# calculate celltype ratio in each spot by NNLS
ratio = nnls(single_cell_matrix, bulk_rep)[0]
ratio = ratio/sum(ratio)

ratio_array = np.round(ratio * input_sc_meta.shape[0] * args.ratio_num)
ratio_list = [r for r in ratio_array]

cell_target_num = dict(zip(id2label, ratio_list))

### 8. <a id="Prepare-the-model-input">Prepare the model input</a>



In [27]:
# *********************************************************************
# input：data， celltype， bulk & output: label, dic, single_cell
single_cell = input_sc_data.values.T  # single cell data (600 * 6588)
index_2_gene = (input_sc_data.index).tolist()
breed = input_sc_meta['Cell_type']
breed_np = breed.values
breed_set = set(breed_np)
breed_2_list = list(breed_set)
dic = {}  # breed_set to index {'B cell': 0, 'Monocyte': 1, 'Dendritic cell': 2, 'T cell': 3}
label = []  # the label of cell (with index correspond)
cfg.nclass = len(breed_set)

cfg.ntrain = single_cell.shape[0]
cfg.FeaSize = single_cell.shape[1]
args.feature_size = single_cell.shape[1]
assert cfg.nclass == len(cell_target_num.keys()), "cell type num no match!!!"

for i in range(len(breed_set)):
    dic[breed_2_list[i]] = i
cell = input_sc_meta["Cell"].values

for i in range(cell.shape[0]):
    label.append(dic[breed_np[i]])

label = np.array(label)

# label index the data size of corresponding target
cell_number_target_num = {}
for k, v in cell_target_num.items():
    cell_number_target_num[dic[k]] = v
# *********************************************************************
# generate data by vae
load_model_1 = args.load_model_1
model_choice_1 = args.model_choice_1

### 9. <a id="Model-training/loading">Model training/loading</a>

In [28]:
ratio = -1
if not load_model_1:  # train
    print("begin vae model training...")
    # ********************* training *********************
    net = train_vae(args, single_cell, cfg, label, used_device)
    # ************** training finished *******************
    print("vae training finished!")
else:  # load model
    print("begin vae model loading...")
    net = load_vae(args, cfg, used_device)
    print("vae load finished!")


begin vae model training...


Train Epoch: 9: 100%|██████████| 10/10 [00:19<00:00,  1.94s/it, loss=0.7268, min_loss=0.7268]


min loss = 0.7268315553665161
vae training finished!


### 10. <a id="Data-generation">Data generation</a>


In [29]:
# generate and out put
generate_sc_meta, generate_sc_data = generate_vae(net, args, ratio, single_cell, cfg, label, breed_2_list, index_2_gene, cell_number_target_num, used_device)


Generate Epoch: 3: 100%|██████████| 249/249.0 [00:02<00:00, 91.67it/s] 

generated done!
begin data to spatial mapping...
Data have been prepared...





In [30]:
generate_sc_meta

Unnamed: 0,Cell,Cell_type
0,C_1,T cell
1,C_2,T cell
2,C_3,Monocyte
3,C_4,B cell
4,C_5,Monocyte
...,...,...
244,C_245,Dendritic cell
245,C_246,Dendritic cell
246,C_247,Dendritic cell
247,C_248,Dendritic cell


In [31]:
generate_sc_data

Unnamed: 0,C_1,C_2,C_3,C_4,C_5,C_6,C_7,C_8,C_9,C_10,...,C_240,C_241,C_242,C_243,C_244,C_245,C_246,C_247,C_248,C_249
SAMD3,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
NDRG3,0.056692,0.083544,0.061471,0.056519,0.086942,0.069131,0.094286,0.073867,0.083930,0.102327,...,0.090966,0.052713,0.102111,0.111872,0.119177,0.097710,0.076162,0.054031,0.090936,0.069547
ERAP2,0.065073,0.095656,0.041923,0.081253,0.073224,0.082682,0.095799,0.054389,0.033598,0.096507,...,0.088366,0.085646,0.051459,0.131833,0.106358,0.039055,0.102615,0.033069,0.080997,0.038181
IL16,0.100663,0.171415,0.138631,0.185857,0.111951,0.172269,0.143623,0.181735,0.128849,0.137029,...,0.116545,0.127886,0.168123,0.175001,0.165426,0.125347,0.145840,0.206552,0.185543,0.117299
ELP2,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.012013,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
GUCD1,0.000000,0.003734,0.040970,0.025784,0.000000,0.000000,0.000000,0.050009,0.005577,0.000000,...,0.007043,0.048570,0.010683,0.007464,0.000000,0.033975,0.010976,0.000000,0.015121,0.000000
TNK2,0.026259,0.014029,0.037063,0.034572,0.000000,0.000000,0.000000,0.034273,0.011471,0.011890,...,0.003083,0.006813,0.052507,0.010723,0.040493,0.029212,0.031700,0.000000,0.000000,0.025666
ABI3,0.076454,0.052532,0.043987,0.022894,0.029788,0.014362,0.000000,0.070021,0.006761,0.035401,...,0.069355,0.025740,0.024902,0.022617,0.000000,0.008169,0.029530,0.026254,0.071665,0.050895
SEC31B,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000


### 11. <a id="Data-saving">Data saving</a>

In [32]:
# saving.....
path = osp.join(args.output_path, args.project_name, 'predata')
if not osp.exists(path):
    os.makedirs(path)
name = "vae"
# kl_loss BetaVAE_H
if args.BetaVAE_H:
    name = "BetaVAE"
path_label_generate_csv = os.path.join(path, args.project_name + "_celltype_pred_" + name + "_epoch" + str(args.epoch_num) + '_lr' + str(args.learning_rate) + ".csv")
path_cell_generate_csv = os.path.join(path, args.project_name + "_data_pred_" + name + "_epoch" + str(args.epoch_num) + '_lr' + str(args.learning_rate) + ".csv")

generate_sc_meta.to_csv(path_label_generate_csv)
generate_sc_data.to_csv(path_cell_generate_csv)

print("bulk deconvolution finish!")

print('start to map data to space...')

bulk deconvolution finish!
start to map data to space...


### 12. <a id="Data-Mapping">Mapping generated single cells to spatial locations</a>

After obtaining the generated data, we needed to map each cell and the spot it belongs to.

In [33]:
processer = CreatData(generate_sc_data, generate_sc_meta, input_st_data, args)
processer.cre_data()
runner = Runner(generate_sc_data, generate_sc_meta, input_st_data, input_st_meta, args)
print('start to train the model...')
pred_meta, pred_spot = runner.run()
print("spatial mapping done!")

preparing train data...
sucessfully create positive data
sucessfully create negative data
save xtrain ok！
save ytrain ok！
train data already prepared.
load xtrain ok!
load ytrain ok!
select top 500 marker genes of each cell type...
start to train the model...
saving model done!
model trained sucessfully, start saving output ...
Calculating scores...
Calculating scores done.
save csv ok
spatial mapping done!


Predicted results, including cell, cell type, spot to which cell belongs, spot coordinates, cell coordinates.

In [34]:
pred_meta

Unnamed: 0,Cell,Cell_type,Spot,Spot_xcoord,Spot_ycoord,Cell_xcoord,Cell_ycoord
0,C_1,T cell,spot_1,1,1,1.02,0.98
1,C_2,T cell,spot_1,1,1,0.88,1.27
2,C_3,T cell,spot_1,1,1,0.80,1.17
3,C_4,Monocyte,spot_1,1,1,1.02,0.96
4,C_5,Monocyte,spot_1,1,1,0.92,1.07
...,...,...,...,...,...,...,...
295,C_296,T cell,spot_30,6,5,5.92,4.99
296,C_297,Monocyte,spot_30,6,5,5.97,4.51
297,C_298,Monocyte,spot_30,6,5,5.75,4.68
298,C_299,T cell,spot_30,6,5,5.63,4.80


Predicted results, cell-gene expression matrix.

In [35]:
pred_spot

Unnamed: 0,C_1,C_2,C_3,C_4,C_5,C_6,C_7,C_8,C_9,C_10,...,C_291,C_292,C_293,C_294,C_295,C_296,C_297,C_298,C_299,C_300
SAMD3,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
NDRG3,0.148384,0.051662,0.086545,0.034085,0.040875,0.087099,0.080949,0.093981,0.087721,0.093598,...,0.090921,0.082046,0.110118,0.087099,0.080949,0.074255,0.101741,0.072627,0.093981,0.046041
ERAP2,0.053677,0.071285,0.102918,0.086292,0.072123,0.131853,0.115909,0.088720,0.109861,0.045992,...,0.057311,0.077989,0.024104,0.131853,0.115909,0.048906,0.094665,0.054679,0.088720,0.063340
IL16,0.193890,0.119845,0.120675,0.115754,0.102137,0.156529,0.105136,0.127973,0.152416,0.121232,...,0.112025,0.171651,0.142742,0.156529,0.105136,0.186450,0.148717,0.156713,0.127973,0.171026
ELP2,0.000000,0.000000,0.019100,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.006391,0.000000,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
GUCD1,0.004715,0.049983,0.016409,0.000000,0.048246,0.007240,0.001086,0.013888,0.008084,0.000000,...,0.018007,0.020545,0.001012,0.007240,0.001086,0.056634,0.000000,0.000000,0.013888,0.023775
TNK2,0.058865,0.016446,0.003978,0.009388,0.004503,0.053652,0.039288,0.018662,0.004980,0.000000,...,0.000000,0.000000,0.053869,0.053652,0.039288,0.000000,0.000000,0.000000,0.018662,0.049675
ABI3,0.026339,0.089545,0.027039,0.007665,0.002503,0.047799,0.000000,0.027250,0.027641,0.048826,...,0.000000,0.000000,0.023334,0.047799,0.000000,0.018365,0.060058,0.022673,0.027250,0.010164
SEC31B,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
