Skip to content

Emvlt/GLM

Repository files navigation

GLM: Graph neural network for Line Manifolds

Code for the paper "Improving the Generalisation of Learned Reconstruction Frameworks" - A deep learning approach for CT image reconstruction using Graph Neural Networks and CNN modules.

Overview

This project implements a learned reconstruction framework for computed tomography (CT) imaging that combines:

  • Graph neural network for Line Manifolds (GLM): GNNs that process sinogram data using information about the acquisition geometry
  • Convolutional Neural Networks: For both sinogram and image domain processing
  • Pseudo-inverse operators: Including backprojection and filtered backprojection
  • End-to-end training: Joint optimization of sinogram processing and image reconstruction

The framework is designed to work with the 2DeteCT dataset and uses DVC (Data Version Control) for experiment tracking and pipeline management.

Features

  • Multi-stage training pipeline (preprocessing → pretraining → training)
  • Distributed training support with PyTorch DDP
  • Graph-based sinogram processing using PyTorch Geometric
  • Integration with ODL (Operator Discretization Library) for tomographic operators
  • Experiment tracking with DVCLive
  • Support for various pseudo-inverse methods (backprojection, filtered backprojection)

Installation

Prerequisites

  • Python 3.11 or higher
  • CUDA-capable GPU (recommended for training)
  • uv package manager

Step-by-step Installation

  1. Clone the repository

    git clone https://github.com/Emvlt/GLM
    cd glm
  2. Install uv (if not already installed)

    curl -LsSf https://astral.sh/uv/install.sh | sh
  3. Install dependencies

    The project uses uv for dependency management. Dependencies will be automatically installed when running commands with uv run:

    uv sync
  4. Verify installation

    uv run python src/glm/install_test.py

    This should print version information for all major dependencies:

    • ODL (a non-official version on my fork): Geometry, Tomographic Operator, 2DeteCT dataloader, experimental code for Graph export and Surfaces handling
    • ASTRA Toolbox: Tomographic Operator Backend
    • PyTorch: Deep Learning
    • PyTorch Geometric: Geometric Deep Learning
    • imageio: raw data IO

Key Dependencies

The project automatically installs:

  • astra-toolbox: For CT reconstruction algorithms
  • torch: Deep learning framework
  • torch-geometric: Graph neural network library
  • odl: Custom fork for tomographic operators (installed from GitHub)
  • dvc & dvclive: Experiment tracking and pipeline management
  • matplotlib, imageio: Visualization and image I/O

Dataset Setup

This project uses the 2DeteCT dataset.

⚠️ You are expected to download the dataset (from Zenodo)and unzip the files by yourself. Note that due to the size of the dataset it is stored on different URLs.

The expected directory structure is:

datasets/
├── raw/
│   └── 2detect/
│       ├── slice00001/
│       │   ├── mode2/
│       │   │   ├── dark.tif
│       │   │   ├── flat1.tif
│       │   │   ├── flat2.tif
│       │   │   ├── sinogram.tif
│       │   │   └── reconstruction.tif
│       │   │   └── segmentation.tif
│       │   └── ...
│       └── ...
└── processed/
    └── 2detect/
        └── (generated by preprocessing)

⚠️ For the DVC to work, you must update the paths in params.yaml:

  • data.raw_path: Path to raw dataset
  • data.processed_path: Path for processed data

Usage

Full Training Pipeline

The complete pipeline consists of three stages managed by DVC:

  1. Data Preprocessing

    uv run dvc repro prepare_training

    Preprocesses raw sinograms and reconstructions from the 2DeteCT dataset.

  2. Sinogram Model Pretraining

    uv run dvc repro pretraining

    Pretrains the sinogram processing model (GLM or CNN) in a self-supervised manner for just one epoch.

  3. End-to-End Training

    uv run dvc repro training

    Trains the complete reconstruction pipeline including sinogram processing, pseudo-inverse, and image refinement.

Configuration

Main configuration file: params.yaml

Key Parameters

Data paths:

data:
  raw_path: /path/to/raw/2detect
  processed_path: /path/to/processed/2detect

Pretraining hyperparameters:

pretrain_parameters:
  hyperparameters:
    learning_rate: 5e-4
    epochs: 1
    batch_size: 8
    downsampling: 1  # Angle downsampling factor
  active_model: GLM  # or sinogram_CNN

Training hyperparameters:

train_parameters:
  hyperparameters:
    learning_rate: 5e-5
    epochs: 40
    batch_size: 8
  active_pseudo_inverse: filtered_backprojection
  active_image_model: image_CNN

Project Structure

glm/
├── src/glm/
│   ├── models/
│   │   ├── gnn.py              # Graph neural network modules
│   │   ├── cnn.py              # Convolutional neural networks
│   │   └── utils.py            # Model loading utilities
│   ├── dataset.py              # Dataset and dataloader
│   ├── preprocess_2detect.py   # Data preprocessing
│   ├── pretrain.py             # Sinogram model pretraining
│   ├── train.py                # End-to-end training
│   ├── run_demo.py             # Demo script
│   └── utils.py                # General utilities
├── params.yaml                 # Configuration file
├── dvc.yaml                    # DVC pipeline definition
├── dvc.lock                    # DVC pipeline lock file
└── pyproject.toml              # Project dependencies

Experiment Tracking

The project uses DVCLive for experiment tracking. Metrics and plots are saved in the dvclive/ directory:

  • Training/validation PSNR
  • Loss curves
  • Sinogram and reconstruction visualizations

View experiments:

dvc plots show

Model Outputs

Trained models are saved in:

  • src/glm/saved_models/pretrained_sinogram_model.pt - Pretrained sinogram processor
  • src/glm/saved_models/end_to_end_model.pt - Complete reconstruction model

Troubleshooting

CUDA out of memory:

  • Reduce batch_size in params.yaml
  • Reduce n_channels in model parameters
  • Increase downsampling to use fewer projection angles

Data loading errors:

  • Verify dataset paths in params.yaml
  • Check that raw data follows the expected directory structure
  • Ensure preprocessing completed successfully

Import errors:

  • Run uv sync to ensure all dependencies are installed
  • Verify installation with uv run python src/glm/install_test.py

Citation

If you use this code in your research, please cite:

@misc{valat2025improvinggeneralisationlearnedreconstruction,
      title={Improving the Generalisation of Learned Reconstruction Frameworks}, 
      author={Emilien Valat and Ozan Öktem},
      year={2025},
      eprint={2511.12730},
      archivePrefix={arXiv},
      primaryClass={eess.IV},
      url={https://arxiv.org/abs/2511.12730}, 
}

License

🔑 Apache License 2.0

Contact

For questions or issues, please open an issue on GitHub or contact:

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages