In [8]:
import sys
from pathlib import Path

# Add project root to Python path
project_root = Path().resolve().parent.parent
sys.path.append(str(project_root))


import torch
import pandas as pd
import pickle
from segger.training.segger_data_module import SeggerDataModule
from segger.prediction.predict import load_model
from docs.notebooks.visualization.batch_visualization import extract_attention_df, visualize_attention_df
from docs.notebooks.visualization.gene_visualization import summarize_attention_by_gene_df
from docs.notebooks.visualization.gene_embedding import visualize_attention_embedding, visualize_all_embeddings, visualize_average_embedding, gene_embedding
from docs.notebooks.visualization.utils import safe_divide_sparse_numpy, get_top_genes_by_attention, load_attention_data
from scipy.sparse import lil_matrix
import numpy as np

In [2]:
# Paths to data and models
model_version = 1
model_path = Path('models') / "lightning_logs" / f"version_{model_version}"
ls = load_model(model_path / "checkpoints")
ls.eval()

# Load transcripts
transcripts = pd.read_parquet(Path('data_xenium') / 'transcripts.parquet')

# Load cells
cells = pd.read_parquet(Path('data_xenium') / 'cell_boundaries.parquet')

# Load gene functional types from xlsx file
gene_types = pd.read_excel(Path('data_xenium') / 'Gene_Functional_Categories.xlsx')

# Move model to device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ls = ls.to(device)

# Initialize the Lightning data module
dm = SeggerDataModule(
    data_dir=Path('data_segger'),
    batch_size=2,
    num_workers=2,
)
dm.setup()


# Get a sample batch from the data module
len(dm.train)

/Users/finleyyu/anaconda3/envs/segger/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.


969

In [3]:
# Get a sample batch
batch = dm.train[0].to(device)

# Get gene names
transcript_ids = batch['tx'].id.cpu().numpy()
id_to_gene = dict(zip(transcripts['transcript_id'], transcripts['feature_name']))
gene_names_batch = [id_to_gene[id] for id in transcript_ids]
# Create a dictionary to map gene names to indices
gene_to_idx = {gene: idx for idx, gene in enumerate(transcripts['feature_name'].unique())}
cell_ids_batch = batch['bd'].id # list of cell ids for the batch one

gene_types = pd.read_csv(Path('data_xenium') / 'gene_groups.csv')
gene_types_dict = dict(zip(gene_types['gene'], gene_types['group']))

cell_types = pd.read_csv(Path('data_xenium') / 'cell_groups.csv')
cell_types_dict = dict(zip(cell_types['cell_id'], cell_types['group']))

# Run forward pass to get attention weights
with torch.no_grad():
    hetero_model = ls.model
    x_dict = batch.x_dict
    edge_index_dict = batch.edge_index_dict
    _, attention_weights = hetero_model(x_dict, edge_index_dict)

edge_type = "tx-bd"

# Extract attention weights
attention_df = extract_attention_df(attention_weights, gene_names_batch, cell_ids_batch, cell_types_dict, edge_type)

  data = torch.load(filepath)


In [13]:
len(cell_types['cell_id'].unique())

190962

In [9]:
visualize_attention_df(attention_df, 1, 1, edge_type, gene_types_dict=gene_types_dict, cell_types_dict=cell_types_dict)

IndexError: index 25 is out of bounds for axis 1 with size 10

In [7]:
print(cell_types['group'].unique())

['Fibroblasts' 'B Cells' 'CXCL9/10 Cells' 'T Cells' 'Mast Cells'
 'Lymphatic Endothelial Cells' 'Macrophages' 'Endothelial' 'Endocrine 2'
 'Tumor Cells' 'CFTR- Tumor Cells' 'Metaplastic Cells' 'Endocrine 1'
 'Smooth Muscle Cells' 'Acinar' 'Ductal']


In [4]:
attention_df.head(5)

Unnamed: 0,source,target,edge_type,layer,head,attention_weight,source_gene,target_cell
0,8822,70,tx-bd,1,1,0.000258,CTHRC1,Endothelial
1,9640,199,tx-bd,1,1,0.000973,GNAS,CFTR- Tumor Cells
2,14513,32,tx-bd,1,1,0.002149,PPA1,T Cells
3,18315,156,tx-bd,1,1,0.00282,MORF4L1,CXCL9/10 Cells
4,9943,154,tx-bd,1,1,0.000574,HSPA8,Fibroblasts


