In [802]:
import argparse
import sys

import mlflow
import numpy as np
import scanpy as sc
import squidpy as sq
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.transforms import RandomLinkSplit, RandomNodeSplit

from autotalker.data import download_nichenet_ligand_target_mx
from autotalker.data import load_spatial_adata_from_csv
from autotalker.data import SpatialAnnDataset
from autotalker.models import Autotalker

dataset = "squidpy_seqfish"
n_epochs = 10
lr = 0.01
batch_size = 128
n_hidden = 32
n_latent = 16
dropout_rate = 0.

print(f"Using dataset {dataset}.")

if dataset == "deeplinc_seqfish":
    adata = load_spatial_adata_from_csv("datasets/seqFISH/counts.csv",
                                        "datasets/seqFISH/adj.csv")
    cell_type_key = None
elif dataset == "squidpy_seqfish":
    adata = sq.datasets.seqfish()
    sq.gr.spatial_neighbors(adata, radius = 0.04, coord_type="generic")
    cell_type_key = "celltype_mapped_refined"
elif dataset == "squidpy_slideseqv2":
    adata = sq.datasets.slideseqv2()
    sq.gr.spatial_neighbors(adata, radius = 30.0, coord_type="generic")
    cell_type_key = "celltype_mapped_refined"

print(f"Number of nodes: {adata.X.shape[0]}")
print(f"Number of node features: {adata.X.shape[1]}")
avg_edges_per_node = round(
    adata.obsp['spatial_connectivities'].toarray().sum(axis=0).mean(),2)
print(f"Average number of edges per node: {avg_edges_per_node}")
n_edges = int(np.triu(adata.obsp['spatial_connectivities'].toarray()).sum())
print(f"Number of edges: {n_edges}", sep="")

dataset = SpatialAnnDataset(adata, adj_key="spatial_connectivities")
data = Data(x=dataset.x,
            edge_index=dataset.edge_index,
            conditions=dataset.conditions,
            size_factors=dataset.size_factors)

Using dataset squidpy_seqfish.


  self._set_arrayXarray(i, j, x)


Number of nodes: 19416
Number of node features: 351
Average number of edges per node: 4.4
Number of edges: 42694


In [821]:
data

Data(x=[19416, 351], edge_index=[2, 85388], conditions=[19416], size_factors=[19416])

In [836]:
# Split data on edge level
transform = RandomLinkSplit(num_val=0.1,
                            num_test=0.1,
                            is_undirected=True,
                            neg_sampling_ratio=1)
node_split = RandomNodeSplit(num_val=0.1,
                             num_test=0,
                             key="x")


train_data, valid_data, test_data = transform(data)
# Gene expression is there for all nodes in all datasets
# edge index is there for all edge index pairs (symmetric) for training edges (not valid and test edges) in the dataset for all datasets
# Conditions are there for all nodes in all datasets
# Size factors are there for all nodes in all datasets
# Pos edge labels are only there for split
# Pos edge label indices are only there for split
# ned edge labels are only there for split
# neg edge label indices are only there for split


print(train_data)
print(valid_data)

Data(x=[19416, 351], edge_index=[2, 68312], conditions=[19416], size_factors=[19416], edge_label=[68312], edge_label_index=[2, 68312])
Data(x=[19416, 351], edge_index=[2, 68312], conditions=[19416], size_factors=[19416], edge_label=[8538], edge_label_index=[2, 8538])


In [838]:
for edge in zip(train_data.edge_label_index[0], train_data.edge_label_index[1]) :
    print(edge)

