In [None]:
%load_ext autoreload
%autoreload 2

# Segger Minimal Example

In [None]:
from segger.data.io import XeniumSample
from segger.training.train import LitSegger
from segger.training.segger_data_module import SeggerDataModule
from lightning.pytorch.loggers import CSVLogger
from pytorch_lightning import Trainer
from pathlib import Path
from lightning.pytorch.plugins.environments import LightningEnvironment

In [None]:
from pytorch_lightning.plugins.environments import SLURMEnvironment
SLURMEnvironment.detect = lambda: False

## Create Dataset

Explain things here

In [None]:
xenium_data_dir = Path('../../dev/tutorial/xenium_data/')
segger_data_dir = Path('../../dev/tutorial/segger_data/')

In [6]:
# Setup Xenium sample to create dataset
xs = XeniumSample(verbose=False)
xs.set_file_paths(
    transcripts_path=xenium_data_dir / 'transcripts.parquet',
    boundaries_path=xenium_data_dir / 'nucleus_boundaries.parquet',
)
xs.set_metadata()

In [14]:
!rm -r ../../dev/tutorial/segger_data/*

In [15]:
try:
    xs.save_dataset_for_segger(
        processed_dir=segger_data_dir,
        r_tx=3,
        k_tx=15,
        receptive_field={'k_bd': 3, 'dist_bd': 15,'k_tx': 15, 'dist_tx': 3},
        x_size=250,
        y_size=250,
        d_x=250,
        d_y=250,
        margin_x=10,
        margin_y=10,
    )
except AssertionError as err:
    print(f'Dataset already exists at {segger_data_dir}')

[########################################] | 100% Completed | 33.24 s


## Train Segger Model

Explain things here

In [5]:
# Base directory to store Pytorch Lightning models
models_dir = Path('../../dev/tutorial/models/')

In [44]:
# Initialize the Lightning model
metadata = (["tx", "bd"], [("tx", "belongs", "bd"), ("tx", "neighbors", "tx")])
ls = LitSegger(
    num_tx_tokens=500,
    init_emb=8,
    hidden_channels=64,
    out_channels=16,
    heads=4,
    num_mid_layers=1,
    aggr='sum',
    metadata=metadata,
)

In [9]:
# Initialize the Lightning data module
dm = SeggerDataModule(
    data_dir=segger_data_dir,
    batch_size=2,  # explain other options above
    num_workers=4,  # explain other options above
)

In [46]:
# Initialize the Lightning trainer
trainer = Trainer(
    accelerator='cuda',  # explain other options above
    strategy='auto',
    precision='16-mixed',
    devices=1,  # explain other options above
    max_epochs=100,  # explain other options above
    default_root_dir=models_dir,
    logger=CSVLogger(models_dir),
)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [47]:
# Fit model
trainer.fit(model=ls, datamodule=dm)

Processing...
Done!
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [MIG-9b1ee058-0b73-52d6-b909-b056af809b4b,MIG-e8861a3f-0f14-562d-ad70-c9f56ba1db06,MIG-e9c80ed6-e922-5f99-84ba-fd31a06f7f82,MIG-ddac4c2a-2ccf-5d86-be01-c9e27dd3cd76]

  | Name      | Type              | Params | Mode 
--------------------------------------------------------
0 | model     | GraphModule       | 12.6 K | train
1 | criterion | BCEWithLogitsLoss | 0      | train
--------------------------------------------------------
12.6 K    Trainable params
0         Non-trainable params
12.6 K    Total params
0.050     Total estimated model params size (MB)
36        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.


## Predict Segmentation Assignments

Explain things here

In [6]:
from segger.prediction.predict import predict, load_model

In [11]:
model_version = 44026  # from training output above
model_path = models_dir / 'lightning_logs' / f'version_{model_version}'
model = load_model(model_path / 'checkpoints')

In [33]:
dm.setup()

Processing...
Done!


In [34]:
batch = next(dm.train_dataloader().__iter__())

In [92]:
from segger.data.utils import get_edge_index, coo_to_dense_adj
import torch

receptive_field={'k_bd': 3, 'dist_bd': 15,'k_tx': 15, 'dist_tx': 3}

edge_index = get_edge_index(
    batch['tx'].pos[:, :2],
    batch['bd'].pos[:, :2],
    k=receptive_field['k_bd'],
    dist=receptive_field['dist_bd'],
    method='kd_tree',
)
batch['bd']['tx_field'] = coo_to_dense_adj(edge_index.T)

In [93]:
from_type = 'tx'
to_type = 'bd'

In [95]:
batch = batch.to("cuda")
y_hat = model.model(batch.x_dict, batch.edge_index_dict)

# Similarity of each 'from_type' to 'to_type' neighbors in embedding
nbr_idx = batch[from_type][f'{to_type}_field']
m = torch.nn.ZeroPad2d((0, 0, 0, 1))  # pad bottom with zeros
similarity = torch.bmm(
    m(y_hat[to_type])[nbr_idx],    # 'to' x 'from' neighbors x embed
    y_hat[from_type].unsqueeze(-1) # 'to' x embed x 1
)  

/home/conda/feedstock_root/build_artifacts/libtorch_1715185017593/work/aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [52,0,0], thread: [64,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
/home/conda/feedstock_root/build_artifacts/libtorch_1715185017593/work/aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [52,0,0], thread: [65,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
/home/conda/feedstock_root/build_artifacts/libtorch_1715185017593/work/aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [52,0,0], thread: [66,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
/home/conda/feedstock_root/build_artifacts/libtorch_1715185017593/work/aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [52,0,0], thread: [67,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
/home/conda/feed

RuntimeError: CUDA error: device-side assert triggered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [95]:
from segger.data.utils import get_edge_index

In [159]:
batch = batch.to('cuda')

In [163]:
edge_index = edge_index.to('cuda')

In [154]:
nbrs

tensor([[  487, 54800],
        [  487, 46648],
        [  487, 43892],
        [  487, 51152],
        [  487, 49180],
        [  487, 55531],
        [  487, 58221],
        [  487, 70179],
        [   -1,    -1],
        [   -1,    -1]])

In [112]:
batch['bd'].pos.shape

torch.Size([488, 2])

In [113]:
batch['tx'].pos.shape

torch.Size([72169, 2])

In [94]:
batch['bd'].pos[:, :2].shape

torch.Size([488, 2])

In [86]:
segmentation = predict(
    model,
    dm.train_dataloader(),
    score_cut = 0.5,
    use_cc = False,
)


  0%|          | 0/7 [00:00<?, ?it/s][A


KeyError: 'bd_field'