In [5]:
attention_df[attention_df['source_gene'].isin(gene_types_dict.keys())]
attention_df[attention_df['target_cell'].isin(cell_types_dict.keys())]

Unnamed: 0,source,target,edge_type,layer,head,attention_weight,source_gene,target_cell
0,8822,70,tx-bd,1,1,0.000258,CTHRC1,Endothelial
2,14513,32,tx-bd,1,1,0.002149,PPA1,T Cells
5,19129,96,tx-bd,1,1,0.087638,CCL5,CFTR- Tumor Cells
18,3459,195,tx-bd,1,1,0.000759,SERPINB1,CFTR- Tumor Cells
20,10995,18,tx-bd,1,1,0.014192,ACTA2,Fibroblasts
...,...,...,...,...,...,...,...,...
164180,4052,212,tx-bd,5,4,0.005588,CTHRC1,Fibroblasts
164183,8064,191,tx-bd,5,4,0.006070,DSP,CFTR- Tumor Cells
164192,4386,69,tx-bd,5,4,0.007327,CTHRC1,Endothelial
164197,3116,36,tx-bd,5,4,0.447217,CD163,Macrophages


dict_items([('aaaanbjb-1', 'Fibroblasts'), ('aaabbnlb-1', 'Fibroblasts'), ('aaabdean-1', 'B Cells'), ('aaabkppc-1', 'CXCL9/10 Cells'), ('aaablfle-1', 'T Cells'), ('aaacbebp-1', 'Mast Cells'), ('aaacjhoh-1', 'Lymphatic Endothelial Cells'), ('aaacjpei-1', 'Macrophages'), ('aaaclnja-1', 'Macrophages'), ('aaacmppf-1', 'CXCL9/10 Cells'), ('aaacojdd-1', 'Fibroblasts'), ('aaadjkgh-1', 'T Cells'), ('aaadnocp-1', 'Fibroblasts'), ('aaaedgpn-1', 'Fibroblasts'), ('aaaffjph-1', 'Fibroblasts'), ('aaaffnge-1', 'Macrophages'), ('aaafmebj-1', 'Macrophages'), ('aaafongp-1', 'Macrophages'), ('aaagapbb-1', 'Lymphatic Endothelial Cells'), ('aaagbane-1', 'Macrophages'), ('aaagkdae-1', 'Fibroblasts'), ('aaahdaij-1', 'Fibroblasts'), ('aaahelod-1', 'Fibroblasts'), ('aaahihap-1', 'Fibroblasts'), ('aaahldhe-1', 'Fibroblasts'), ('aaaihgia-1', 'Macrophages'), ('aaaiocjc-1', 'Fibroblasts'), ('aaajeafb-1', 'Fibroblasts'), ('aaajijkl-1', 'Macrophages'), ('aaajkhla-1', 'Fibroblasts'), ('aaajknif-1', 'Fibroblasts'), ('

In [13]:
# Load data
attention_gene_matrix_dict = load_attention_data()

# Get dimensions
n_layers = len(attention_gene_matrix_dict["adj_matrix"])
n_heads = len(attention_gene_matrix_dict["adj_matrix"][0])

# Load gene names
transcripts = pd.read_parquet(Path('data_xenium') / 'transcripts.parquet')
gene_names = transcripts['feature_name'].unique().tolist()

In [17]:
# Initialize average attention matrix
avg_attention = np.zeros_like(attention_gene_matrix_dict['adj_matrix'][0][0].toarray())

# Average attention matrices across all layers and heads
for layer_idx in range(n_layers):
    for head_idx in range(n_heads):
        attention_matrix = attention_gene_matrix_dict['adj_matrix'][layer_idx][head_idx]
        if isinstance(attention_matrix, (lil_matrix)):
            attention_matrix = attention_matrix.toarray()
        avg_attention += attention_matrix

avg_attention /= (n_layers * n_heads)

# Calculate gene importance (sum of attention weights)
gene_importance = np.array(avg_attention.sum(axis=1)).flatten()

# Get indices of top and bottom 10 genes
top_10_idx = np.argsort(gene_importance)[-10:]
bottom_10_idx = np.argsort(gene_importance)[:10]

