### Running Segger

This notebook runs the complete Segger pipeline end-to-end from a single YAML configuration file (`config.yaml`). It will:

- Create and validate the dataset  
- Train the GNN model with your specified parameters  
- Predict edges in the resulting graph  

Below you will find:  
1. A description of the configuration parameters  
2. A minimal code snippet to execute the pipeline  


#### 1. Config YAML Documentation

##### create_dataset

- **save_dir** (DirectoryPath): Directory for saving generated datasets; created if parent exists

- **transcripts_parquet** (FilePath): Path to the transcripts Parquet file

- **boundaries_parquet** (FilePath): Path to the boundaries Parquet file (GeoParquet polygons)

- **transcripts_feature_name** (str): Column name for transcript feature labels

- **transcripts_cell_id** (str): Column name for transcript cell IDs

- **transcripts_x** (str): Column name for transcript x-coordinates

- **transcripts_y** (str): Column name for transcript y-coordinates

- **transcripts_k** (PositiveInt): Number of neighbors for transcript graph

- **transcripts_dist** (PositiveFloat): Max distance for transcript neighbor search

- **boundaries_k** (PositiveInt): Number of neighbors for boundary graph

- **boundaries_dist** (PositiveFloat): Max distance for boundary neighbor search

- **tile_margin** (PositiveFloat): Margin around each tile

- **max_cells_per_tile** (PositiveInt, optional): Max cells allowed per tile

- **max_transcripts_per_tile** (PositiveInt, optional): Max transcripts allowed per tile

- **fraction_train** (PositiveFloat < 1): Fraction of data for training

- **fraction_test** (PositiveFloat < 1): Fraction of data for testing

- **fraction_val** (PositiveFloat < 1): Fraction of data for validation

- **n_workers** (int ≥ –1): number of worker processes (–1 = all cores)


##### train

- **save_dir** (DirectoryPath): Directory for saving training outputs; created if parent exists

- **checkpoint_path** (FilePath, optional): Path to resume training from a checkpoint

- **in_channels** (PositiveInt): Number of input channels

- **hidden_channels** (PositiveInt): Number of hidden channels

- **out_channels** (PositiveInt): Number of output channels

- **n_mid_layers** (PositiveInt): Number of middle GNN layers

- **n_heads** (PositiveInt): Number of attention heads

- **gene_embedding_weights** (FilePath, optional): CSV of pretrained gene embeddings; checked for format

- **learning_rate** (PositiveFloat): Learning rate

- **batch_size** (PositiveInt): Batch size

- **n_workers** (int ≥ –1): number of worker processes (–1 = all cores)

- **max_transcripts_k** (int, optional): Max k for transcript neighbors during training

- **max_transcripts_dist** (float, optional): Max distance for transcript neighbors during training

- **negative_edge_sampling_ratio** (PositiveInt): Ratio for negative edge sampling

- **n_epochs** (PositiveInt): Number of training epochs

- **random_seed** (int, optional): Seed for reproducibility

##### predict

- **save_dir** (DirectoryPath): Directory for saving prediction outputs; created if parent exists

- **receptive_field_k** (PositiveInt): Number of neighbors for receptive field

- **receptive_field_dist** (PositiveFloat): Max distance for receptive field

- **min_score** (float [0–1]): Minimum score threshold for predicted edges

#### 2. Minimal Segger Pipeline

In [None]:
# Remove SLURM environment autodetect
from lightning.pytorch.plugins.environments import SLURMEnvironment
SLURMEnvironment.detect = lambda: False

# Warnings
import warnings
warnings.filterwarnings('ignore', category=FutureWarning)

# Segger imports
import logging
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import CSVLogger

from segger.config.utils import (
    model_from_config,
    predict_from_config,
    ist_sample_from_config,
    data_module_from_config,
)
from segger.config import SeggerConfig

overwrite = False
verbose = True

In [None]:
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("Segger Pipeline")
for log in [
    "lightning",
    "lightning.pytorch.utilities.rank_zero",
    "lightning.pytorch.accelerators.cuda",
]:
    logging.getLogger(log).setLevel(logging.WARNING)

In [None]:
# 0. Load configuration
config_path = "../path/to/your/config.yaml"
config = SeggerConfig.from_yaml(config_path)

In [None]:
# 1. Dataset creation
logger.info("Creating dataset...")
try:
    sample = ist_sample_from_config(config)
    sample.save(pbar=verbose, overwrite=overwrite)
except FileExistsError as e:
    msg = (
        f"Data directory already exists and is non<br>-empty. "
        "Skipping tile creation."
    )
    logger.warning(msg)

In [None]:
# 2. Training
logger.info("Training model...")
csv_logger = CSVLogger(config.train.save_dir)
trainer = Trainer(
    logger=csv_logger,
    max_epochs=config.train.n_epochs,
    enable_progress_bar=verbose,
    enable_model_summary=verbose,
)
trainer.fit(
    model=model_from_config(config),
    datamodule=data_module_from_config(config),
    ckpt_path=config.train.checkpoint_path,
)

In [None]:
# 3. Prediction
logger.info("Predicting edges...")
predict_from_config(config, pbar=verbose)