In [1]:
import pickle
import numpy as np
import pandas as pd
import networkx as nx
from pathlib import Path
from typing import Dict, Tuple, Optional
import torch
import scanpy as sc
import anndata

In [2]:
# Data parameters
trajectory_file = '../data/trajectories/traj_data.pkl' 
prior_graph_file='../data/trajectories/cancer_granger_prior_graph_nx_20.pkl'
gene_names_file='../data/trajectories/gene_names.txt'
n_top_genes = 20  # Number of genes from prior graph to use
use_mean_trajectory= True


In [3]:
# Load trajectory data from file
trajectory_path = Path(trajectory_file)
with open(trajectory_path, "rb") as f:
    trajectories = pickle.load(f)

print(trajectories.shape)

# Load prior graph from file
prior_graph_path = Path(prior_graph_file)
with open(prior_graph_path, "rb") as f:
    prior_graph = pickle.load(f)

gene_names_path = Path(gene_names_file)
with open(gene_names_path, "r") as f:
    gene_names = [line.strip() for line in f.readlines()]

(100, 100, 21465)


In [4]:
n_timepoints, n_trajectories, n_genes = trajectories.shape

# Select highly variable genes using scanpy
# Reshape trajectories to (n_samples, n_genes) for scanpy
# Combine timepoints and trajectories into samples dimension
entire_trajectory = trajectories.reshape(-1, n_genes)

# Create AnnData object for scanpy
adata = anndata.AnnData(X=entire_trajectory)
adata.var_names = gene_names

In [None]:
print(f"Identifying {n_top_genes} highly variable genes.")
# Identify highly variable genes
sc.pp.highly_variable_genes(
    adata,
    n_top_genes=n_top_genes
)


Identifying 20 highly variable genes.


  dispersion = np.log(dispersion)


In [6]:
# Get indices of highly variable genes
selected_genes = np.where(adata.var['highly_variable'])[0]
# selected_genes_names = np.where(adata.var['highly_variable'])[0]

In [7]:
# 
filtered_gene_names = adata.var_names[selected_genes]
# Filter trajectories to selected genes
filtered_trajectories = trajectories[:, :, selected_genes]

# Extract node features
if use_mean_trajectory:
    node_features = torch.tensor(
        filtered_trajectories.mean(axis=1),
        dtype=torch.float32
    )  # Shape: (n_timepoints, n_top_genes)
else:
    node_features = torch.tensor(
        filtered_trajectories,
        dtype=torch.float32
    )  # Shape: (n_timepoints, n_trajectories, n_top_genes)

# Convert prior graph to adjacency matrix
n_nodes = len(prior_graph.nodes())
prior_adjacency = torch.zeros(n_nodes, n_nodes)
for edge in prior_graph.edges():
    prior_adjacency[edge[0], edge[1]] = 1
    prior_adjacency[edge[1], edge[0]] = 1  # Symmetric

In [8]:
filtered_gene_names

Index(['ID3', 'ODF2L', 'LAPTM4A', 'IGFBP5', 'HLA-C', 'NPTX2', 'SAT1', 'TIMP1',
       'MSMP', 'ACTA2', 'MGP', 'KRT7', 'KRT81', 'KRT19', 'APLP1', 'APOE',
       'TTYH1', 'COL6A1', 'COL6A2', 'MT-ND3'],
      dtype='object')