In [None]:
import torch
import anndata
import numpy as np
import pandas as pd
from tqdm import tqdm
from glob import glob
pt_files = sorted(glob("/gpfs/gibbs/pi/krishnaswamy_smita/hm638/SCGFM/data/merfish_brain_preprocessed/*"))
for i, file in tqdm(enumerate(pt_files)):
    slide_name = file.split("/")[-1][:-3]
    graphs = torch.load(file, weights_only=False)
    X = []
    for k in range(1, len(graphs)):
        X.append(graphs[k].X.squeeze(1).tolist())
    pos = graphs[0].X.float().numpy()
    X = np.array(X)
    adata = anndata.AnnData(X)
    adata.obsm['spatial'] = pos
    adata.obs['cell_type'] = graphs[0].cell_type
    adata.write("/gpfs/gibbs/pi/krishnaswamy_smita/hm638/SCGFM/data/adata/merfish-brain/"+slide_name+".h5ad")

In [4]:
graphs = torch.load(pt_files[0], weights_only=False)

In [None]:
graphs[0]

In [1]:
from glob import glob

In [None]:
glob("/gpfs/gibbs/pi/krishnaswamy_smita/hm638/SCGFM/data/adata/space-gm/charville/*")

In [5]:
import anndata

In [15]:
adata= anndata.read_h5ad( '/gpfs/gibbs/pi/krishnaswamy_smita/hm638/SCGFM/data/adata/space-gm/charville/Charville_c001_v001_r001_reg017.h5ad',
)

In [None]:
adata.uns

In [2]:
from glob import glob

In [None]:
glob("/gpfs/gibbs/pi/krishnaswamy_smita/hm638/SCGFM/data/adata/*")

In [23]:
import torch

