In [1]:
import argparse
import sys

import mlflow
import numpy as np
import scanpy as sc
import squidpy as sq
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.transforms import RandomLinkSplit, RandomNodeSplit

from autotalker.utils import download_nichenet_ligand_target_mx
from autotalker.data import load_spatial_adata_from_csv
from autotalker.data import SpatialAnnTorchDataset
from autotalker.models import Autotalker

dataset = "squidpy_seqfish"
n_epochs = 10
lr = 0.01
batch_size = 128
n_hidden = 32
n_latent = 16
dropout_rate = 0.

print(f"Using dataset {dataset}.")

if dataset == "deeplinc_seqfish":
    adata = load_spatial_adata_from_csv("datasets/seqFISH/counts.csv",
                                        "datasets/seqFISH/adj.csv")
    cell_type_key = None
elif dataset == "squidpy_seqfish":
    adata = sq.datasets.seqfish()
    sq.gr.spatial_neighbors(adata, radius = 0.04, coord_type="generic")
    cell_type_key = "celltype_mapped_refined"
elif dataset == "squidpy_slideseqv2":
    adata = sq.datasets.slideseqv2()
    sq.gr.spatial_neighbors(adata, radius = 30.0, coord_type="generic")
    cell_type_key = "celltype_mapped_refined"

print(f"Number of nodes: {adata.X.shape[0]}")
print(f"Number of node features: {adata.X.shape[1]}")
avg_edges_per_node = round(
    adata.obsp['spatial_connectivities'].toarray().sum(axis=0).mean(),2)
print(f"Average number of edges per node: {avg_edges_per_node}")
n_edges = int(np.triu(adata.obsp['spatial_connectivities'].toarray()).sum())
print(f"Number of edges: {n_edges}", sep="")

dataset = SpatialAnnTorchDataset(adata, adj_key="spatial_connectivities")
data = Data(x=dataset.x,
            edge_index=dataset.edge_index,
            size_factors=dataset.size_factors)

Using dataset squidpy_seqfish.
Number of nodes: 19416
Number of node features: 351
Average number of edges per node: 4.4
Number of edges: 42694


In [2]:
data

Data(x=[19416, 351], edge_index=[2, 85388], size_factors=[19416])

In [3]:
# Split data on edge level
random_link_split = RandomLinkSplit(num_val=0.1,
                                    num_test=0.1,
                                    is_undirected=True,
                                    neg_sampling_ratio=1.)
random_node_split = RandomNodeSplit(num_val=0.1,
                                    num_test=0,
                                    key="x")

data = random_node_split(data)
train_data, val_data, test_data = random_link_split(data)

print(train_data)
print(val_data)

Data(x=[19416, 351], edge_index=[2, 68312], size_factors=[19416], train_mask=[19416], val_mask=[19416], test_mask=[19416], edge_label=[68312], edge_label_index=[2, 68312])
Data(x=[19416, 351], edge_index=[2, 68312], size_factors=[19416], train_mask=[19416], val_mask=[19416], test_mask=[19416], edge_label=[8538], edge_label_index=[2, 8538])


In [4]:
data

Data(x=[19416, 351], edge_index=[2, 85388], size_factors=[19416], train_mask=[19416], val_mask=[19416], test_mask=[19416])

In [5]:
n_nodes = train_data.train_mask.sum()
node_loader_batch_size = 4
n_nodes / node_loader_batch_size

tensor(4368.5000)

In [6]:
edge_train_loader = torch_geometric.loader.LinkNeighborLoader(
    train_data,
    num_neighbors=[-1]*1, # iterations of neighbor sampling
    batch_size=16,
    edge_label_index=train_data.edge_label_index,
    # edge_label=train_data.edge_label,
    directed=False,
    shuffle=False,
    neg_sampling_ratio=1.0)

In [7]:
node_train_loader = torch_geometric.loader.NeighborLoader(
    train_data,
    num_neighbors=[-1]*1, # iterations of neighbor sampling
    batch_size=4,
    directed=True,
    shuffle=False,
    input_nodes=train_data.train_mask)

