In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import argparse
import os
import sys

import mlflow
import numpy as np
import scanpy as sc
import squidpy as sq

from autotalker.data import load_spatial_adata_from_csv
from autotalker.models import Autotalker
from autotalker.utils import download_nichenet_ligand_target_mx
from autotalker.utils import extract_gps_from_ligand_target_mx
from autotalker.utils import mask_adata_with_gp_dict

In [None]:
dataset = "squidpy_seqfish"
n_epochs = 10
lr = 0.01
batch_size = 64
n_hidden_encoder = 32
n_latent = 16
dropout_rate = 0.

In [None]:
print(f"Using dataset {dataset}.")

if dataset == "deeplinc_seqfish":
    adata = load_spatial_adata_from_csv("datasets/seqFISH/counts.csv",
                                        "datasets/seqFISH/adj.csv")
    cell_type_key = None
elif dataset == "squidpy_seqfish":
    adata = sq.datasets.seqfish()
    sq.gr.spatial_neighbors(adata, radius = 0.04, coord_type="generic")
    cell_type_key = "celltype_mapped_refined"
elif dataset == "squidpy_slideseqv2":
    adata = sq.datasets.slideseqv2()
    sq.gr.spatial_neighbors(adata, radius = 30.0, coord_type="generic")
    cell_type_key = "celltype_mapped_refined"

In [None]:
print(f"Number of nodes: {adata.X.shape[0]}")
print(f"Number of node features: {adata.X.shape[1]}")
avg_edges_per_node = round(
    adata.obsp['spatial_connectivities'].toarray().sum(axis=0).mean(),2)
print(f"Average number of edges per node: {avg_edges_per_node}")
n_edges = int(np.triu(adata.obsp['spatial_connectivities'].toarray()).sum())
print(f"Number of edges: {n_edges}", sep="")

In [None]:
os.makedirs("mlruns", exist_ok=True)

In [None]:
experiment = mlflow.set_experiment("autotalker")
mlflow.log_param("dataset", dataset)

In [None]:
# Mask that allows all genes
mask = np.ones((n_latent, len(adata.var)))

In [None]:
model = Autotalker(adata,
                   autotalker_module="VGPGAE",
                   n_hidden_encoder=n_hidden_encoder,
                   dropout_rate_encoder=dropout_rate,
                   dropout_rate_graph_decoder=dropout_rate,
                   gp_mask=mask)

In [None]:
model.train(n_epochs=n_epochs,
            lr=lr,
            mlflow_experiment_id=experiment.experiment_id)

In [None]:
model.save(dir_path="./model_artefacts",
           overwrite=True,
           save_adata=True,
           adata_file_name="adata.h5ad")

In [None]:
model = Autotalker.load(dir_path="./model_artefacts",
                        adata=None,
                        adata_file_name="adata.h5ad")

In [None]:
latent = model.get_latent_representation()

In [None]:
latent_new_data = model.get_latent_representation(adata)

In [None]:
adata.obsm["latent_autotalker"] = latent_new_data

## Interoperability with scanpy

In [None]:
# Use autotalker latent space for UMAP generation
sc.pp.neighbors(adata, use_rep="latent_autotalker")
sc.tl.umap(adata, min_dist=0.3)
sc.pl.umap(adata, color=["celltype_mapped_refined"], frameon=False,)

## NicheNet Gene Programs (GPs)

In [None]:
gp_data_folder_path = "datasets/gp_data"
gp_data_file_path = gp_data_folder_path + "/ligand_target_matrix.csv"
os.makedirs(gp_data_folder_path, exist_ok=True)

In [None]:
download_nichenet_ligand_target_mx(
    save_path=gp_data_file_path)

In [None]:
gp_dict = extract_gps_from_ligand_target_mx(
    path=gp_data_file_path)

In [None]:
mask_adata_with_gp_dict(adata, gp_dict)

In [None]:
mask = adata.varm["I"].T

In [None]:
mask.shape

In [None]:
mask.shape

In [None]:
I = [[int(gene in gp) for _, gp in gp_dict.items()] for gene in adata_genes]
I = np.asarray(I, dtype="int32")

In [None]:
gp_dict

In [None]:
I.sum()

In [None]:
for gp_name, gp in gp_dict.items():
    print(gp_name)
    print(gp)
    break

## SCVI

In [3]:
import scvi
import scanpy as sc
import matplotlib.pyplot as plt

sc.set_figure_params(figsize=(4, 4))

# for white background of figures (only for docs rendering)
%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
%config InlineBackend.figure_format='retina'

Global seed set to 0


RuntimeError: This version of jaxlib was built using AVX instructions, which your CPU and/or operating system do not support. You may be able work around this issue by building jaxlib from source.