In [8]:
import sys
from pathlib import Path

# Add project root to Python path
project_root = Path().resolve().parent.parent
sys.path.append(str(project_root))


import torch
import pandas as pd
import pickle
from segger.training.segger_data_module import SeggerDataModule
from segger.prediction.predict import load_model
from docs.notebooks.visualization.transcript_visualization import extract_attention_df, visualize_attention_difference
from docs.notebooks.visualization.gene_visualization import summarize_attention_by_gene_df, compare_attention_patterns, visualize_all_attention_patterns, visualize_all_attention_patterns_with_metrics
from docs.notebooks.visualization.gene_embedding import visualize_attention_embedding, visualize_all_embeddings, visualize_average_embedding, gene_embedding
from docs.notebooks.visualization.utils import safe_divide_sparse_numpy
from docs.notebooks.visualization.interactive_attention import get_top_genes_across_all_layers
from docs.notebooks.visualization.interactive_attention import load_attention_data
from scipy.sparse import lil_matrix
import numpy as np

In [None]:
# Paths to data and models
model_version = 1
model_path = Path('models') / "lightning_logs" / f"version_{model_version}"
ls = load_model(model_path / "checkpoints")
ls.eval()

# Load transcripts
transcripts = pd.read_parquet(Path('data_xenium') / 'transcripts.parquet')

# Move model to device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ls = ls.to(device)

# Initialize the Lightning data module
dm = SeggerDataModule(
    data_dir=Path('data_segger'),
    batch_size=2,
    num_workers=2,
)
dm.setup()


# Get a sample batch from the data module
len(dm.train)

/Users/finleyyu/anaconda3/envs/segger/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.
  data = torch.load(filepath)


969

In [6]:
# Get a sample batch
batch = dm.train[0].to(device)

# Get gene names
transcript_ids = batch['tx'].id.cpu().numpy()
id_to_gene = dict(zip(transcripts['transcript_id'], transcripts['feature_name']))
gene_names = [id_to_gene[id] for id in transcript_ids]
# Create a dictionary to map gene names to indices
gene_to_idx = {gene: idx for idx, gene in enumerate(transcripts['feature_name'].unique())}

# Run forward pass to get attention weights
with torch.no_grad():
    hetero_model = ls.model
    x_dict = batch.x_dict
    edge_index_dict = batch.edge_index_dict
    _, attention_weights = hetero_model(x_dict, edge_index_dict)

edge_type = "tx-bd"

# Extract attention weights
attention_df = extract_attention_df(attention_weights, gene_names)

In [23]:
gene_names_batch = [id_to_gene[id] for id in transcript_ids]
print(gene_names_batch)
gene_names = transcripts['feature_name'].unique().tolist()
print(gene_names)

# check if the gene names are the same
print(gene_names_batch == gene_names)