In [8]:
edge_val_loader = torch_geometric.loader.LinkNeighborLoader(
    val_data,
    num_neighbors=[-1]*1, # iterations of neighbor sampling
    batch_size=4,
    edge_label_index=val_data.edge_label_index,
    # edge_label=val_data.edge_label,
    directed=False,
    shuffle=True,
    neg_sampling_ratio=1.0)

In [9]:
for epoch in range(2):
    for batch in node_train_loader:
        print(batch)
        break

Data(x=[17, 351], edge_index=[2, 13], size_factors=[17], train_mask=[17], val_mask=[17], test_mask=[17], edge_label=[13], edge_label_index=[2, 13], batch_size=4)
Data(x=[17, 351], edge_index=[2, 13], size_factors=[17], train_mask=[17], val_mask=[17], test_mask=[17], edge_label=[13], edge_label_index=[2, 13], batch_size=4)


In [10]:
node_val_loader = torch_geometric.loader.NeighborLoader(
    val_data,
    num_neighbors=[-1]*1, # iterations of neighbor sampling
    batch_size=8,
    input_nodes=train_data.train_mask)

In [11]:
n_edges = train_data.edge_label_index.size(1)

In [12]:
n_edges

68312

In [13]:
edge_batch_size = 256

In [14]:
edge_train_loader_iters = int(np.ceil(n_edges / edge_batch_size))

In [15]:
edge_train_loader_iters

267

In [16]:
n_nodes = train_data.train_mask.sum()

In [17]:
n_nodes.item()

17474

In [18]:
node_batch_size = int(np.floor(n_nodes / edge_train_loader_iters))

In [19]:
node_batch_size

65

In [20]:
len(node_train_loader)

4369

In [21]:
len(edge_train_loader)

4270

In [22]:
len(node_train_loader)

4369

In [23]:
for epoch in range(4):
    for batch in node_train_loader:
        print(batch)
        break

Data(x=[17, 351], edge_index=[2, 13], size_factors=[17], train_mask=[17], val_mask=[17], test_mask=[17], edge_label=[13], edge_label_index=[2, 13], batch_size=4)
Data(x=[17, 351], edge_index=[2, 13], size_factors=[17], train_mask=[17], val_mask=[17], test_mask=[17], edge_label=[13], edge_label_index=[2, 13], batch_size=4)
Data(x=[17, 351], edge_index=[2, 13], size_factors=[17], train_mask=[17], val_mask=[17], test_mask=[17], edge_label=[13], edge_label_index=[2, 13], batch_size=4)
Data(x=[17, 351], edge_index=[2, 13], size_factors=[17], train_mask=[17], val_mask=[17], test_mask=[17], edge_label=[13], edge_label_index=[2, 13], batch_size=4)


for i, zipped in enumerate(zip(edge_train_loader, node_train_loader)):
    if i > 2180:
        print(zipped)

In [25]:
len(edge_train_loader)

4270

In [26]:
len(node_train_loader)

4369

In [31]:
sampled_data = next(iter(edge_val_loader))
print(sampled_data)

Data(x=[55, 351], edge_index=[2, 114], size_factors=[55], train_mask=[55], val_mask=[55], test_mask=[55], edge_label=[8], edge_label_index=[2, 8])


In [13]:
sampled_data.edge_label_index

tensor([[ 0,  2, 13,  4,  8, 15, 12, 11],
        [ 1,  3, 14,  5,  9,  6, 10,  7]])

In [12]:
sampled_data = node_split(sampled_data)

In [28]:
sampled_data.train_mask

NameError: name 'sampled_data' is not defined

In [16]:
sampled_data_train = sampled_data[sampled_data.train_mask]

KeyError: tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True, False,
         True,  True,  True, False,  True,  True,  True,  True,  True,  True,
         True, False,  True,  True,  True,  True,  True,  True,  True, False,
         True, False,  True,  True,  True, False,  True,  True,  True,  True,
         True,  True,  True, False,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True])

In [27]:
len(sampled_data.x[sampled_data.val_mask])

7

In [28]:
sampled_data["val_mask"]

tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,  True,
        False, False, False,  True, False, False, False, False, False, False,
        False,  True, False, False, False, False, False, False, False,  True,
        False,  True, False, False, False,  True, False, False, False, False,
        False, False, False,  True, False, False, False, False, False, False,
        False, False, False, False, False, False])

In [25]:
len(sampled_data.x)

66

In [372]:
from torch_geometric.utils import coalesce
coalesce(sampled_data.edge_label_index)

tensor([[ 0,  2,  4,  6,  8,  9, 12, 14, 16, 18, 20, 21, 24, 26, 28, 30],
        [ 1,  3,  5,  7, 10, 11, 13, 15, 17, 19, 23, 22, 25, 27, 29, 31]])

In [373]:
sampled_data.edge_label_index[0].sort(dim=-1)

torch.return_types.sort(
values=tensor([ 0,  2,  4,  6,  8,  9, 12, 14, 16, 18, 20, 21, 24, 26, 28, 30]),
indices=tensor([ 3, 10, 12,  5,  4,  9, 15,  0,  6, 13,  8,  7,  1,  2, 11, 14]))

In [296]:
sampled_data.edge_label

tensor([2., 2., 2., 0., 0., 0.])

In [120]:
sort_index = sampled_data.edge_label_index[0].sort(dim=-1).indices
edge_label_index_sorted = sampled_data.edge_label_index[:, sort_index]
edge_labels_sorted = sampled_data.edge_label[sort_index]

In [124]:
edge_label_index_sorted

tensor([[ 0,  2,  4,  6, 10, 11, 12, 15],
        [ 1,  3,  5, 14,  7, 13,  8,  9]])

In [123]:
edge_labels_sorted

tensor([1., 1., 1., 0., 0., 1., 0., 0.])

In [121]:
edge_labels = sampled_data.edge_label

In [122]:
edge_labels

tensor([1., 1., 1., 1., 0., 0., 0., 0.])

In [98]:
edge_label_index_sorted

torch.return_types.sort(
values=tensor([ 0,  2,  4,  6, 10, 11, 12, 15]),
indices=tensor([3, 0, 2, 7, 5, 1, 4, 6]))

In [60]:
edge_label_index_sorted.indices[0]

tensor([1, 0, 2, 3])

In [59]:
sampled_data.edge_label[edge_label_index_sorted.indices[0]]

tensor([0., 0., 0., 0.])

In [42]:
torch.arange(x.size(0)).unsqueeze(1)

tensor([[0],
        [1],
        [2]])

In [None]:
onehot.scatter_(1, idx.long(), 1)

In [18]:
sort_edge_index(sampled_data.edge_label)

IndexError: too many indices for tensor of dimension 1

In [799]:
sampled_data = next(iter(train_loader))
print(sampled_data)

Data(x=[1064, 351], edge_index=[2, 1052], conditions=[1064], size_factors=[1064], edge_label=[128], edge_label_index=[2, 128])


In [800]:
len(valid_loader)

67

In [801]:
len(train_loader)

534

In [783]:
1423*3

4269

In [723]:
len(train_loader)

11386

In [724]:
sampled_data = node_split(sampled_data)

In [725]:
sampled_data

Data(x=[53, 351], edge_index=[2, 50], conditions=[53], size_factors=[53], edge_label=[6], edge_label_index=[2, 6], train_mask=[53], val_mask=[53], test_mask=[53])

In [726]:
import torch
sampled_data.edge_label
pos_edge_label_mask = (sampled_data.edge_label > 0)

In [842]:
adj_recon_logits = torch.randn(10, 10)

In [728]:
adj_recon_logits

tensor([[ 1.0437, -1.0448, -0.3835,  0.7467, -0.0287,  0.3265,  0.1181,  2.2507,
          1.9912, -0.9735],
        [-1.1402, -0.4663,  1.3603, -1.2761, -0.3817,  0.7807,  1.9599,  0.8239,
          0.5776, -0.8431],
        [ 0.6757, -0.8224,  1.1456,  0.1753, -0.0155, -0.8384, -1.1537,  1.7968,
          1.2332, -1.7170],
        [ 0.3345, -0.1133,  0.2663,  0.9104,  0.5621, -0.2775, -1.2293, -0.6008,
         -0.4959, -1.5902],
        [-0.1176, -0.9366,  1.0543, -1.2974,  0.5777, -0.3199, -0.3814,  1.0258,
         -0.8282,  0.3318],
        [-0.9692, -1.1611, -0.5451, -0.4355,  0.1138,  2.2474,  0.1553, -1.0575,
          1.0705, -0.2537],
        [ 0.6907, -0.4729, -0.4832,  0.2529,  1.0338, -0.9106,  0.3217, -0.8446,
         -0.7785,  2.1865],
        [-0.1002, -0.9662, -0.0562,  1.3300, -0.8733, -0.8903,  0.7502,  0.3109,
         -0.2238,  1.2834],
        [ 0.3995, -0.4492, -0.1522, -0.0981, -1.3563,  1.2207, -1.0342,  0.0348,
          1.1774,  1.6985],
        [-0.7843,  

In [730]:
(torch.tensor(adj_recon_logits.shape[0]) - torch.tensor(mask.shape[0])).item()

2

In [731]:
torch.tensor(mask.shape)

tensor([8, 8])

In [835]:
n_nodes=adj_recon_logits.shape[0]
n_nodes

10

In [732]:
pad_dim = (torch.tensor(adj_recon_logits.shape[0]) - torch.tensor(mask.shape[0])).item()

In [733]:
padded_mask = F.pad(mask, (0, pad_dim, 0, pad_dim), "constant", False)

In [734]:
mask.size()

torch.Size([8, 8])

In [735]:
padded_mask.size()

torch.Size([10, 10])

In [736]:
sampled_data.edge_label_index

tensor([[ 3,  1,  8,  7,  6,  0],
        [ 4,  2,  9,  5, 10, 11]])

In [737]:
adj_recon_logits[7,2]

tensor(-0.0562)

In [686]:
import torch.nn.functional as F
pad_dim = (torch.tensor(adj_recon_logits.shape[0]) - torch.tensor(mask.shape[0])).item()
mask = torch.squeeze(torch_geometric.utils.to_dense_adj(sampled_data.edge_label_index)) > 0
padded_mask = F.pad(mask, (0, pad_dim, 0, pad_dim), "constant", False)
torch.masked_select(adj_recon_logits, padded_mask)

tensor([ 0.9267, -0.4892, -0.1091, -0.3577])

In [670]:
padded_mask

tensor([[False, False, False,  True, False, False, False, False, False, False,
         False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False],
        [False, False, False, False, False,  True, False, False, False, False,
         False, False],
        [False, False, False, False, False, False, False, False,  True, False,
         False, False],
        [False, False, False, False, False, False, False, False, False, False,
         False, False],
        [False, False, False, False,  True, False, False, False, False, False,
         False, False]])

In [574]:
sampled_data.edge_label_index[:, pos_edge_label_mask]

tensor([[2, 8, 5],
        [3, 9, 6]])

In [747]:
from torch_geometric.utils import add_self_loops

edge_index_self_loops = add_self_loops(sampled_data.edge_label_index)[0]

In [476]:
edge_label_index = sampled_data.edge_label_index
edge_labels = sampled_data.edge_label

In [756]:
edge_label_index.shape[0]

2

In [773]:
edge_index_self_loops

tensor([[ 3,  1,  8,  7,  6,  0,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
        [ 4,  2,  9,  5, 10, 11,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11]])

In [766]:
n_self_loops = edge_index_self_loops.shape[1] - edge_label_index.shape[1]

In [768]:
n_self_loops

12

In [777]:
sampled_data.x.shape[0] ** 2

2809

In [780]:
(sampled_data.edge_label == 0).sum()

tensor(3)

In [771]:
torch.cat((edge_labels, torch.ones(n_self_loops)))

tensor([1., 1., 1., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

In [742]:
edge_index_self_loops.size()

AttributeError: 'tuple' object has no attribute 'size'

In [450]:
sampled_data.val_mask

tensor([False, False, False, False, False, False, False, False, False,  True,
        False, False, False, False, False, False, False, False,  True, False,
        False,  True, False, False, False, False])

In [451]:
sampled_data.x[sampled_data.val_mask]

tensor([[0., 0., 0.,  ..., 3., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 1.],
        [0., 0., 0.,  ..., 1., 0., 0.]])

In [392]:
t1 = torch.tensor([[1, 2, 3 ,4, 2, 5], [1, 2, 7, 4, 2, 5]])
t1
t1_no_dups = torch.unique(t1, dim=1,return_inverse=True)

In [393]:
t1

tensor([[1, 2, 3, 4, 2, 5],
        [1, 2, 7, 4, 2, 5]])

In [394]:
t1_no_dups

(tensor([[1, 2, 3, 4, 5],
         [1, 2, 7, 4, 5]]),
 tensor([0, 1, 2, 3, 1, 4]))

In [456]:
sampled_data.edge_label

tensor([1., 1., 1., 1., 1., 1., 1., 1.])

In [458]:
torch.tensor([1])

tensor([1])

In [None]:
sort_index = edge_label_index[0].sort(dim=-1).indices
# edge_labels_sorted = edge_label[sort_index]

In [45]:
edge_label_index = torch.cat((sampled_data.edge_label_index, torch.tensor([[2, 1], [3, 6]])), dim=1)
edge_label = torch.cat((sampled_data.edge_label, torch.tensor([0, 1.])), dim=0)

NameError: name 'torch' is not defined

In [502]:
edge_label_index

tensor([[ 4, 11, 13,  0,  7,  2,  9, 15,  2,  1],
        [ 5, 12, 14,  1, 10,  3,  8,  6,  3,  6]])

In [503]:
edge_label

tensor([1., 1., 1., 1., 0., 0., 0., 0., 0., 1.])

In [504]:
def unique(x, dim=-1):
    unique, inverse = torch.unique(x, return_inverse=True, dim=dim)
    perm = torch.arange(inverse.size(dim), dtype=inverse.dtype, device=inverse.device)
    inverse, perm = inverse.flip([dim]), perm.flip([dim])
    return unique, inverse.new_empty(unique.size(dim)).scatter_(dim, inverse, perm)

In [505]:
edge_label_index, sort_index = unique(edge_label_index)

In [506]:
edge_label_index

tensor([[ 0,  1,  2,  4,  7,  9, 11, 13, 15],
        [ 1,  6,  3,  5, 10,  8, 12, 14,  6]])

In [507]:
sort_index

tensor([3, 9, 5, 0, 4, 6, 1, 2, 7])

In [486]:
edge_label[sort_index]

tensor([1., 0., 1., 0., 0., 1., 1., 0.])

In [444]:
print(sort_index)

tensor([3, 5, 4, 0, 6, 7, 1, 2])


In [434]:
sort_index = edge_label_index[0].sort(dim=-1).indices
# edge_labels_sorted = edge_label[sort_index]

In [435]:
print(sort_index)

tensor([3, 5, 8, 9, 4, 0, 6, 7, 1, 2])


In [417]:
edge_label_index

tensor([[ 6, 12, 14,  0,  4,  2,  8, 10],
        [ 7, 13, 15,  1,  5,  3,  9, 11]])

In [32]:
sampled_data.edge_label_index

tensor([[ 2, 11,  0,  7, 14,  3, 10,  5],
        [ 6, 12,  8, 13,  1,  4,  9, 15]])

In [43]:
from torch_geometric.utils import index_to_mask, mask_to_index, to_dense_adj

In [44]:
to_dense_adj(sampled_data.edge_label_index)

tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0.,

In [53]:
A = mask_to_index(index_to_mask(sampled_data.edge_label_index, 16))

In [56]:
import torch
x = torch.cat((A, torch.tensor([])), 0)

In [57]:
x

tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13.,
        14., 15.])

In [101]:
x = torch.tensor([[2, 3],[4, 7]])

In [102]:
x.sum(1)

tensor([ 5, 11])

In [103]:
x

tensor([[2, 3],
        [4, 7]])