In [1]:
import os
import torch
import pandas as pd
import scanpy as sc
from preprocess import preprocess_scbert
from graph_datasets import visium_anndata_to_graphdataset, GraphDataset

  from pkg_resources import DistributionNotFound, get_distribution
  left = partial(_left_join_spatialelement_table)
  left_exclusive = partial(_left_exclusive_join_spatialelement_table)
  inner = partial(_inner_join_spatialelement_table)
  right = partial(_right_join_spatialelement_table)
  right_exclusive = partial(_right_exclusive_join_spatialelement_table)
  return module_get_attr_redirect(attr_name, deprecated_mapping=_DEPRECATED)


In [2]:
data_dir = '/proj/berzelius-2024-407/data/human_ba46'

### Create test dataset of 3 Visium arrays

In [3]:
'''
adata = sc.read_h5ad(os.path.join(data_dir, 'adata_ba46_counts_scbert.h5ad'))
sel_arrs = adata.obs.array_name.unique()[:3]
adata = adata[adata.obs.array_name.isin(sel_arrs)]
adata.var_names_make_unique()
'''

"\nadata = sc.read_h5ad(os.path.join(data_dir, 'adata_ba46_counts_scbert.h5ad'))\nsel_arrs = adata.obs.array_name.unique()[:3]\nadata = adata[adata.obs.array_name.isin(sel_arrs)]\nadata.var_names_make_unique()\n"

In [4]:
'''
target_genes = 'data/gene2vec_names.csv'

adata = preprocess_scbert(adata, target_genes=target_genes)
'''

"\ntarget_genes = 'data/gene2vec_names.csv'\n\nadata = preprocess_scbert(adata, target_genes=target_genes)\n"

In [5]:
#adata.write(os.path.join(data_dir, 'adata_test.h5ad'))

### Load into PyG dataset

In [6]:
adata = sc.read_h5ad(os.path.join(data_dir, 'adata_test.h5ad'))
adata

AnnData object with n_obs × n_vars = 9734 × 16906
    obs: 'diagnosis', 'phenotype', 'individual', 'aar', 'cell_type', 'array_name', 'x_arr', 'y_arr'
    uns: 'log1p'

In [7]:
graph_list, graph_names = visium_anndata_to_graphdataset(adata, 
                                                         annot_col='aar', batch_col='array_name',
                                                         x_col='x_arr', y_col='y_arr', 
                                                         pseudo_hex=True)

gdat = GraphDataset(graph_list)

In [8]:
gdat[0]

Data(x=[3538, 16906], edge_index=[2, 20444], edge_attr=[20444, 1], y=[3538], pos=[3538, 2])

### Instantiate graph-based cross-attention model

In [9]:
from graph_attention import GraphPerformerLM

In [10]:
# scBERT encoding parameters
n_genes = 16906  # number of genes in gene2vec model
bin_num = 5  # discrete bins for transcriptomic data (excl. "zero" and "mask")
dim = 200    # dimension of token embeddings (e.g., gene2vec)
depth = 6    # number of attention layers
heads = 10   # number of attention heads per layer

In [11]:
class BERST(GraphPerformerLM):
    def __init__(
        self, 
        n_genes=16906,
        bin_num=5,
        dim=200,
        depth=6,
        heads=10,
        dim_head=64,
        g2v_position_emb=True
    ):
        super(BERST, self).__init__(num_tokens=bin_num+2, max_seq_len=n_genes+1, 
                                    dim=dim, depth=depth, heads=heads, dim_head=dim_head, 
                                    g2v_position_emb=g2v_position_emb)
        self.bin_num = bin_num

    def forward(self, x, edge_index, return_encodings=False, output_attentions=False, **kwargs):
        x[x > self.bin_num] = self.bin_num
        x = x.long()
        new_feat = torch.zeros((x.shape[0],1), dtype=torch.long)
        x = torch.cat((x, new_feat), dim=-1)
        
        return super(BERST, self).forward(x, edge_index, 
                                          return_encodings=return_encodings,
                                          output_attentions=output_attentions,
                                          **kwargs)

berst = BERST()

In [12]:
# Running the full Visium array through the model exhausts memory quickly!
#berst(gdat[0].x, gdat[0].edge_index)

# Instead, let's create a small (k-hop) sub-graph and see how that goes:
from torch_geometric.utils import k_hop_subgraph

subset, edge_index, mapping, edge_mask = k_hop_subgraph(0, 1, gdat[0].edge_index, relabel_nodes=True)

g = gdat[0]
print(g.x[subset])
print(edge_index)

tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 1.1897, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 1.8032],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 1.3946, 1.3946]])
tensor([[0, 0, 0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6],
        [1, 2, 3, 4, 5, 6, 0, 4, 5, 0, 3, 5, 0, 2, 6, 0, 1, 6, 0, 1, 2, 0, 3, 4]])


In [13]:
berst(g.x[subset], edge_index)

GraphPerformer
(tensor([[0, 0, 0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6],
        [1, 2, 3, 4, 5, 6, 0, 4, 5, 0, 3, 5, 0, 2, 6, 0, 1, 6, 0, 1, 2, 0, 3, 4]]),)
{'pos_emb': None}
SequentialSequence
(tensor([[0, 0, 0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6],
        [1, 2, 3, 4, 5, 6, 0, 4, 5, 0, 3, 5, 0, 2, 6, 0, 1, 6, 0, 1, 2, 0, 3, 4]]),)
{'pos_emb': None}


tensor([[[-0.6378,  0.1029, -0.4765,  ..., -1.2899, -0.5108, -0.3404],
         [-0.4916,  0.0834, -0.5365,  ..., -1.2764, -0.4062, -0.3476],
         [-0.5212,  0.1320, -0.4535,  ..., -1.1853, -0.4875, -0.2458],
         ...,
         [-0.7137,  0.3805,  0.1309,  ..., -0.6904,  0.6976, -0.6426],
         [-0.5408,  0.0324, -0.3774,  ..., -1.2139, -0.4519, -0.4303],
         [-0.5623,  0.1184, -0.4049,  ..., -1.3034, -0.4487, -0.3094]],

        [[-0.7157,  0.2403, -0.3214,  ..., -1.1994, -0.4828, -0.2815],
         [-0.5715,  0.2193, -0.3871,  ..., -1.1782, -0.3803, -0.2908],
         [-0.6044,  0.2774, -0.3026,  ..., -1.0923, -0.4700, -0.1829],
         ...,
         [-0.5484,  0.2660, -0.2722,  ..., -1.1948, -0.3292, -0.3553],
         [-0.6260,  0.1722, -0.2172,  ..., -1.1226, -0.4159, -0.3573],
         [-0.6414,  0.2552, -0.2520,  ..., -1.2098, -0.4196, -0.2480]],

        [[-0.6913,  0.1543, -0.4241,  ..., -1.2356, -0.5080, -0.3206],
         [-0.5435,  0.1350, -0.4849,  ..., -1