In [1]:
%load_ext autoreload
%autoreload 2

# Segger Minimal Example

In [2]:
from segger.data.io import XeniumSample
from segger.training.train import LitSegger
from segger.training.segger_data_module import SeggerDataModule
from lightning.pytorch.loggers import CSVLogger
from pytorch_lightning import Trainer
from pathlib import Path

In [3]:
from pytorch_lightning.plugins.environments import SLURMEnvironment
SLURMEnvironment.detect = lambda: False

## Create Dataset

Explain things here

In [147]:
xenium_data_dir = Path('../../dev/tutorial/xenium_data/')
segger_data_dir = Path('../../dev/tutorial/segger_data/')

In [148]:
# Setup Xenium sample to create dataset
xs = XeniumSample(verbose=True)
xs.set_file_paths(
    transcripts_path=xenium_data_dir / 'transcripts.parquet',
    boundaries_path=xenium_data_dir / 'nucleus_boundaries.parquet',
)

Set transcripts file path to ../../dev/tutorial/xenium_data/transcripts.parquet
Set boundaries file path to ../../dev/tutorial/xenium_data/nucleus_boundaries.parquet


In [171]:
!rm -r ../../dev/tutorial/segger_data/*

In [172]:
try:
    xs.save_dataset_for_segger(
        processed_dir=segger_data_dir,
        r_tx=3,
        k_tx=15,
        receptive_field={'k_bd': 3, 'dist_bd': 15,'k_tx': 15, 'dist_tx': 3},
        x_size=250,
        y_size=250,
        d_x=250,
        d_y=250,
        margin_x=10,
        margin_y=10,
    )
except AssertionError as err:
    print(f'Dataset already exists at {segger_data_dir}')

[########################################] | 100% Completed | 31.21 s


## Train Segger Model

Explain things here

In [5]:
# Base directory to store Pytorch Lightning models
models_dir = Path('../../dev/tutorial/models/')

In [6]:
# Initialize the Lightning model
metadata = (["tx", "bd"], [("tx", "belongs", "bd"), ("tx", "neighbors", "tx")])
ls = LitSegger(
    num_tx_tokens=500,
    init_emb=8,
    hidden_channels=64,
    out_channels=16,
    heads=4,
    num_mid_layers=1,
    aggr='sum',
    metadata=metadata,
)

  torch.has_cuda,
  torch.has_cudnn,
  torch.has_mps,
  torch.has_mkldnn,


In [174]:
# Initialize the Lightning data module
dm = SeggerDataModule(
    data_dir=segger_data_dir,
    batch_size=1,  # explain other options above
    num_workers=4,  # explain other options above
)

In [8]:
# Initialize the Lightning trainer
trainer = Trainer(
    accelerator='cuda',  # explain other options above
    strategy='auto',
    precision='16-mixed',
    devices=1,  # explain other options above
    max_epochs=100,  # explain other options above
    default_root_dir=models_dir,
    logger=CSVLogger(models_dir),
)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [200]:
import pandas as pd
import numpy as np
import torch
import shapely
from geopandas import GeoSeries
from segger.data.io import SpatialTranscriptomicsSample

In [202]:
dm.setup()  # LightningDataModule
nan_summary = pd.DataFrame()
nan_ids = []
for i, batch in enumerate(dm.train_dataloader()):
    cell_ids = batch['bd']['id'][0]
    nan_summary.loc[i, 'Total'] = len(cell_ids)
    for key in ['x', 'pos']:
        is_nan = batch['bd'][key].isnan().any(1)
        nan_summary.loc[i, f'NaN in {key}'] = is_nan.sum().item()
        nan_ids.append(cell_ids[is_nan].astype(int))

nan_ids = np.hstack(nan_ids)
nan_summary.index.name = 'batch'
nan_summary.astype(int).head()

Processing...
Done!


Unnamed: 0_level_0,Total,NaN in x,NaN in pos
batch,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
0,315,1,1
1,169,1,1
2,116,2,2
3,400,2,2
4,281,3,3


In [204]:
nan_ids[:10]

array([104104, 104104,  96337,  96337, 105091, 107718, 105091, 107718,
       106365, 106730])

In [207]:
xs.verbose = False
df = pd.read_parquet(xs.boundaries_path)
cell_ids = [int(i) for i in ['107025', '107671', '107712']]
boundaries_df = df.loc[df['cell_id'].isin(nan_ids)]
bd_gdf = xs.compute_boundaries_geometries(boundaries_df)
bd_gdf.head()

Unnamed: 0,geometry,cell_id,centroid_x,centroid_y,area,convexity,elongation,circularity
0,"POLYGON ((3261.875 3800.625, 3261.78001 3800.6...",96043,3262.394341,3805.044646,47.560525,1.002316,16.463885,2.163828
1,"POLYGON ((3265.9126 3836.75, 3265.80472 3836.7...",96063,3264.154193,3844.47605,133.714509,1.03096,17.605377,1.656342
2,"POLYGON ((3101.6499 3734.11255, 3101.55491 373...",96144,3102.560951,3737.868083,49.990781,1.0,16.025747,2.654613
3,"POLYGON ((3261.44995 3479.11255, 3261.34289 34...",96333,3261.160832,3481.926085,25.388122,1.001656,16.0,2.547631
4,"POLYGON ((3261.44995 3517.1499, 3261.35733 351...",96337,3261.2478,3520.710203,27.774425,1.0,16.646432,2.179706


No precomputed polygons provided. Computing polygons from boundaries with a scale factor of 1.0.
Adding centroids to the polygons...
Polygons are available. Proceeding with geometrical computations.
Computing area...
Computing convexity...
Computing elongation...
Computing circularity...
Geometrical computations completed.


Unnamed: 0,geometry,cell_id,centroid_x,centroid_y,area,convexity,elongation,circularity
0,"POLYGON ((3509.86255 3656.55005, 3509.76453 36...",107025,3513.399805,3664.016222,88.17845,1.009637,18.672552,1.382146
1,"POLYGON ((3329.44995 3755.36255, 3329.3471 375...",107671,3327.914109,3762.51844,86.371948,1.009086,17.894785,1.584323
2,"POLYGON ((3510.5 3486.55005, 3510.39953 3486.5...",107712,3511.994986,3492.59325,63.488113,1.00239,17.848407,1.608155


In [130]:
torch.as_tensor(bd_gdf[['centroid_x', 'centroid_y']].values.astype(float))

tensor([[3926.6391, 3963.4658],
        [3936.1926, 3966.2522],
        [3896.9759, 3964.2134],
        [3900.3009, 3968.2333],
        [3923.7249, 3970.0328],
        [3893.7169, 3970.7377],
        [3929.4759, 3972.1830],
        [3911.8555, 3975.3812],
        [3906.8622, 3976.0429],
        [3936.7611, 3975.9349],
        [3899.0634, 3975.4511],
        [3916.9490, 3976.2784],
        [3920.0802, 3978.0556],
        [3973.5543, 3978.7783],
        [3913.1479, 3979.9234],
        [3926.7720, 3980.1006],
        [3934.2694, 3981.2204],
        [3916.6300, 3983.0027],
        [3920.8678, 3983.8628],
        [3907.8610, 3983.4038],
        [3988.4076, 3982.8885],
        [3897.2974, 3986.0511],
        [3926.2433, 3986.9642],
        [3917.8334, 3987.9329],
        [3910.9940, 3988.7892],
        [3921.9359, 3990.1196],
        [3927.7383, 3992.1927],
        [3905.3212, 3991.0185],
        [3932.1959, 3994.1976],
        [3901.0472, 3994.7502],
        [3967.5587, 3996.4936],
        

In [121]:
bd_gdf

Unnamed: 0,geometry,cell_id,centroid_x,centroid_y,area,convexity,elongation,circularity
0,"POLYGON ((3888.5376 3927.0625, 3888.43323 3927...",105853,3888.221029,3929.647922,24.439403,1.002589,16.269334,2.448092
1,"POLYGON ((3889.6001 3881.375, 3889.49718 3881....",106135,3888.51911,3888.255852,92.111427,1.007715,17.28246,1.810704


In [19]:
# Forward pass to get the logits
z = ls.model(batch.x_dict, batch.edge_index_dict)
z['bd']  # NaN values are exlusively in z['bd']

tensor([[ 0.0946,  0.0508, -0.2296,  ..., -0.1565,  0.1716, -0.0835],
        [ 0.0906,  0.0560, -0.2173,  ..., -0.1641,  0.1864, -0.0926],
        [ 0.0785,  0.0415, -0.2433,  ..., -0.1505,  0.1540, -0.0480],
        ...,
        [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
        [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
        [    nan,     nan,     nan,  ...,     nan,     nan,     nan]],
       grad_fn=<DivBackward0>)

In [29]:
import torch
import numpy as np
frac = torch.isnan(z['bd']).all(1).sum() / z['bd'].shape[0]
print(f"Fraction of boundaries which are all NaN: {np.around(frac, 2)}")
frac = torch.isnan(z['bd']).any(1).sum() / z['bd'].shape[0]
print(f"Fraction of boundaries which have any NaN: {np.around(frac, 2)}")

Fraction of boundaries which are all NaN: 0.75
Fraction of boundaries which have any NaN: 0.75
