# Set Up

## Overview

Plan: integrate normal (atlas) with T21 data and ToF data from Hill2022

`cell_type_key` will be a modification of `mid_grain` labels from the atlas.

Modified because the `CardiacConductionSystem` cells in the `mid_grain` category are functionally linked, but transcriptionally distinct. For scANVI we want labels to reflect transcriptionally close cells.

`batch_key` will be `batch_key` (donor, kit, cell v nuclei, which is present on both objects). This will need to be added to Hill2022

T21 and ToF data will be concatenated into a single query dataset. The query dataset needs to have the following columns:
- `donor`
- `cell_or_nuclei`
- `kit_10x`
- `batch_key` (the 3 above concatenated)
- `diagnosis` (T21 or ToF) - will also add this column to reference (atlas)

For both refernce and query datasets:
- `"counts"` layer containing raw counts
- `.X` contains norm-logged data
- features will be the HVGs **from the reference atlas** (so features are the same across both datasets)

https://docs.scarches.org/en/latest/scanvi_surgery_pipeline.html

In [1]:
import scanpy as sc
import torch
import scarches as sca
from scarches.dataset.trvae.data_handling import remove_sparsity
import matplotlib.pyplot as plt
import numpy as np
import os
import gdown
import pandas as pd
import anndata

INFO:lightning_fabric.utilities.seed:Global seed set to 0
 captum (see https://github.com/pytorch/captum).


In [2]:
# Prints current CPU and GPU RAM usage
import psutil
import os

def print_memory_usage():
    process = psutil.Process(os.getpid())
    mem_info = process.memory_info()
    total_memory = psutil.virtual_memory().total
    
    relative_memory = (mem_info.rss / total_memory) * 100
    
    print(f"Absolute memory usage: {round(mem_info.rss / 1024 ** 2)} MB")
    print(f"Relative memory usage: {round(relative_memory)}%")
    
import GPUtil
def print_gpu_usage():
    GPUs = GPUtil.getGPUs()
    for i, gpu in enumerate(GPUs):
        print(f"GPU {i}: {round(gpu.load*100, 2)}%")

In [3]:
sc.settings.set_figure_params(dpi=200, frameon=False)
sc.set_figure_params(dpi=200)
sc.set_figure_params(figsize=(4, 4))
torch.set_printoptions(precision=3, sci_mode=False, edgeitems=7)

In [4]:
%%bash
nvidia-smi

Fri Jun 23 07:27:19 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.91.03    Driver Version: 460.91.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:62:00.0 Off |                    0 |
| N/A   39C    P0    53W / 300W |    972MiB / 32510MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla V100-SXM2...  Off  | 00000000:89:00.0 Off |                    0 |
| N/A   36C    P0    38W / 300W |      0MiB / 32510MiB |      0%      Default |
|       

## Paths & variables

In [5]:
#set relevant label key
compartment='AllCompartments'
cell_type_key = 'MidGrainModified'
batch_key='batch_key'
job_name=f'AtlasT21ToF_WithContCovariates_{cell_type_key}_{compartment}'
job_name

'AtlasT21ToF_WithContCovariates_MidGrainModified_AllCompartments'

In [6]:
#paths
base_path='/lustre/scratch126/cellgen/team205/jc48/jupyter/scArches'
data_dir=os.path.join(base_path,'data')
results_dir=os.path.join(base_path,f'results/{job_name}')
adata_path=os.path.join(results_dir,'anndata/')
models_path=os.path.join(results_dir,'models/')


#make directories
os.makedirs(data_dir, exist_ok=True)
os.makedirs(results_dir, exist_ok=True)
os.makedirs(adata_path, exist_ok=True)
os.makedirs(models_path, exist_ok=True)

# Read in data

## Reference data

In [7]:
adata_ref = sc.read_h5ad('/nfs/team205/heart/anndata_objects/Foetal/Feb28ObjectRaw.h5ad')

adata_ref.var['mt'] = adata_ref.var_names.str.startswith('MT-')  # annotate the group of mitochondrial genes
adata_ref.var['ribo'] = adata_ref.var_names.str.startswith('RPS','RPL')  # annotate the group of ribosomal genes
sc.pp.calculate_qc_metrics(adata_ref, qc_vars=['mt','ribo'], percent_top=None, log1p=False, inplace=True)

print(adata_ref.X.data[:10])

adata_ref.layers["counts"]=adata_ref.X.copy()

sc.pp.normalize_total(adata_ref,target_sum=1e4)
sc.pp.log1p(adata_ref)

sc.pp.highly_variable_genes(adata_ref, batch_key=batch_key)

print(adata_ref.shape)
adata_ref = adata_ref[:, adata_ref.var.highly_variable]

# Modify the mid_grain label since the CCS cells bridge two major clusters
rename_dict = {
    "SinoatrialNodeCardiomyocytes": "AtrialCardiomyocytes",
    "AtrioventricularNodeCardiomyocytes": "AtrialCardiomyocytes",
    "VentricularConductionSystemDistal": "VentricularCardiomyocytes",
    "VentricularConductionSystemProximal": "VentricularCardiomyocytes"
}

# Create the "MidGrainModified" column using the mapping logic
adata_ref.obs["MidGrainModified"] = adata_ref.obs["fine_grain"].map(rename_dict).fillna(adata_ref.obs["mid_grain"])

adata_ref.obs['diagnosis']='atlas'

print('data prepared...')
print(adata_ref.X.data[:10])
print(adata_ref.shape)

[ 1.  1.  5.  5.  8.  4. 13.  4.  3.  1.]
(297473, 36601)


  adata_ref.obs["MidGrainModified"] = adata_ref.obs["fine_grain"].map(rename_dict).fillna(adata_ref.obs["mid_grain"])


data prepared...
[0.42 0.34 1.12 0.1  0.34 0.34 0.27 0.1  0.19 1.08]
(297473, 3392)


The data is raw counts and has all features.
We:
- Calculate QC metrics (to be used as covariates), in case the object has been subset since this was last done
- store counts in "counts" layer.
- Norm-Log
- Calculate HVGs
- Subset to those
- Add a custom cell type label

## Target data

### T21

In [8]:
# T21 data
file_path=os.path.join(data_dir,'T21HeartsExtracardiacRemovedRaw.h5ad')
T21=sc.read(file_path)

# add extra columns
T21.obs['diagnosis']='T21'

print(T21.X.data[:10])
T21

[ 3.  1.  1.  2. 24.  7.  4.  1.  1. 14.]


AnnData object with n_obs × n_vars = 76358 × 36601
    obs: 'latent_RT_efficiency', 'latent_cell_probability', 'latent_scale', 'sangerID', 'combinedID', 'donor', 'region', 'age', 'facility', 'cell_or_nuclei', 'modality', 'kit_10x', 'scrublet_score', 'doublet_pval', 'doublet_bh_pval', 'n_genes', 'n_counts', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'log1p_total_counts_ribo', 'pct_counts_ribo', 'HB_score', 'batch_key', 'S_score', 'G2M_score', '_scvi_batch', 'hb1_score', 'leiden', 'coarse_grain', 'FACSgate', 'sex', 'week', 'trimester', 'cycling', 'phase', 'diagnosis'
    var: 'feature_type', 'ambient_expression-0', 'ambient_expression-1', 'ambient_expression-2', 'ambient_expression-3', 'ambient_expression-4', 'ambient_expression-5', 'ambient_expression-6', 'ambient_expression-7', 'mt', 'ribo', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by

### ToF

In [9]:
# ToF data
file_path=os.path.join(data_dir, 'Hill2022_QCapplied.h5ad')
ToF = sc.read_h5ad(file_path)

# subset to ToF only
print(ToF.shape)
ToF=ToF[ToF.obs.diagnosis=='TOF']
print(ToF.shape)

# add extra columns
ToF.obs['donor']=ToF.obs['patientID']
ToF.obs['cell_or_nuclei']='Nuclei'
ToF.obs['kit_10x']='3prime-v3'
ToF.obs['batch_key']=ToF.obs['donor'].astype('str')+'_'+ToF.obs['cell_or_nuclei'].astype('str')+'_'+ToF.obs['kit_10x'].astype('str')

print(ToF.X.data[:10])
ToF

(157293, 29266)
(15836, 29266)


  ToF.obs['donor']=ToF.obs['patientID']


[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]


AnnData object with n_obs × n_vars = 15836 × 29266
    obs: 'orig_ident', 'nCount_RNA', 'nFeature_RNA', 'labID', 'procedure', 'age', 'gender', 'echoEF', 'vers10X', 'diagnosis', 'ageCont', 'region', 'batch_indices', 'percent_mt', 'ClinicalRank', 'DEid', 'MainCellType', 'Cluster', 'labID2', 'colors', 'Diagnosis', 'patientID', 'cell_id', 'biosample_id', 'donor_id', 'species', 'species__ontology_label', 'disease', 'disease__ontology_label', 'organ', 'organ__ontology_label', 'library_preparation_protocol', 'sex', 'library_preparation_protocol__ontology_label', 'n_genes', 'n_counts', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'pct_counts_ribo', 'donor', 'cell_or_nuclei', 'kit_10x', 'batch_key'
    var: 'gene_ids', 'mt', 'ribo', 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts'

### Combine ToF and T21

In [10]:
# Get variable names from both adata objects
vars_T21 = set(T21.var_names)
vars_ToF = set(ToF.var_names)

# Find intersection (shared variables)
shared_vars = vars_T21 & vars_ToF

# Print the number of shared variables
print("Number of shared variables between T21 and ToF: ", len(shared_vars))

Number of shared variables between T21 and ToF:  18642


In [11]:
# Get column names from .obs of both adata objects
obs_cols_T21 = set(T21.obs.columns)
obs_cols_ToF = set(ToF.obs.columns)

# Find intersection (shared columns)
shared_obs_cols = obs_cols_T21 & obs_cols_ToF

# Print shared columns
shared_obs_cols

{'age',
 'batch_key',
 'cell_or_nuclei',
 'diagnosis',
 'donor',
 'kit_10x',
 'n_counts',
 'n_genes',
 'n_genes_by_counts',
 'pct_counts_mt',
 'pct_counts_ribo',
 'region',
 'sex',
 'total_counts',
 'total_counts_mt',
 'total_counts_ribo'}

In [12]:
# Concatenate T21 and ToF objects
# Get shared variables/features
vars_T21 = set(T21.var_names)
vars_ToF = set(ToF.var_names)
shared_vars = vars_T21 & vars_ToF

# Keep only shared variables/features
T21 = T21[:, list(shared_vars)]
ToF = ToF[:, list(shared_vars)]

# Get shared .obs columns
obs_cols_T21 = set(T21.obs.columns)
obs_cols_ToF = set(ToF.obs.columns)
shared_obs_cols = obs_cols_T21 & obs_cols_ToF

# Keep only shared .obs columns
T21.obs = T21.obs.loc[:, list(shared_obs_cols)]
ToF.obs = ToF.obs.loc[:, list(shared_obs_cols)]

# Sort variables/features
T21 = T21[:, sorted(list(shared_vars))]
ToF = ToF[:, sorted(list(shared_vars))]

# Sort .obs columns
T21.obs = T21.obs.sort_index(axis=1)
ToF.obs = ToF.obs.sort_index(axis=1)

# Concatenate the two adata objects
adata_que = anndata.concat([T21, ToF])
print(adata_que.X.data[:10])
adata_que

[ 1. 24.  7.  4.  1. 14.  3.  1.  3.  2.]


AnnData object with n_obs × n_vars = 92194 × 18642
    obs: 'age', 'batch_key', 'cell_or_nuclei', 'diagnosis', 'donor', 'kit_10x', 'n_counts', 'n_genes', 'n_genes_by_counts', 'pct_counts_mt', 'pct_counts_ribo', 'region', 'sex', 'total_counts', 'total_counts_mt', 'total_counts_ribo'

In [13]:
# save adata_que with raw counts in .X
file_path=os.path.join(adata_path,f'{job_name}_adata_que_T21ToF.h5ad')
adata_que.write(file_path)

### Prepare adata_que for scANVI (subset to features shared with adata_ref HVGs)

In [14]:
# Get variable names from both adata objects
vars_ref = set(adata_ref.var_names)
vars_que = set(adata_que.var_names)

# Find intersection (shared variables)
shared_vars = vars_ref & vars_que

# Print the number of shared variables
print("Number of features in adata_ref: ", len(adata_ref.var_names))
print("Number of features in adata_ref also found in adata_que: ", len(shared_vars))

Number of features in adata_ref:  3392
Number of features in adata_ref also found in adata_que:  2872


In [15]:
adata_que.var['mt'] = adata_que.var_names.str.startswith('MT-')  # annotate the group of mitochondrial genes
adata_que.var['ribo'] = adata_que.var_names.str.startswith('RPS','RPL')  # annotate the group of ribosomal genes
sc.pp.calculate_qc_metrics(adata_que, qc_vars=['mt','ribo'], percent_top=None, log1p=False, inplace=True)

print(adata_que.X.data[:10])
adata_que.layers["counts"] = adata_que.X.copy()
sc.pp.normalize_total(adata_que,target_sum=1e4)
sc.pp.log1p(adata_que)
adata_que.raw = adata_que

# Slice to keep only the features found in the adata_ref HVGs
# Convert set to list
shared_vars_list = list(shared_vars)
# Slice adata_que to only contain shared variables
adata_que = adata_que[:, shared_vars_list]

print('data prepared...')
print(adata_que.X.data[:10])
print(adata_que.shape)

[ 1. 24.  7.  4.  1. 14.  3.  1.  3.  2.]
data prepared...
[0.53 0.34 0.26 0.83 0.4  0.1  0.18 0.26 0.1  0.26]
(92194, 2872)


In [16]:
pd.crosstab(adata_que.obs.donor,adata_que.obs.diagnosis)

diagnosis,T21,TOF
donor,Unnamed: 1_level_1,Unnamed: 2_level_1
Hst42,54262,0
Hst44,22096,0
P26,0,2428
P28,0,6195
P33,0,7213


In [17]:
adata_que.obs.head(3)

Unnamed: 0,age,batch_key,cell_or_nuclei,diagnosis,donor,kit_10x,n_counts,n_genes,n_genes_by_counts,pct_counts_mt,pct_counts_ribo,region,sex,total_counts,total_counts_mt,total_counts_ribo
BHF_F_Hea13188319_BHF_F_Hea13187621_TAAGGACCATGTCAAT-1,14W0D,Hst42_Nuclei_Multiome-v1,nuclei,T21,Hst42,Multiome-v1,108008.0,12409,10363,1.983755,1.538707,pericardium,male,100214.0,1988.0,1542.0
BHF_F_Hea13188319_BHF_F_Hea13187621_CGTTTCTCATGAATCT-1,14W0D,Hst42_Nuclei_Multiome-v1,nuclei,T21,Hst42,Multiome-v1,85032.0,11220,9634,3.359229,2.350827,pericardium,male,79036.0,2655.0,1858.0
BHF_F_Hea13188319_BHF_F_Hea13187621_GATCACTAGCACGTTG-1,14W0D,Hst42_Nuclei_Multiome-v1,nuclei,T21,Hst42,Multiome-v1,74917.0,10797,9212,1.703957,0.972046,pericardium,male,69544.0,1185.0,676.0


### Finally, reduce adata_ref to shared variables as well

In [18]:
print(adata_ref.shape)
adata_ref = adata_ref[:, shared_vars_list]
print(adata_ref.shape)

(297473, 3392)
(297473, 2872)


Similarly, data is raw counts and has all features.
We:
- Calculate QC metrics (to be used as covariates), in case the object has been subset since this was last done
- store counts in "counts" layer.
- Norm-Log
- **Subset to match the HVGs of the reference**

# Make reference model

In [20]:
# set up the anndata

adata_ref=adata_ref.copy()

sca.models.SCVI.setup_anndata(adata_ref,
                              batch_key=batch_key,
                              labels_key=cell_type_key,
                              continuous_covariate_keys=['total_counts','pct_counts_mt','pct_counts_ribo'],
                              layer="counts")



In [21]:
# initialise the VAE model
vae = sca.models.SCVI(
    adata_ref,
    n_hidden = 128, n_latent = 50, n_layers = 3, dispersion = 'gene-batch',
    encode_covariates=True,
    deeply_inject_covariates=False,
    use_layer_norm="both",
    use_batch_norm="none",
)

In [22]:
vae.view_anndata_setup(adata_ref)

In [None]:
# Train the VAE
vae.train()

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Epoch 1/27:   0%|                                                                                                                     | 0/27 [00:00<?, ?it/s]

In [None]:
# Initialise a SCANVI model from the VAE
scanvae = sca.models.SCANVI.from_scvi_model(vae,unlabeled_category = "Unknown")

In [None]:
print("Labelled Indices: ", len(scanvae._labeled_indices))
print("Unlabelled Indices: ", len(scanvae._unlabeled_indices))

In [None]:
scanvae.view_anndata_setup(adata_ref)

In [None]:
# Train the SCANVI model
scanvae.train(max_epochs=20)

## Save models

In [None]:
vae_path=os.path.join(models_path,f'{job_name}_ref_vae')
scanvae_path = os.path.join(models_path,f'{job_name}_ref_scanvae')

In [None]:
vae.save(vae_path, overwrite=True)
scanvae.save(scanvae_path, overwrite=True)

# Extract latent variables for reference object

In [None]:
reference_latent = sc.AnnData(scanvae.get_latent_representation())
reference_latent.obs["cell_type"] = adata_ref.obs[cell_type_key].tolist()
reference_latent.obs["batch"] = adata_ref.obs[batch_key].tolist()

In [None]:
reference_latent.obs['predictions'] = scanvae.predict()
print("Acc: {}".format(np.mean(reference_latent.obs.predictions == reference_latent.obs.cell_type)))

## Inspect reference latent

In [None]:
reference_latent

In [None]:
%%time
# Optional: UMAP of reference embedding using scANVI latent variables
'''
sc.pp.neighbors(reference_latent)
sc.tl.umap(reference_latent)
sc.pl.umap(reference_latent,
           color=['batch', 'cell_type'],
           wspace=0.8
)
'''

## Save reference latent adata

In [None]:
# save
file_path=os.path.join(adata_path,f'{job_name}_reference_latent.h5ad')
reference_latent.write(file_path)

# Update the base model with the query data ('perform surgery')

one has to preprocess ‘target_adata’ in the following way:

If there are new celltypes in there, save the original labels in other column and replace all labels with unlabeled category:

`target_adata.obs['orig_cell_types'] = target_adata.obs[cell_type_key].copy()`

`target_adata.obs[cell_type_key] = scanvae.unlabeled_category_`

If there is no ‘.obs’ column for cell types:

`target_adata.obs[cell_type_key] = scanvae.unlabeled_category_`

In [None]:
adata_que.obs[cell_type_key] = scanvae.unlabeled_category_

In [None]:
# Initialise
# This take the original scanvae model we just made and embeds the query data within it

adata_que=adata_que.copy()

model = sca.models.SCANVI.load_query_data(adata_que,
                                          scanvae_path,
                                          freeze_dropout = True)

In [None]:
model._unlabeled_indices = np.arange(adata_que.n_obs)
model._labeled_indices = []
print("Labelled Indices: ", len(model._labeled_indices))
print("Unlabelled Indices: ", len(model._unlabeled_indices))

In [None]:
# train the model
model.train(max_epochs=100,
            plan_kwargs=dict(weight_decay=0.0),
            check_val_every_n_epoch=10,
)

## Save model

In [None]:
# Save the 'surgery' model
file_path=os.path.join(models_path,f'{job_name}_surgery_model')
model.save(file_path, overwrite=True)

# Extract latent variables for query object

In [None]:
query_latent = sc.AnnData(model.get_latent_representation())
query_latent.obs['cell_type'] = adata_que.obs[cell_type_key].tolist()
query_latent.obs['batch'] = adata_que.obs[batch_key].tolist()
query_latent.obs_names = model.adata.obs_names.copy() # added

In [None]:
query_latent.obs['predictions'] = model.predict()
print("Acc: {}".format(np.mean(query_latent.obs.predictions == query_latent.obs.cell_type)))

## Inspect query latent

In [None]:
query_latent

## Save query latent adata

In [None]:
# save
file_path=os.path.join(adata_path,f'{job_name}_query_latent.h5ad')
query_latent.write(file_path)

# Combine query and reference objects

In [None]:
# add scANVI prediction outs to reference adata
adata_ref.obs['scANVI_predictions'] = scanvae.predict()

In [None]:
adata_full = adata_ref.concatenate(adata_que,
                                  index_unique = None,
                                  batch_key= 'original_or_new',
                                  batch_categories=['original','new'])

full_latent = sc.AnnData(model.get_latent_representation(adata=adata_full))

full_latent.obs['cell_type'] = adata_full.obs[cell_type_key].tolist()
full_latent.obs['batch'] = adata_full.obs[batch_key].tolist()
full_latent.obs['predictions'] = model.predict(adata=adata_full)
full_latent.obs_names = adata_full.obs_names.copy()
full_latent.obs['original_or_new'] = adata_full.obs['original_or_new'].tolist()

## Inspect full latent

In [None]:
full_latent

In [None]:
%%time
sc.pp.neighbors(full_latent)

In [None]:
%%time
sc.tl.umap(full_latent)

In [None]:
%%time
sc.pl.umap(full_latent,
           color=["predictions","batch", "cell_type","original_or_new"],
           wspace=0.8,
           ncols=2
)

## Save full latent adata

In [None]:
# save
file_path=os.path.join(adata_path,f'{job_name}_full_latent.h5ad')
full_latent.write(file_path)