In [23]:
import os
import torch
import pandas as pd
import scanpy as sc
from preprocess import preprocess_scbert
from graph_datasets import visium_anndata_to_graphdataset, GraphDataset

### Create test dataset of 3 Visium arrays

In [2]:
adata = sc.read_h5ad('/Volumes/Aidan_NYGC/Visium/2024_LSC_MTC/adata_lsc_raw_counts.h5ad')
sel_arrs = adata.obs.array.unique()[:3]
adata = adata[adata.obs.array.isin(sel_arrs)]

In [3]:
target_genes = 'data/gene2vec_names.csv'

adata = preprocess_scbert(adata, target_genes=target_genes, gene_symbols='gene_symbol')

  adata.var[gene_symbols] = adata.var[gene_symbols].astype(str)


In [4]:
adata.write('adata_test.h5ad')

### Load into PyG dataset

In [9]:
adata = sc.read_h5ad('adata_test.h5ad')

In [10]:
graph_list, graph_names = visium_anndata_to_graphdataset(adata, x_col='x', y_col='y', pseudo_hex=True)

gdat = GraphDataset(graph_list)

In [11]:
gdat[0]

Data(x=[4880, 16906], edge_index=[2, 28718], edge_attr=[28718, 1], y=[4880], pos=[4880, 2])

### Instantiate graph-based cross-attention model

In [12]:
from graph_attention import GraphPerformerLM

In [16]:
# scBERT encoding parameters
n_genes = 16906  # number of genes in gene2vec model
bin_num = 5  # discrete bins for transcriptomic data (excl. "zero" and "mask")
dim = 200    # dimension of token embeddings (e.g., gene2vec)
depth = 6    # number of attention layers
heads = 10   # number of attention heads per layer

In [24]:
class BERST(GraphPerformerLM):
    def __init__(
        self, 
        n_genes=16906,
        bin_num=5,
        dim=200,
        depth=6,
        heads=10,
        dim_head=64,
        g2v_position_emb=True
    ):
        super(BERST, self).__init__(num_tokens=bin_num+2, max_seq_len=n_genes+1, 
                                    dim=dim, depth=depth, heads=heads, dim_head=dim_head, 
                                    g2v_position_emb=g2v_position_emb)
        self.bin_num = bin_num

    def forward(self, x, edge_index, return_encodings=False, output_attentions=False, **kwargs):
        x[x > self.bin_num] = self.bin_num
        x = x.long()
        new_feat = torch.zeros((x.shape[0],1), dtype=torch.long)
        x = torch.cat((x, new_feat), dim=-1)
        
        return super(BERST, self).forward(x, edge_index, 
                                          return_encodings=return_encodings,
                                          output_attentions=output_attentions,
                                          **kwargs)

berst = BERST()

In [None]:
berst(gdat[0].x, gdat[0].edge_index)

tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])