['BTF3', 'SUMO2', 'NNMT', 'SUMO2', 'BTF3', 'CAP1', 'SUMO2', 'SUMO2', 'CAP1', 'FBLN1', 'APCDD1', 'FBLN1', 'CAP1', 'SUMO2', 'SUMO2', 'SUMO2', 'CAP1', 'SUMO2', 'NNMT', 'CAP1', 'CLEC10A', 'KRT7', 'BTF3', 'BTF3', 'SUMO2', 'NNMT', 'SSR4', 'PDGFRB', 'BTF3', 'BTF3', 'SUMO2', 'BTF3', 'BTF3', 'SUMO2', 'SUMO2', 'PDGFRB', 'SUMO2', 'BTF3', 'FBLN1', 'BTF3', 'BTF3', 'BTF3', 'CAP1', 'BTF3', 'BTF3', 'CAP1', 'CAP1', 'CAP1', 'SUMO2', 'BTF3', 'KRT7', 'SUMO2', 'PDGFRB', 'BTF3', 'NNMT', 'CAP1', 'SUMO2', 'LAMP3', 'BTF3', 'NNMT', 'SUMO2', 'CAP1', 'BTF3', 'KRT7', 'SUMO2', 'SUMO2', 'SUMO2', 'IL1B', 'BTF3', 'BTF3', 'BTF3', 'SUMO2', 'IL1B', 'FBLN1', 'BTF3', 'KRT7', 'CAP1', 'BTF3', 'SUMO2', 'IL1B', 'BTF3', 'SUMO2', 'SSR4', 'NNMT', 'SSR4', 'KRT7', 'SUMO2', 'SUMO2', 'BTF3', 'SUMO2', 'BTF3', 'SUMO2', 'BTF3', 'BTF3', 'NNMT', 'NNMT', 'SUMO2', 'SUMO2', 'NNMT', 'PTN', 'BTF3', 'CAP1', 'SSR4', 'CAP1', 'BTF3', 'BTF3', 'SUMO2', 'NNMT', 'CAP1', 'ANGPT2', 'SUMO2', 'SUMO2', 'CAP1', 'CAP1', 'DUSP2', 'BTF3', 'NNMT', 'SUMO2', 'CAP

['RHOA', 'PFDN5', 'CAPN8', 'HNRNPA2B1', 'ASAH1', 'PTPRC', 'ATP5MD', 'ATP5F1B', 'NOP53', 'AMY2A', 'GNAS', 'GPR183', 'NegControlProbe_00034', 'PABPC1', 'SLC25A3', 'HSP90B1', 'LILRA4', 'FSTL1', 'SUMO2', 'CD34', 'AQP3', 'TMBIM6', 'MS4A4A', 'CFAP53', 'CIB1', 'DUSP1', 'IL1B', 'OST4', 'TM4SF4', 'STEAP4', 'ACTG2', 'SFTA2', 'DSP', 'SLC26A2', 'NDUFC2', 'EPCAM', 'PMP22', 'CFB', 'CD28', 'CAP1', 'EDN1', 'FKBP11', 'VAMP8', 'DES', 'SST', 'SERPINB1', 'EIF2S2', 'SSR3', 'OGN', 'CXCR4', 'GEM', 'MEDAG', 'PPA1', 'FGL2', 'CLEC14A', 'RTN4', 'UPK3B', 'CD3E', 'TFPI', 'APCDD1', 'PCSK2', 'MZB1', 'YAF2', 'TCIM', 'HINT1', 'BCL2L11', 'NUPR1', 'SKP1', 'SCGB2A1', 'OSTC', 'MFAP5', 'SFRP2', 'KNG1', 'CD3D', 'DLK1', 'SSR2', 'FSTL3', 'EPAS1', 'BIRC3', 'FBN1', 'MYH9', 'GLIPR1', 'KLF6', 'BASP1', 'UGP2', 'ADGRL4', 'FBLN1', 'C7', 'CD4', 'ALDH1A3', 'LYVE1', 'VWF', 'DMBT1', 'MRC1', 'SSR4', 'LTBP2', 'CYP2F1', 'MYLK', 'PDGFRA', 'BTF3', 'SERPINB3', 'C1R', 'NNMT', 'LAPTM5', 'MMRN2', 'MS4A6A', 'PPP1R12B', 'PRDM1', 'HMGCS2', 'HSPA8',

In [None]:
adj_matrix, count_matrix = summarize_attention_by_gene_df(
    attention_df, 
    layer_idx=layer_idx, 
    head_idx=head_idx, 
    edge_type=edge_type, 
    gene_to_idx=gene_to_idx, 
    visualize=False
)

In [13]:
# Load data
attention_gene_matrix_dict = load_attention_data()

# Get dimensions
n_layers = len(attention_gene_matrix_dict["adj_matrix"])
n_heads = len(attention_gene_matrix_dict["adj_matrix"][0])

# Load gene names
transcripts = pd.read_parquet(Path('data_xenium') / 'transcripts.parquet')
gene_names = transcripts['feature_name'].unique().tolist()

In [17]:
# Initialize average attention matrix
avg_attention = np.zeros_like(attention_gene_matrix_dict['adj_matrix'][0][0].toarray())

# Average attention matrices across all layers and heads
for layer_idx in range(n_layers):
    for head_idx in range(n_heads):
        attention_matrix = attention_gene_matrix_dict['adj_matrix'][layer_idx][head_idx]
        if isinstance(attention_matrix, (lil_matrix)):
            attention_matrix = attention_matrix.toarray()
        avg_attention += attention_matrix

avg_attention /= (n_layers * n_heads)

# Calculate gene importance (sum of attention weights)
gene_importance = np.array(avg_attention.sum(axis=1)).flatten()

# Get indices of top and bottom 10 genes
top_10_idx = np.argsort(gene_importance)[-10:]
bottom_10_idx = np.argsort(gene_importance)[:10]



In [16]:
# print out the names of the top 10 and bottom 10 genes
print(gene_to_idx[top_10_idx[0]])
print(gene_to_idx[bottom_10_idx])

# print out the names of the top 10 and bottom 10 genes
print(gene_to_idx[top_10_idx])
print(gene_to_idx[bottom_10_idx])

KeyError: 208

In [20]:
id_to_gene[top_10_idx.tolist()]

TypeError: unhashable type: 'list'