In [270]:
import argparse
import sys

import mlflow
import numpy as np
import scanpy as sc
import squidpy as sq
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.transforms import RandomLinkSplit, RandomNodeSplit

from autotalker.data import download_nichenet_ligand_target_mx
from autotalker.data import load_spatial_adata_from_csv
from autotalker.data import SpatialAnnDataset
from autotalker.models import Autotalker

dataset = "squidpy_seqfish"
n_epochs = 10
lr = 0.01
batch_size = 128
n_hidden = 32
n_latent = 16
dropout_rate = 0.

print(f"Using dataset {dataset}.")

if dataset == "deeplinc_seqfish":
    adata = load_spatial_adata_from_csv("datasets/seqFISH/counts.csv",
                                        "datasets/seqFISH/adj.csv")
    cell_type_key = None
elif dataset == "squidpy_seqfish":
    adata = sq.datasets.seqfish()
    sq.gr.spatial_neighbors(adata, radius = 0.04, coord_type="generic")
    cell_type_key = "celltype_mapped_refined"
elif dataset == "squidpy_slideseqv2":
    adata = sq.datasets.slideseqv2()
    sq.gr.spatial_neighbors(adata, radius = 30.0, coord_type="generic")
    cell_type_key = "celltype_mapped_refined"

print(f"Number of nodes: {adata.X.shape[0]}")
print(f"Number of node features: {adata.X.shape[1]}")
avg_edges_per_node = round(
    adata.obsp['spatial_connectivities'].toarray().sum(axis=0).mean(),2)
print(f"Average number of edges per node: {avg_edges_per_node}")
n_edges = int(np.triu(adata.obsp['spatial_connectivities'].toarray()).sum())
print(f"Number of edges: {n_edges}", sep="")

dataset = SpatialAnnDataset(adata, adj_key="spatial_connectivities")
data = Data(x=dataset.x,
            edge_index=dataset.edge_index,
            conditions=dataset.conditions,
            size_factors=dataset.size_factors)

Using dataset squidpy_seqfish.


  self._set_arrayXarray(i, j, x)


Number of nodes: 19416
Number of node features: 351
Average number of edges per node: 4.4
Number of edges: 42694


In [274]:
data

Data(x=[19416, 351], edge_index=[2, 85388], conditions=[19416], size_factors=[19416], train_mask=[19416], val_mask=[19416], test_mask=[19416])

In [275]:
# Split data on edge level
transform = RandomLinkSplit(num_val=0.1,
                            num_test=0.1,
                            is_undirected=True,
                            neg_sampling_ratio=0)
node_split = RandomNodeSplit(num_val=0.1,
                             num_test=0,
                             key="x")


train_data, valid_data, test_data = transform(data)
# Gene expression is there for all nodes in all datasets
# edge index is there for all edge index pairs (symmetric) for training edges (not valid and test edges) in the dataset for all datasets
# Conditions are there for all nodes in all datasets
# Size factors are there for all nodes in all datasets
# Pos edge labels are only there for split
# Pos edge label indices are only there for split
# ned edge labels are only there for split
# neg edge label indices are only there for split


print(train_data)
print(valid_data)

Data(x=[19416, 351], edge_index=[2, 68312], conditions=[19416], size_factors=[19416], train_mask=[19416], val_mask=[19416], test_mask=[19416], edge_label=[34156], edge_label_index=[2, 34156])
Data(x=[19416, 351], edge_index=[2, 68312], conditions=[19416], size_factors=[19416], train_mask=[19416], val_mask=[19416], test_mask=[19416], edge_label=[4269], edge_label_index=[2, 4269])


In [287]:
train_loader = torch_geometric.loader.LinkNeighborLoader(
    train_data,
    num_neighbors=[-1]*2, # iterations of neighbor sampling
    batch_size=1,
    edge_label_index=train_data.edge_label_index,
    directed=True,
    neg_sampling_ratio=1)

In [397]:
valid_loader = torch_geometric.loader.LinkNeighborLoader(
    valid_data,
    num_neighbors=[-1]*2, # iterations of neighbor sampling
    batch_size=1,
    edge_label_index=valid_data.edge_label_index,
    directed=True,
    neg_sampling_ratio=1)

In [346]:
valid_data

Data(x=[19416, 351], edge_index=[2, 68312], conditions=[19416], size_factors=[19416], train_mask=[19416], val_mask=[19416], test_mask=[19416], edge_label=[4269], edge_label_index=[2, 4269])

In [432]:
sampled_data = next(iter(valid_loader))
print(sampled_data)

Data(x=[46, 351], edge_index=[2, 80], conditions=[46], size_factors=[46], train_mask=[46], val_mask=[46], test_mask=[46], edge_label=[2], edge_label_index=[2, 2])


In [241]:
sampled_data = next(iter(train_loader))
print(sampled_data)

Data(x=[64, 351], edge_index=[2, 143], conditions=[64], size_factors=[64], edge_label=[2], edge_label_index=[2, 2])


In [None]:
sampled_data = 

In [48]:
[30]*2

[30, 30]

In [24]:
print(train_data)

Data(x=[19416, 351], edge_index=[2, 54652], conditions=[19416], size_factors=[19416], pos_edge_label=[27326], pos_edge_label_index=[2, 27326], neg_edge_label=[27326], neg_edge_label_index=[2, 27326], train_mask=[19416], val_mask=[19416], test_mask=[19416])


In [27]:
print(valid_data)

Data(x=[19416, 351], edge_index=[2, 54652], conditions=[19416], size_factors=[19416], pos_edge_label=[3415], pos_edge_label_index=[2, 3415], neg_edge_label=[3415], neg_edge_label_index=[2, 3415], train_mask=[19416], val_mask=[19416], test_mask=[19416])


In [8]:
print(data)

Data(x=[19416, 351], edge_index=[2, 68312], conditions=[19416], size_factors=[19416], pos_edge_label=[34156], pos_edge_label_index=[2, 34156], neg_edge_label=[34156], neg_edge_label_index=[2, 34156])