(tensor(8930), tensor(8955))
(tensor(6037), tensor(6058))
(tensor(15214), tensor(15278))
(tensor(1647), tensor(1652))
(tensor(12568), tensor(12654))
(tensor(16100), tensor(16268))
(tensor(4611), tensor(4618))
(tensor(2677), tensor(2850))
(tensor(2248), tensor(2287))
(tensor(1104), tensor(1159))
(tensor(4951), tensor(5074))
(tensor(9683), tensor(9804))
(tensor(114), tensor(132))
(tensor(5148), tensor(5254))
(tensor(16611), tensor(16638))
(tensor(4046), tensor(4102))
(tensor(12539), tensor(12629))
(tensor(3181), tensor(3183))
(tensor(3668), tensor(3677))
(tensor(3223), tensor(3239))
(tensor(4307), tensor(4317))
(tensor(17397), tensor(17407))
(tensor(11764), tensor(11860))
(tensor(10715), tensor(10738))
(tensor(16089), tensor(16252))
(tensor(6164), tensor(6182))
(tensor(7078), tensor(7165))
(tensor(8290), tensor(8294))
(tensor(2187), tensor(2554))
(tensor(13695), tensor(13736))
(tensor(19172), tensor(19177))
(tensor(13677), tensor(13730))
(tensor(712), tensor(729))
(tensor(16249), tensor(

In [823]:
train_loader = torch_geometric.loader.LinkNeighborLoader(
    train_data,
    num_neighbors=[-1]*1, # iterations of neighbor sampling
    batch_size=64,
    edge_label_index=train_data.edge_label_index,
    directed=False,
    neg_sampling_ratio=0)

In [824]:
len(train_data.edge_label)

68312

In [825]:
valid_loader = torch_geometric.loader.LinkNeighborLoader(
    valid_data,
    num_neighbors=[-1]*1, # iterations of neighbor sampling
    batch_size=64,
    edge_label_index=valid_data.edge_label_index,
    directed=False,
    neg_sampling_ratio=0)

In [831]:
valid_data

Data(x=[19416, 351], edge_index=[2, 68312], conditions=[19416], size_factors=[19416], edge_label=[8538], edge_label_index=[2, 8538])

In [827]:
sampled_data = next(iter(valid_loader))
print(sampled_data)

Data(x=[483, 351], edge_index=[2, 1402], conditions=[483], size_factors=[483], edge_label=[64], edge_label_index=[2, 64])


In [828]:
sampled_data.edge_label

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [799]:
sampled_data = next(iter(train_loader))
print(sampled_data)

Data(x=[1064, 351], edge_index=[2, 1052], conditions=[1064], size_factors=[1064], edge_label=[128], edge_label_index=[2, 128])


In [800]:
len(valid_loader)

67

In [801]:
len(train_loader)

534

In [783]:
1423*3

4269

In [723]:
len(train_loader)

11386

In [724]:
sampled_data = node_split(sampled_data)

In [725]:
sampled_data

Data(x=[53, 351], edge_index=[2, 50], conditions=[53], size_factors=[53], edge_label=[6], edge_label_index=[2, 6], train_mask=[53], val_mask=[53], test_mask=[53])

In [726]:
import torch
sampled_data.edge_label
pos_edge_label_mask = (sampled_data.edge_label > 0)

In [727]:
adj_recon_logits = torch.randn(10, 10)

In [728]:
adj_recon_logits

tensor([[ 1.0437, -1.0448, -0.3835,  0.7467, -0.0287,  0.3265,  0.1181,  2.2507,
          1.9912, -0.9735],
        [-1.1402, -0.4663,  1.3603, -1.2761, -0.3817,  0.7807,  1.9599,  0.8239,
          0.5776, -0.8431],
        [ 0.6757, -0.8224,  1.1456,  0.1753, -0.0155, -0.8384, -1.1537,  1.7968,
          1.2332, -1.7170],
        [ 0.3345, -0.1133,  0.2663,  0.9104,  0.5621, -0.2775, -1.2293, -0.6008,
         -0.4959, -1.5902],
        [-0.1176, -0.9366,  1.0543, -1.2974,  0.5777, -0.3199, -0.3814,  1.0258,
         -0.8282,  0.3318],
        [-0.9692, -1.1611, -0.5451, -0.4355,  0.1138,  2.2474,  0.1553, -1.0575,
          1.0705, -0.2537],
        [ 0.6907, -0.4729, -0.4832,  0.2529,  1.0338, -0.9106,  0.3217, -0.8446,
         -0.7785,  2.1865],
        [-0.1002, -0.9662, -0.0562,  1.3300, -0.8733, -0.8903,  0.7502,  0.3109,
         -0.2238,  1.2834],
        [ 0.3995, -0.4492, -0.1522, -0.0981, -1.3563,  1.2207, -1.0342,  0.0348,
          1.1774,  1.6985],
        [-0.7843,  

In [730]:
(torch.tensor(adj_recon_logits.shape[0]) - torch.tensor(mask.shape[0])).item()

2

In [731]:
torch.tensor(mask.shape)

tensor([8, 8])

In [835]:
n_nodes=adj_recon_logits.shape[0]
n_nodes

10

In [732]:
pad_dim = (torch.tensor(adj_recon_logits.shape[0]) - torch.tensor(mask.shape[0])).item()

In [733]:
padded_mask = F.pad(mask, (0, pad_dim, 0, pad_dim), "constant", False)

In [734]:
mask.size()

torch.Size([8, 8])

In [735]:
padded_mask.size()

torch.Size([10, 10])

In [736]:
sampled_data.edge_label_index

tensor([[ 3,  1,  8,  7,  6,  0],
        [ 4,  2,  9,  5, 10, 11]])

In [737]:
adj_recon_logits[7,2]

tensor(-0.0562)

In [686]:
import torch.nn.functional as F
pad_dim = (torch.tensor(adj_recon_logits.shape[0]) - torch.tensor(mask.shape[0])).item()
mask = torch.squeeze(torch_geometric.utils.to_dense_adj(sampled_data.edge_label_index)) > 0
padded_mask = F.pad(mask, (0, pad_dim, 0, pad_dim), "constant", False)
torch.masked_select(adj_recon_logits, padded_mask)

tensor([ 0.9267, -0.4892, -0.1091, -0.3577])

In [670]:
padded_mask

tensor([[False, False, False,  True, False, False, False, False, False, False,
         False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False],
        [False, False, False, False, False,  True, False, False, False, False,
         False, False],
        [False, False, False, False, False, False, False, False,  True, False,
         False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False],
        [False, False, False, False,  True, False, False, False, False, False,
         False, False]])

In [574]:
sampled_data.edge_label_index[:, pos_edge_label_mask]

tensor([[2, 8, 5],
        [3, 9, 6]])

In [747]:
from torch_geometric.utils import add_self_loops

edge_index_self_loops = add_self_loops(sampled_data.edge_label_index)[0]

In [764]:
edge_label_index = sampled_data.edge_label_index
edge_labels = sampled_data.edge_label

In [756]:
edge_label_index.shape[0]

2

In [773]:
edge_index_self_loops

tensor([[ 3,  1,  8,  7,  6,  0,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
        [ 4,  2,  9,  5, 10, 11,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11]])

In [766]:
n_self_loops = edge_index_self_loops.shape[1] - edge_label_index.shape[1]

In [768]:
n_self_loops

12

In [777]:
sampled_data.x.shape[0] ** 2

2809

In [780]:
(sampled_data.edge_label == 0).sum()

tensor(3)

In [771]:
torch.cat((edge_labels, torch.ones(n_self_loops)))

tensor([1., 1., 1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

In [742]:
edge_index_self_loops.size()

AttributeError: 'tuple' object has no attribute 'size'

In [450]:
sampled_data.val_mask

tensor([False, False, False, False, False, False, False, False, False,  True,
        False, False, False, False, False, False, False, False,  True, False,
        False,  True, False, False, False, False])

In [451]:
sampled_data.x[sampled_data.val_mask]

tensor([[0., 0., 0.,  ..., 3., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 1.],
        [0., 0., 0.,  ..., 1., 0., 0.]])