Copyright 2025 Georgia Institute of Technology
CSI-4CAST is a comprehensive framework for generating and evaluating Channel State Information (CSI) prediction models using 3GPP TR 38.901 channel models. The repository provides tools for large-scale dataset generation, model training, and comprehensive evaluation with support for high-performance computing environments (SLURM-based clusters) and local machines.
This framework is developed as part of our research paper CSI-4CAST: A Hybrid Deep Learning Model for CSI Prediction with Comprehensive Robustness and Generalization Testing. (A BibTeX entry for citation is provided at the end of this page.) The corresponding datasets are publicly available on our Hugging Face organization.
Important
- Paper updated: the latest version of the paper includes additional statistical baselines, expanded ablation experiments, and new discussion on the mechanism of each module in the proposed model.
- Model weights are now public: all model weights (proposed, baselines, and ablation variants) are available at CSI-4CAST/weights. See
z_artifacts/weights/info.mdfor download instructions and the expected directory structure. - General-purpose tuning framework included:
ahpt/is a standalone, reusable Optuna-based hyperparameter tuning framework designed for any deep learning project — it handles search-space definition, SLURM job orchestration, trial pruning, and dashboard monitoring out of the box. The CSI prediction integration insrc/cp/tune/demonstrates how to adapt it to a specific task. A sample Optuna study is provided inz_artifacts/outputs/ahpt/ablation/predictor/lstm_replace/. - Proposed model and configs provided: the proposed CSI-4CAST models for FDD and TDD are in
src/cp/models/proposed/with their default configurations (model_fdd_cfg.yml,model_tdd_cfg.yml). - All ablation models are included: the ablation study implementations are available under
src/cp/models/ablation/. - All baselines are included: statistical baselines in
src/cp/models/baseline/statistical/, learning-based baselines insrc/cp/models/baseline/learning/, and the no-prediction baseline insrc/cp/models/baseline/np.py. - Full experiment artifacts provided:
z_artifacts/outputs/testing/now contains complete results for all 9 models across both FDD and TDD scenarios, including computational overhead profiling, prediction performance for every test setting, consolidated analysis (NMSE & SE rankings), and a full visualization suite (line plots, radar charts, violin plots, and performance tables).
CSI-4CAST/
├── README.md # Project documentation
├── LICENSE # License information
├── env.yml # Conda environment configuration
├── pyproject.toml # Project configuration and linting rules
├── asset/ # Figures and paper PDF for documentation
│ ├── csi-4cast-Mar-26-2026.pdf
│ ├── sample_training.png
│ ├── sample_optuna_1.png
│ └── sample_optuna_2.png
├── ahpt/ # Reusable hyperparameter tuning utilities
│ └── base/
│ ├── config.py # Tuning configuration dataclasses
│ ├── submit_jobs.py # SLURM job submission helpers
│ ├── tune_obj.py # Optuna objective base classes
│ ├── tune_runner.py # Study orchestration utilities
│ ├── tune_space.py # Search-space helpers
│ └── utils.py # Shared tuning utilities
├── scripts/ # HPC job scripts (SLURM templates)
│ ├── computational_overhead.sh
│ ├── cp.sh
│ ├── data_gen.sh
│ ├── nd.sh
│ ├── param_estimation.sh
│ └── testing.sh
├── src/ # Source code
│ ├── data/ # Data generation module
│ ├── cp/ # Channel prediction and tuning module
│ │ ├── main.py # Training entry point
│ │ ├── config/ # Training configuration management
│ │ ├── dataset/ # Data modules
│ │ ├── loss/ # Loss functions
│ │ ├── tune/ # Tuning entry points and monitoring
│ │ │ ├── main.py
│ │ │ ├── monitor.py
│ │ │ ├── submit.py
│ │ │ ├── tune_obj.py
│ │ │ ├── tune_runner.py
│ │ │ └── worker.py
│ │ └── models/ # Model registry and implementations
│ │ ├── __init__.py
│ │ ├── common/ # Shared model components
│ │ ├── proposed/ # Proposed CSI-4CAST models
│ │ ├── ablation/ # Ablation study variants
│ │ └── baseline/ # Baseline models
│ │ ├── np.py # No-prediction baseline
│ │ ├── learning/ # CNN, RNN, STEMGNN, LLM4CP
│ │ └── statistical/ # AR, PAD, Wiener
│ ├── noise/ # Noise modeling and calibration
│ ├── testing/ # Evaluation and visualization
│ └── utils/ # Shared utilities
└── z_artifacts/ # Generated artifacts and outputs used by the codebase
├── config/ # Generated configuration files
├── data/ # Generated datasets and normalization stats
├── outputs/ # Training, tuning, noise, and testing outputs
└── weights/ # Trained checkpoints organized by scenario/model
The data generation module provides a complete pipeline for creating realistic CSI datasets using 3GPP channel models.
-
csi_simulator.py: Configures and implements the CSI simulator based on Sionna's 3GPP TR 38.901 channel model implementation. The simulator generates realistic channel responses for various propagation scenarios including different channel models, delay spreads, and mobility conditions. -
data_utils.py: Defines all simulation parameters and constants following the specifications detailed in the research paper. This includes antenna configurations, OFDM parameters, subcarrier arrangements, and dataset organization structures. -
generator.py: Employs the CSI simulator to generate comprehensive datasets including:- Training datasets for model development
- Regular testing datasets for standard and robustness evaluation
- Generalization testing datasets for generalization evaluation
The generator creates three types of CSI data files for each channel configuration:
H_U_hist.pt: Uplink historical CSI data (model input)H_U_pred.pt: Uplink prediction target CSI dataH_D_pred.pt: Downlink prediction target CSI data (for cross-link scenarios)
Data Dimensions:
- Antennas: 32 (4×4×2 dual-polarized BS antenna array)
- Time slots: 20 total (16 historical + 4 prediction)
- Subcarriers: 300 each for uplink and downlink (750 total with gap)
- Channel models: A, C, D (regular) / A, B, C, D, E (generalization)
- Delay spreads: 30-400 nanoseconds
- Mobility scenarios: 1-45 m/s
The channel prediction module provides a comprehensive framework for training CSI prediction models using PyTorch Lightning.
main.py: Training entry point that orchestrates the entire training processconfig/config.py: Configuration management system for training parameters, model settings, and hyperparametersdataset/data_module.py: PyTorch Lightning data modules for efficient data loading and preprocessingtune/: CSI prediction tuning workflow including SLURM submission, Optuna workers, and dashboard monitoringmodels/: Model architectures including:__init__.py: PREDICTORS registry for model selectioncommon/base.py: BaseCSIModel class that all models inherit fromproposed/: Proposed CSI-4CAST models for FDD and TDDablation/: Ablation models for denoiser, ARL, IDFT, embedding, and predictor studiesbaseline/: Baseline implementations including NP, AR, PAD, Wiener, CNN, RNN, STEMGNN, and LLM4CP
loss/loss.py: Custom loss functions optimized for CSI prediction tasks
The noise module handles realistic noise modeling and parameter calibration for comprehensive testing scenarios.
noise.py: Core noise generation functions implementing various realistic noise typesnoise_degree.py: Noise parameter calibration system that maps target SNRs to appropriate noise parametersnoise_testing.py: Noise testing utilities and configurationsresults/decide_nd.json: Pre-calibrated noise degree mapping for different noise types
The testing module provides comprehensive evaluation frameworks for CSI prediction models across multiple dimensions.
config.py: Testing configuration including model lists, scenarios, noise types, and job allocation settingsget_models.py: Model loading utilities with checkpoint path managementcomputational_overhead/: Performance profiling for measuring model computational requirementsprediction_performance/: Accuracy evaluation across thousands of testing scenariosresults/: Result processing pipeline including completion checking, data aggregation, and statistical analysisvis/: Comprehensive visualization suite generating line plots, radar charts, violin plots, and tables
The CSI-4CAST framework is designed to be flexible. Data generation, model training, noise calibration, results processing, and visualization all run locally. Prediction performance testing is designed for SLURM-based HPC clusters but can also be run locally by setting the SLURM_ARRAY_TASK_ID environment variable.
module load mamba/[mamba_version]
mamba env create -f env.yml
mamba activate csi-4cast-envThe code related to data generation is in the src/data folder and src/utils/data_utils.py file.
The data_utils.py file defines all constants which configure the Sionna simulator and data generation process. It is critical to understand and adjust these constants based on your setting before running any code.
For high-performance computing, use the template in scripts/data_gen.sh:
python3 -m src.data.generator --is_train # Generate training data, typical array size is 1-9
python3 -m src.data.generator # Generate regular test data, typical array size is 1
python3 -m src.data.generator --is_gen # Generate generalization test data, typical array size is 1-20For local/single-node execution, use debug mode for minimal datasets:
python3 -m src.data.generator --debug --is_train # Debug mode: minimal training data
python3 -m src.data.generator --debug # Debug mode: minimal test data
python3 -m src.data.generator --debug --is_gen # Debug mode: minimal generalization dataAll datasets used in the paper are publicly available on our Hugging Face organization. You can download them individually or in bulk using the provided helper scripts:
python3 z_artifacts/data/download.py # Download all datasets from Hugging Face
python3 z_artifacts/data/reconstruction.py \ # Reconstruct the original folder structure
--input-dir datasets --output-dir z_artifacts/dataSee z_artifacts/data/info.md for detailed download instructions, dataset naming conventions, and the expected directory layout.
After data generation (or downloading), compute normalization statistics using src/utils/norm_utils.py:
python3 -m src.utils.norm_utilsThe normalization stats will be saved in z_artifacts/data/stats/[fdd/tdd]/normalization_stats.pkl.
The model training framework is built on PyTorch Lightning and located in the src/cp folder.
Models should be defined under src/cp/models, inherit from BaseCSIModel in src/cp/models/common/base.py, and be registered in the PREDICTORS class in src/cp/models/__init__.py. See src/cp/models/baseline/learning/rnn.py for an example implementation.
Configure the training process in src/cp/config/config.py, then generate configuration files:
python3 -m src.cp.config.config --model [model_name] --output-dir [output_dir] --is_U2D [True/False] --config-file [yaml/json]Default output directory: z_artifacts/config/cp/[model_name]/
python3 -m src.cp.main --hparams_csi_pred [config_file]For HPC clusters, use scripts/cp.sh. Training outputs are saved in z_artifacts/outputs/[TDD/FDD]/[model_name]/[date_time]/ with checkpoints in ckpts/ and TensorBoard logs in tb_logs/.
View training progress:
tensorboard --logdir [output_directory]/tb_logsA sample training run is included in z_artifacts/outputs/TDD/RNN/ (RNN model, TDD scenario), containing the configuration snapshot (config_copy.yaml), full training log (result.log), model checkpoints (ckpts/), and TensorBoard event files (tb_logs/). A sample configuration file is also provided at z_artifacts/config/cp/rnn/tdd_rnn.yaml.
Below are the TensorBoard training and validation loss curves from this run:
The hyperparameter tuning workflow combines the reusable tuning utilities in ahpt/ with the CSI prediction-specific integration in src/cp/tune/. Each study is defined by a tuning YAML that points to a base training config, a search space module, and an output directory.
Use src/cp/models/ablation/predictor/lstm_replace/tuning.yaml as a reference. This file defines:
- the base training config in
config.yaml - the target model
ABL_LSTM_REPLACE_PRED - the search space in
src.cp.models.ablation.predictor.lstm_replace.tune_space - Optuna study settings such as
n_trials, pruning, and early stopping - the output directory
z_artifacts/outputs/ahpt/ablation/predictor/lstm_replace
python3 -m src.cp.tune.main \
--config src/cp/models/ablation/predictor/lstm_replace/tuning.yaml \
--num_workers 3You can also override SLURM resource settings at submission time:
python3 -m src.cp.tune.main \
--config src/cp/models/ablation/predictor/lstm_replace/tuning.yaml \
--num_workers 3 \
--time 08:00:00 \
--mem 64GEach study writes its artifacts under z_artifacts/outputs/ahpt/ablation/..., including the Optuna database, best configs, and study history.
To inspect the discovered studies without launching dashboards:
python3 -m src.cp.tune.monitor --list-onlyTo launch Optuna dashboards for the latest study databases:
python3 -m src.cp.tune.monitor --base-port 9000The monitor scans z_artifacts/outputs/ahpt/ablation/ for study.db files, keeps the latest run for each ablation, and serves one dashboard per study on consecutive ports.
A sample Optuna study database is included at z_artifacts/outputs/ahpt/ablation/predictor/lstm_replace/. Below are screenshots from the Optuna dashboard showing the optimization history, hyperparameter importance, trial timeline, and intermediate values:
Since realistic noise types cannot be directly defined by SNRs, calibrate noise parameters first:
python3 -m src.noise.noise_degreeResults are saved in z_artifacts/outputs/noise/noise_degree/[date_time]/decide_nd.json and copied to src/noise/results/decide_nd.json.
The model evaluation framework in src/testing provides comprehensive assessment across multiple dimensions.
Configure models and checkpoint paths in src/testing/config.py. Ensure checkpoints conform to the get_ckpt_path function in src/testing/get_models.py. Default checkpoint path: z_artifacts/weights/[tdd/fdd]/[model_name]/model.ckpt. The published experiment weights are also available from CSI-4CAST/weights.
python3 -m src.testing.computational_overhead.mainResults saved in z_artifacts/outputs/testing/computational_overhead/[date_time]/ for all configured models.
First, view the testing settings and SLURM array ID mapping:
python3 -m src.testing.configThis prints a table showing each model/scenario/test-type combination with its SLURM array range:
Line Model Scenario Test Type Jobs Array Range Combos Combos/Job
------------------------------------------------------------------------------
0 PAD TDD regular 1 1-1 162 162.0
1 PAD TDD robustness 3 2-4 486 162.0
2 PAD TDD generalization 19 5-23 3060 161.1
3 AR TDD regular 1 24-24 162 162.0
...
45 WIENER TDD regular 1 346-346 162 162.0
46 WIENER TDD robustness 3 347-349 486 162.0
47 WIENER TDD generalization 19 350-368 3060 161.1
For HPC clusters, submit the desired models using the array range from the table:
sbatch --array=1-23 scripts/testing.sh # Run all PAD tests
sbatch --array=162-184 scripts/testing.sh # Run all MODEL TDD tests
sbatch --array=1-368 scripts/testing.sh # Run everythingFor local execution, set the SLURM_ARRAY_TASK_ID environment variable to the desired array ID from the table:
SLURM_ARRAY_TASK_ID=139 python3 -m src.testing.prediction_performance.main # Run MODEL FDD regularResults saved in z_artifacts/outputs/testing/prediction_performance/[model_name]/[scenario]/[test_type]/[slice_i]/[date_time]/.
Process all testing results with comprehensive analysis:
python3 -m src.testing.results.mainThis performs three steps:
- Check completion status of testing models
- Gather and aggregate all results into CSV files
- Post-process results for scenario-wise distributions based on NMSE and SE metrics
Example output from the completion check:
PER-MODEL SUMMARY:
------------------------------------------------------------
AR 3/3 settings done (3708/3708 rows)
CNN 6/6 settings done (7416/7416 rows)
LLM4CP 6/6 settings done (7416/7416 rows)
MODEL 6/6 settings done (7416/7416 rows)
NP 6/6 settings done (7416/7416 rows)
PAD 3/3 settings done (3708/3708 rows)
RNN 6/6 settings done (7416/7416 rows)
STEMGNN 6/6 settings done (7416/7416 rows)
WIENER 6/6 settings done (7416/7416 rows)
All complete: True
Results saved in:
z_artifacts/outputs/testing/results/completion_reports/[date_time]/z_artifacts/outputs/testing/results/gather/[date_time]/z_artifacts/outputs/testing/results/analysis/[nmse/se]/[date_time]/
Generate comprehensive visualizations (line plots, radar plots, violin plots, tables):
python3 -m src.testing.vis.mainResults saved in z_artifacts/outputs/testing/vis/[date_time]/[line/radar/violin/table]/.
The z_artifacts/outputs/testing/ directory contains complete results for all 9 models (AR, CNN, LLM4CP, MODEL, NP, PAD, RNN, STEMGNN, WIENER) across both FDD and TDD scenarios:
-
computational_overhead/— FLOPs, inference/training time, parameter counts, GPU memory, and energy consumption for every model (computational_overhead.csv) -
prediction_performance/— per-model, per-scenario raw results organized as{Model}/{Scenario}/{TestType}/slice_{i}/{timestamp}/result.csvModel Scenarios Test Types AR TDD regular, robustness, generalization CNN FDD, TDD regular, robustness, generalization LLM4CP FDD, TDD regular, robustness, generalization MODEL FDD, TDD regular, robustness, generalization NP FDD, TDD regular, robustness, generalization PAD TDD regular, robustness, generalization RNN FDD, TDD regular, robustness, generalization STEMGNN FDD, TDD regular, robustness, generalization WIENER FDD, TDD regular, robustness, generalization -
results/— consolidated analysis from all modelscompletion_reports/— per-model completion status (completion_status.csv)gather/— singleconsolidated_results.csvmerging all models and slicesanalysis/nmse/— NMSE-based pivot tables, rankings, and rank distributionsanalysis/se/— spectral efficiency-based pivot tables, rankings, and rank distributions
-
vis/— full visualization suiteType Path Description Line plots line/regular/NMSE and SE vs. noise degree for in-distribution performance (FDD & TDD) line/robustness/Performance under burst, package-drop, and phase-noise conditions line/generalization/Out-of-distribution performance across mobility scenarios Radar plots radar/Multi-dimensional ranking comparison ( combined_radar_fdd.pdf,combined_radar_tdd.pdf) with rank summary CSVsViolin plots violin/nmse/,violin/se/Distribution of NMSE and SE across regular, robustness, and generalization scenarios for FDD and TDD Tables table/nmse_cm/,table/nmse_ds/NMSE performance tables by channel model and delay spread (formatted CSVs + text summaries) table/se_cm/,table/se_ds/SE performance tables by channel model and delay spread
For detailed analysis and interpretation of these results, please refer to the paper (also available at asset/csi-4cast-Mar-26-2026.pdf).
If you use this framework in your research, please cite the corresponding paper:
@misc{cheng2025csi4casthybriddeeplearning,
title={CSI-4CAST: A Hybrid Deep Learning Model for CSI Prediction with Comprehensive Robustness and Generalization Testing},
author={Sikai Cheng and Reza Zandehshahvar and Haoruo Zhao and Daniel A. Garcia-Ulloa and Alejandro Villena-Rodriguez and Carles Navarro Manchón and Pascal Van Hentenryck},
year={2025},
eprint={2510.12996},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2510.12996},
}This project is licensed under the terms specified in the LICENSE file.