def calculate_sinusoidal_pe(high_level_graph, low_level_graphs, pe_dim):
    # Step 1: Calculate cell positional encodings (Dist_i)
    num_nodes = high_level_graph.num_nodes
    cell_locations = high_level_graph.X  # Shape: [num_cells, 2]
    anchor_nodes = torch.randint(0, num_nodes, (pe_dim,))
    
    # Compute distance vectors (Dist_i) between each cell and the anchors
    # dist_matrix = torch.cdist(cell_locations, cell_locations[anchor_nodes])  # Shape: [num_cells, num_cells]
    x, y = cell_locations[:, 0], cell_locations[:, 1]  # Extract x and y coordinates
    half_dim = pe_dim // 2  # Half for x, half for y
    i = torch.arange(half_dim // 2, device=cell_locations.device)  # i indices
    j = torch.arange(half_dim // 2, device=cell_locations.device)  # j indices

    # Compute denominator 10000^(4i/d) and 10000^(4j/d)
    div_term_x = 10000 ** (4 * i / half_dim)
    div_term_y = 10000 ** (4 * j / half_dim)

    # Compute positional encodings
    pe_x = torch.zeros((cell_locations.shape[0], half_dim), device=cell_locations.device)
    pe_y = torch.zeros((cell_locations.shape[0], half_dim), device=cell_locations.device)

    pe_x[:, 0::2] = torch.sin(x[:, None] / div_term_x)  # sin terms for x
    pe_x[:, 1::2] = torch.cos(x[:, None] / div_term_x)  # cos terms for x

    pe_y[:, 0::2] = torch.sin(y[:, None] / div_term_y)  # sin terms for y
    pe_y[:, 1::2] = torch.cos(y[:, None] / div_term_y)  # cos terms for y

    # Concatenate positional encodings for x and y
    pe = torch.cat([pe_x, pe_y], dim=-1)

    # Use the distance matrix as positional encoding for the high-level graph
    high_level_graph.pe = pe

    # Step 2: Calculate gene positional encodings (RankNorm * Dist_i)
    gene_expressions = low_level_graphs.X.squeeze(-1)  # Shape: [num_genes]
    gene_batches = low_level_graphs.batch  # Shape: [num_genes]

    # Filter by batch and calculate RankNorm for each gene batch
    rank_norm = torch.zeros_like(gene_expressions).to(high_level_graph.X.device).float()
    unique_batches = gene_batches.unique()
    for b in unique_batches:
        batch_mask = (gene_batches == b)
        batch_gene_expressions = gene_expressions[batch_mask]
        batch_ranks = torch.argsort(-batch_gene_expressions, dim=0)  # Descending order
        batch_rank_norm = torch.linspace(0, 1, steps=batch_ranks.size(0), device=batch_gene_expressions.device)
        rank_norm[torch.where(batch_mask)[0]] = batch_rank_norm[torch.argsort(batch_ranks)]

    div_term = 10000 * (2 * torch.arange(0, pe_dim, 2, device=rank_norm.device).float() /half_dim)
    print(div_term)
    gene_sinusoidal_pe = torch.zeros(rank_norm.size(0), pe_dim, device=rank_norm.device)
    gene_sinusoidal_pe[:, 0::2] = torch.sin(rank_norm.unsqueeze(-1) * div_term)
    gene_sinusoidal_pe[:, 1::2] = torch.cos(rank_norm.unsqueeze(-1) * div_term)
    # Set low-level graph PE
    low_level_graphs.pe = gene_sinusoidal_pe.to(high_level_graph.X.device)

    return high_level_graph, low_level_graphs


In [3]:

import torch
from sklearn.model_selection import train_test_split
from glob import glob
from utils.dataloader import create_dataloader_ddp
import numpy as np
from tqdm import tqdm
from torch_scatter import scatter
from torch_scatter import scatter_add

from torch_geometric.nn.pool import global_mean_pool

import torch_geometric as tg
import networkx as nx

In [4]:
import torch
import torch_geometric.utils as pyg_utils
from matplotlib import pyplot as plt
data = torch.load("/gpfs/gibbs/project/ying_rex/hm638/SCGFM/data/pretraining/vizgen_preprocessed/HumanBreastCancerPatient1_113.pt")
dataloader = create_dataloader_ddp(data, 10, 0, 1)
for high_level_subgraph, low_level_batch, batch_idx in (dataloader):
    pass

  data = torch.load("/gpfs/gibbs/project/ying_rex/hm638/SCGFM/data/pretraining/vizgen_preprocessed/HumanBreastCancerPatient1_113.pt")


In [24]:
high_level_graph, low_level_graphs = calculate_sinusoidal_pe(high_level_subgraph, low_level_batch, 64)

tensor([    0.,  1250.,  2500.,  3750.,  5000.,  6250.,  7500.,  8750., 10000.,
        11250., 12500., 13750., 15000., 16250., 17500., 18750., 20000., 21250.,
        22500., 23750., 25000., 26250., 27500., 28750., 30000., 31250., 32500.,
        33750., 35000., 36250., 37500., 38750.])


In [16]:
high_level_graph, low_level_graphs = calculate_sinusoidal_pe(high_level_subgraph, low_level_batch, 64)

tensor([1.0000e+00, 7.4989e-01, 5.6234e-01, 4.2170e-01, 3.1623e-01, 2.3714e-01,
        1.7783e-01, 1.3335e-01, 1.0000e-01, 7.4989e-02, 5.6234e-02, 4.2170e-02,
        3.1623e-02, 2.3714e-02, 1.7783e-02, 1.3335e-02, 1.0000e-02, 7.4989e-03,
        5.6234e-03, 4.2170e-03, 3.1623e-03, 2.3714e-03, 1.7783e-03, 1.3335e-03,
        1.0000e-03, 7.4989e-04, 5.6234e-04, 4.2170e-04, 3.1623e-04, 2.3714e-04,
        1.7783e-04, 1.3335e-04])


In [25]:
low_level_graphs.pe

tensor([[ 0.0000,  1.0000, -0.0835,  ..., -0.8054, -0.5230, -0.8523],
        [ 0.0000,  1.0000, -0.1259,  ..., -0.7987,  0.6978, -0.7163],
        [ 0.0000,  1.0000, -0.1734,  ...,  0.4938,  0.7709,  0.6369],
        ...,
        [ 0.0000,  1.0000, -0.3045,  ..., -0.9898,  0.1658, -0.9862],
        [ 0.0000,  1.0000, -0.1189,  ..., -0.9077,  0.5243, -0.8515],
        [ 0.0000,  1.0000, -0.1471,  ..., -0.2803,  0.9907, -0.1357]])

In [17]:
low_level_graphs.pe

tensor([[2.3399e-01, 9.7224e-01, 1.7619e-01,  ..., 1.0000e+00, 3.1495e-05,
         1.0000e+00],
        [3.4926e-01, 9.3702e-01, 2.6437e-01,  ..., 1.0000e+00, 4.7578e-05,
         1.0000e+00],
        [4.7280e-01, 8.8117e-01, 3.6096e-01,  ..., 1.0000e+00, 6.5671e-05,
         1.0000e+00],
        ...,
        [7.6714e-01, 6.4148e-01, 6.0970e-01,  ..., 1.0000e+00, 1.1660e-04,
         1.0000e+00],
        [3.3036e-01, 9.4386e-01, 2.4980e-01,  ..., 1.0000e+00, 4.4897e-05,
         1.0000e+00],
        [4.0510e-01, 9.1427e-01, 3.0770e-01,  ..., 1.0000e+00, 5.5619e-05,
         1.0000e+00]])