Skip to content

Nathan fm#76

Merged
renierts merged 104 commits into
foundation_modelfrom
nathan_fm
May 10, 2026
Merged

Nathan fm#76
renierts merged 104 commits into
foundation_modelfrom
nathan_fm

Conversation

@nathanchenseanwalter
Copy link
Copy Markdown
Collaborator

No description provided.

nathanchenseanwalter and others added 30 commits February 17, 2026 18:29
…d training

- Replaced `srun pixi run python` with `srun pixi run torchrun` for improved distributed execution.
- Added parameters for distributed training configuration in the training script.
- Included a comment in the UnimodalTrainer class to indicate the integration of distributed training support.
…d training

- Replaced `srun pixi run python` with `srun pixi run torchrun` for improved distributed execution.
- Added parameters for distributed training configuration in the training script.
- Included a comment in the UnimodalTrainer class to indicate the integration of distributed training support.
- Added detailed sections in the README for data storage locations and Flash Attention installation instructions.
- Removed unused SpectrogramResLSTM classes and related imports from modality models.
- Introduced new Conv3dEncoderBlock and Conv3dDecoderBlock classes in the SpectrogramBaselineEncoder and SpectrogramBaselineDecoder for improved architecture.
- Enhanced SpectrogramBaselineEncoder and Decoder to support LSTM integration and modular block structures.
- Clarified instructions for activating the environment.
- Improved wording on Flash Attention benefits.
- Updated the description for the pre-downloaded wheel URL for Princeton clusters.
- Reorganized environment setup instructions for better readability.
- Enhanced the explanation of Flash Attention's benefits and usage.
- Updated links for Flash Attention wheel downloads to be more user-friendly.
- Corrected spelling in the "Environment Setup" section and clarified the use of Pixi for environment management.
- Updated section titles for consistency, changing "Datas" to "Data."
- Removed the unused CER model implementation from the codebase to streamline the project.
- Re-added .streamlit/secrets.toml to the .gitignore file to prevent tracking of sensitive information.
- Removed previous entries for .pixi/* while retaining .pixi/config.toml to maintain necessary configuration files.
…sses

- Updated `__init__.py` to include `NullDrawer` in the module exports.
- Refactored `distributed.py` to improve DDP wrapping and added `unwrap` method.
- Enhanced `drawing.py` with a new `DrawerProtocol` and implemented `NullDrawer` for no-op drawing.
- Modified `DefaultDrawer` to streamline loss tracking and visualization.
- Added new scripts for SLURM job submissions for various training configurations.
- Introduced `spectrogram_tf_only.py` model with encoder-decoder architecture for spectrogram processing.
- Implemented new metrics for PSNR and SSIM in `metrics.py`.
- Created tracking utilities in `tracking.py` for better training progress monitoring and logging.
…nd gated target loss (also returned to normalized spectrum, non-normalized causing issues)
…on framework

- Refactored VideoBaselineEncoder and VideoBaselineDecoder to improve architecture and maintainability.
- Introduced VideoBaselineAutoEncoder to encapsulate encoder and decoder functionality.
- Added Stage3Trainer for training multimodal prediction models, handling inputs, targets, and observations.
- Implemented MultimodalPredictionModel combining a fusion transformer with forecasting heads for each modality.
- Created CombinedPredictionLoss to compute losses for both token-space and observation-space predictions.
- Added training scripts for unimodal and multimodal models, including support for SLURM job arrays.
- Updated model factory to accommodate new model parameters and defaults.
Changed default hyperparameters in the models.
Added demo for profile reconstruction.
Added script for dataset standardization (has to be run once before model training to store normalization coefficients).
…, the wrong configuration was used to find the correct signal name.

Also, removed warning for duplicated tensor conversion.
…s and opening an H5 file prior to distributing the dataset across all workers.

Significant updates in the Fast time series baseline and actuator reconstruction classes.
The basic encoders are now all working.

Examples are in scripts.
- Model
- Optimizer state
- Scheduler state
- Current loss
- Current epoch

For the sake of continual training.
- Model
- Optimizer state
- Scheduler state
- Current loss
- Current epoch

For the sake of continual training.
…got to remove unused modalities. This follows the standard getitem function now.
Quick fix for the data standardization. Invalid values have to be ignored.
Fix in the function to create H5 files. bolo data does not have to be flipped anymore as the data is now stored in the correct format.
* Nathan fm (#53)

* chore: Update `pyproject.toml` to reorder authors, enhance README with environment setup instructions, and add validation notes in `validation.txt`. Refactor `dummy_model_2.py` for improved modality configuration and introduce `TextEncoder` enhancements in `text_baseline.py`.

* Refactor demo scripts to utilize new `Prediction4FusionModel` and `DictMSELoss`. Update `run_demo_2.py` and `run_demo_3.py` for improved model initialization and data handling. Enhance `TokamakH5Dataset` to handle degenerate signals and improve data extraction logic. Remove unused `latent_space.py` and integrate new modality fusion models in `modality_fusion.py`.

* Remove unused shot list configuration files and refactor trainer class to introduce MultimodalTrainer and UnimodalTrainer for improved training structure.

* Refactor modality models and trainer classes for improved structure and functionality. Removed unused TimeSeriesEncoder and Decoder, introduced FastTimeSeriesEncoder and SpectrogramAutoEncoder. Updated UnimodalTrainer to support logging and checkpoint management. Enhanced TokamakH5Dataset for better data handling and added checkpoint loading functionality in spectrogram reconstruction script.

* Add padding collate function and update training script for unimodal autoencoder

- Introduced `collate_fn_pad` to handle variable-length tensors in batches.
- Updated `train_unimodal_autoencoder.py` to use the new collate function.
- Modified `train_unimodal.sh` to include additional signal modalities for training.
- Added new autoencoder classes for fast time series and spatial profile modalities, ensuring output shape consistency with adaptive pooling.
- Enhanced video autoencoder implementation for better reconstruction quality.

* Remove spectrogram reconstruction script and refactor modality models

- Deleted `spectrogram_reconstruction.py` as part of the restructuring.
- Refactored modality models to introduce baseline versions for actuator, slow time series, fast time series, spatial profile, spectrogram, and video.
- Updated model registry and signal-to-model mappings to reflect new baseline architecture.
- Enhanced `TokamakH5Dataset` to support additional parameters for FFT and hop length.
- Improved training script for unimodal autoencoders to utilize new baseline models and added support for variable-length tensors.

* Update .gitignore to include pixi environments and add link to HSI-compression-benchmark in SpectrogramBaselineAutoEncoder docstring

* Remove unused shot list files and delete deprecated scripts for training and data handling

* Remove deprecated training scripts for CO2, ECE, MHR, and unimodal training

* Dev peter (#48)

* Removed the argument "batch_size" from the trainers.
Changed default hyperparameters in the models.
Added demo for profile reconstruction.
Added script for dataset standardization (has to be run once before model training to store normalization coefficients).

* Bugfix in the dataset class. When iterating over movie configurations, the wrong configuration was used to find the correct signal name.
Also, removed warning for duplicated tensor conversion.

* Added base script for video reconstruction. Copied from Aza's branch for debugging purposes.

* Added base script for video reconstruction. Copied from Aza's branch for debugging purposes.

* Minor changes in the example scripts. More preprocessing options for the dataset class.

* Fixed a bug where the dataset class failed when using multiple workers and opening an H5 file prior to distributing the dataset across all workers.

Significant updates in the Fast time series baseline and actuator reconstruction classes.

* Lots of bugfixes in the dataset, trainer, and models.
The basic encoders are now all working.

Examples are in scripts.

* Dev peter (#50)

* Removed the argument "batch_size" from the trainers.
Changed default hyperparameters in the models.
Added demo for profile reconstruction.
Added script for dataset standardization (has to be run once before model training to store normalization coefficients).

* Bugfix in the dataset class. When iterating over movie configurations, the wrong configuration was used to find the correct signal name.
Also, removed warning for duplicated tensor conversion.

* Added base script for video reconstruction. Copied from Aza's branch for debugging purposes.

* Added base script for video reconstruction. Copied from Aza's branch for debugging purposes.

* Minor changes in the example scripts. More preprocessing options for the dataset class.

* Fixed a bug where the dataset class failed when using multiple workers and opening an H5 file prior to distributing the dataset across all workers.

Significant updates in the Fast time series baseline and actuator reconstruction classes.

* Lots of bugfixes in the dataset, trainer, and models.
The basic encoders are now all working.

Examples are in scripts.

* Extended checkpointing - the trainer stores now:
- Model
- Optimizer state
- Scheduler state
- Current loss
- Current epoch

For the sake of continual training.

* Extended checkpointing - the trainer stores now:
- Model
- Optimizer state
- Scheduler state
- Current loss
- Current epoch

For the sake of continual training.

* Adapted the other reconstruction scripts to match the new API.

* Bugfix in the dataset class. When splitting inputs and targets, I forgot to remove unused modalities. This follows the standard getitem function now.

* Prepared an option to preprocess movies. This has to be fully integrated!!!

---------

Co-authored-by: Peter Steiner <61472983+renierts@users.noreply.github.com>

* Dev peter (#55)

* Removed the argument "batch_size" from the trainers.
Changed default hyperparameters in the models.
Added demo for profile reconstruction.
Added script for dataset standardization (has to be run once before model training to store normalization coefficients).

* Bugfix in the dataset class. When iterating over movie configurations, the wrong configuration was used to find the correct signal name.
Also, removed warning for duplicated tensor conversion.

* Added base script for video reconstruction. Copied from Aza's branch for debugging purposes.

* Added base script for video reconstruction. Copied from Aza's branch for debugging purposes.

* Minor changes in the example scripts. More preprocessing options for the dataset class.

* Fixed a bug where the dataset class failed when using multiple workers and opening an H5 file prior to distributing the dataset across all workers.

Significant updates in the Fast time series baseline and actuator reconstruction classes.

* Lots of bugfixes in the dataset, trainer, and models.
The basic encoders are now all working.

Examples are in scripts.

* Extended checkpointing - the trainer stores now:
- Model
- Optimizer state
- Scheduler state
- Current loss
- Current epoch

For the sake of continual training.

* Extended checkpointing - the trainer stores now:
- Model
- Optimizer state
- Scheduler state
- Current loss
- Current epoch

For the sake of continual training.

* Adapted the other reconstruction scripts to match the new API.

* Bugfix in the dataset class. When splitting inputs and targets, I forgot to remove unused modalities. This follows the standard getitem function now.

* Prepared an option to preprocess movies. This has to be fully integrated!!!

* Added a baseline fusion transformer for latent space prediction.
Quick fix for the data standardization. Invalid values have to be ignored.
Fix in the function to create H5 files. bolo data does not have to be flipped anymore as the data is now stored in the correct format.

---------

Co-authored-by: Nathaniel Chen <nathanchen1101@gmail.com>
renierts and others added 28 commits April 13, 2026 14:06
Moved prepare_data.py to scripts, added a batch script to do this on compute nodes.
Added more point names to the data fetching scripts for Omega.
Added docstring to the WelfordTensor class.
Updated modalities.yaml with the new point names added.
…tats. This is still not efficient enough and causes memory issues.
Bugfixes in the trainer.
Cosmetic changes in tracking.py
…_series_baseline.py to filterscope_baseline.py).

Updates in the dataset class: Clipping for log transform can go down to -.99 (sufficient because we subtract 1.0).
Updates in drawing.py: We can now draw all kinds of different plots (except for profiles for now). Added functionality to draw correlation plots, which is important for finding feature distributions.
Added masked loss functions to not consider out-of-range time slices for training.
…ted for both, linear and log10 scale.

Working on more accurate autoencoders for time-series and profiles.
* Removed the argument "batch_size" from the trainers.
Changed default hyperparameters in the models.
Added demo for profile reconstruction.
Added script for dataset standardization (has to be run once before model training to store normalization coefficients).

* Bugfix in the dataset class. When iterating over movie configurations, the wrong configuration was used to find the correct signal name.
Also, removed warning for duplicated tensor conversion.

* Added base script for video reconstruction. Copied from Aza's branch for debugging purposes.

* Added base script for video reconstruction. Copied from Aza's branch for debugging purposes.

* Minor changes in the example scripts. More preprocessing options for the dataset class.

* Fixed a bug where the dataset class failed when using multiple workers and opening an H5 file prior to distributing the dataset across all workers.

Significant updates in the Fast time series baseline and actuator reconstruction classes.

* Lots of bugfixes in the dataset, trainer, and models.
The basic encoders are now all working.

Examples are in scripts.

* Extended checkpointing - the trainer stores now:
- Model
- Optimizer state
- Scheduler state
- Current loss
- Current epoch

For the sake of continual training.

* Extended checkpointing - the trainer stores now:
- Model
- Optimizer state
- Scheduler state
- Current loss
- Current epoch

For the sake of continual training.

* Adapted the other reconstruction scripts to match the new API.

* Bugfix in the dataset class. When splitting inputs and targets, I forgot to remove unused modalities. This follows the standard getitem function now.

* Prepared an option to preprocess movies. This has to be fully integrated!!!

* Added a baseline fusion transformer for latent space prediction.
Quick fix for the data standardization. Invalid values have to be ignored.
Fix in the function to create H5 files. bolo data does not have to be flipped anymore as the data is now stored in the correct format.

* Foundation model (#56)

* Nathan fm (#53)

* chore: Update `pyproject.toml` to reorder authors, enhance README with environment setup instructions, and add validation notes in `validation.txt`. Refactor `dummy_model_2.py` for improved modality configuration and introduce `TextEncoder` enhancements in `text_baseline.py`.

* Refactor demo scripts to utilize new `Prediction4FusionModel` and `DictMSELoss`. Update `run_demo_2.py` and `run_demo_3.py` for improved model initialization and data handling. Enhance `TokamakH5Dataset` to handle degenerate signals and improve data extraction logic. Remove unused `latent_space.py` and integrate new modality fusion models in `modality_fusion.py`.

* Remove unused shot list configuration files and refactor trainer class to introduce MultimodalTrainer and UnimodalTrainer for improved training structure.

* Refactor modality models and trainer classes for improved structure and functionality. Removed unused TimeSeriesEncoder and Decoder, introduced FastTimeSeriesEncoder and SpectrogramAutoEncoder. Updated UnimodalTrainer to support logging and checkpoint management. Enhanced TokamakH5Dataset for better data handling and added checkpoint loading functionality in spectrogram reconstruction script.

* Add padding collate function and update training script for unimodal autoencoder

- Introduced `collate_fn_pad` to handle variable-length tensors in batches.
- Updated `train_unimodal_autoencoder.py` to use the new collate function.
- Modified `train_unimodal.sh` to include additional signal modalities for training.
- Added new autoencoder classes for fast time series and spatial profile modalities, ensuring output shape consistency with adaptive pooling.
- Enhanced video autoencoder implementation for better reconstruction quality.

* Remove spectrogram reconstruction script and refactor modality models

- Deleted `spectrogram_reconstruction.py` as part of the restructuring.
- Refactored modality models to introduce baseline versions for actuator, slow time series, fast time series, spatial profile, spectrogram, and video.
- Updated model registry and signal-to-model mappings to reflect new baseline architecture.
- Enhanced `TokamakH5Dataset` to support additional parameters for FFT and hop length.
- Improved training script for unimodal autoencoders to utilize new baseline models and added support for variable-length tensors.

* Update .gitignore to include pixi environments and add link to HSI-compression-benchmark in SpectrogramBaselineAutoEncoder docstring

* Remove unused shot list files and delete deprecated scripts for training and data handling

* Remove deprecated training scripts for CO2, ECE, MHR, and unimodal training

* Dev peter (#48)

* Removed the argument "batch_size" from the trainers.
Changed default hyperparameters in the models.
Added demo for profile reconstruction.
Added script for dataset standardization (has to be run once before model training to store normalization coefficients).

* Bugfix in the dataset class. When iterating over movie configurations, the wrong configuration was used to find the correct signal name.
Also, removed warning for duplicated tensor conversion.

* Added base script for video reconstruction. Copied from Aza's branch for debugging purposes.

* Added base script for video reconstruction. Copied from Aza's branch for debugging purposes.

* Minor changes in the example scripts. More preprocessing options for the dataset class.

* Fixed a bug where the dataset class failed when using multiple workers and opening an H5 file prior to distributing the dataset across all workers.

Significant updates in the Fast time series baseline and actuator reconstruction classes.

* Lots of bugfixes in the dataset, trainer, and models.
The basic encoders are now all working.

Examples are in scripts.

* Dev peter (#50)

* Removed the argument "batch_size" from the trainers.
Changed default hyperparameters in the models.
Added demo for profile reconstruction.
Added script for dataset standardization (has to be run once before model training to store normalization coefficients).

* Bugfix in the dataset class. When iterating over movie configurations, the wrong configuration was used to find the correct signal name.
Also, removed warning for duplicated tensor conversion.

* Added base script for video reconstruction. Copied from Aza's branch for debugging purposes.

* Added base script for video reconstruction. Copied from Aza's branch for debugging purposes.

* Minor changes in the example scripts. More preprocessing options for the dataset class.

* Fixed a bug where the dataset class failed when using multiple workers and opening an H5 file prior to distributing the dataset across all workers.

Significant updates in the Fast time series baseline and actuator reconstruction classes.

* Lots of bugfixes in the dataset, trainer, and models.
The basic encoders are now all working.

Examples are in scripts.

* Extended checkpointing - the trainer stores now:
- Model
- Optimizer state
- Scheduler state
- Current loss
- Current epoch

For the sake of continual training.

* Extended checkpointing - the trainer stores now:
- Model
- Optimizer state
- Scheduler state
- Current loss
- Current epoch

For the sake of continual training.

* Adapted the other reconstruction scripts to match the new API.

* Bugfix in the dataset class. When splitting inputs and targets, I forgot to remove unused modalities. This follows the standard getitem function now.

* Prepared an option to preprocess movies. This has to be fully integrated!!!

---------



* Dev peter (#55)

* Removed the argument "batch_size" from the trainers.
Changed default hyperparameters in the models.
Added demo for profile reconstruction.
Added script for dataset standardization (has to be run once before model training to store normalization coefficients).

* Bugfix in the dataset class. When iterating over movie configurations, the wrong configuration was used to find the correct signal name.
Also, removed warning for duplicated tensor conversion.

* Added base script for video reconstruction. Copied from Aza's branch for debugging purposes.

* Added base script for video reconstruction. Copied from Aza's branch for debugging purposes.

* Minor changes in the example scripts. More preprocessing options for the dataset class.

* Fixed a bug where the dataset class failed when using multiple workers and opening an H5 file prior to distributing the dataset across all workers.

Significant updates in the Fast time series baseline and actuator reconstruction classes.

* Lots of bugfixes in the dataset, trainer, and models.
The basic encoders are now all working.

Examples are in scripts.

* Extended checkpointing - the trainer stores now:
- Model
- Optimizer state
- Scheduler state
- Current loss
- Current epoch

For the sake of continual training.

* Extended checkpointing - the trainer stores now:
- Model
- Optimizer state
- Scheduler state
- Current loss
- Current epoch

For the sake of continual training.

* Adapted the other reconstruction scripts to match the new API.

* Bugfix in the dataset class. When splitting inputs and targets, I forgot to remove unused modalities. This follows the standard getitem function now.

* Prepared an option to preprocess movies. This has to be fully integrated!!!

* Added a baseline fusion transformer for latent space prediction.
Quick fix for the data standardization. Invalid values have to be ignored.
Fix in the function to create H5 files. bolo data does not have to be flipped anymore as the data is now stored in the correct format.

---------



* Moved some remaining scripts to the correct subdirectories.

* Still working on preparing the dataset. This is not ready to push. Preparation to moving to Stellar.

* Updated the data loader. Bugfix for loading the correct slices from H5 files.

Implemented calculating incremental statistics.

Corrected values in the modality configuration.

Removed redundant script standardize_dataset.py

* Added scripts for data fetching in Omega.
TODO: Write a documentation.

* Added a documentation for setting up Globus CLI on Omega and start a simple file transfer.

* Updated README.md:
- Added information on how to use all the scripts for data fetching.

Updated read_mds.sh
- Added a switch for globus file transfer. This simply stores the H5 files on Omega and we can add more data later.

* More PTData to fetch.

* PEP-8 compatible code.
Moved prepare_data.py to scripts, added a batch script to do this on compute nodes.
Added more point names to the data fetching scripts for Omega.
Added docstring to the WelfordTensor class.
Updated modalities.yaml with the new point names added.

* Generalized make_preprocessing_stats.py and made the function compute_preprocessing_stats more transparent.
Bugfix in modalities.yaml - Channels were missing in ECE.

* A lot of bugfixes in the dataloader and prepare_data.py

* Many bugfixees in the dataset class and for computing preprocessing stats. This is still not efficient enough and causes memory issues.

* Speed-ups in data_loader.py.

* Speed-ups in the dataloader.
Bugfixes in the trainer.
Cosmetic changes in tracking.py

* drawing.py:
- PEP-8 corrections
- Support plots of time signals and videos

Train-val-test split in fast_time_series_reconstruction.py

* Bugfix in processing methods of the dataloader:
- Channels was not handled properly (if selecting slices of a signal).
- Drawing: Restrict plotting to valid signals (not the padded sections after the actual signal).
- Introduced masked loss for fast time series reconstruction.

* Added a separate baseline encoder for filterscopes (renamed fast_time_series_baseline.py to filterscope_baseline.py).
Updates in the dataset class: Clipping for log transform can go down to -.99 (sufficient because we subtract 1.0).
Updates in drawing.py: We can now draw all kinds of different plots (except for profiles for now). Added functionality to draw correlation plots, which is important for finding feature distributions.
Added masked loss functions to not consider out-of-range time slices for training.

* Added a weighted loss to penalize target distributions.
Corrected the R2 score calculation in the drawer.
Renamed profile_reconstruction.py to mse_profile_reconstruction.py
Added ts_core_density_profile_reconstruction.py

* Modified the default parameters of some profile and time-series signals in data_loader.py
Added more loss functions in loss.py
Switched to HuberLoss in filterscopes_reconstruction.py, in mse_profile_reconstruction.py.
Updated model_factory.py to completed signal encoders/decoders.
Moved profile_baseline.py into modality.
Added training scripts for thomson scattering profiles.

* Added CER related info to the dataset class and to the model factory.

* Added dummy perceiver stuff. Be careful - this is not structured nicely yet. Only work in progress.

* Added more RMP point names to the data fetching script.
Restarted work on the latent feature space.

* Updated all scripts according to the increased set of diagnostics and actuators we are using.

* Updated preprocessing_stats. Here, the statistics are now pre-calculated for both, linear and log10 scale.
Working on more accurate autoencoders for time-series and profiles.

---------

Co-authored-by: Nathaniel Chen <nathanchen1101@gmail.com>
Co-authored-by: renierts <ps9551@princeton.edu>
…re space is more compact now.

Added foundation model utilities. This is under development!!!
Too much to comment all.
Mainly, the old foundation model is in archive to be able to restore it at any point.
The new training scripts are train_e2e*.
Adapted dataset functionalities to be compatible with the new training approach.
submit_all.sh now distributes signals across two HIP_VISIBLE_DEVICES
lanes concurrently and samples amd-smi during runs. train_ddp.sh adds a
parameterized torchrun launcher that reaches 2.16x speedup on ECE over
the single-GPU path. train_bes.sh batch_size 8 -> 4 to clear the prior
HIP OOM. .gitignore excludes local logs/ and .claude/ session dirs.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
# Conflicts:
#	pixi.lock
#	scripts/slurm/train_cer_rot.sh
#	scripts/slurm/train_cer_ti.sh
#	scripts/slurm/train_co2_tf_only.sh
#	scripts/slurm/train_ece_conv_fct.sh
#	scripts/slurm/train_ece_conv_nc.sh
#	scripts/slurm/train_ece_conv_tfc.sh
#	scripts/slurm/train_ece_tf_only.sh
#	scripts/slurm/train_filterscopes.sh
#	scripts/slurm/train_mhr_conv_dw_ft.sh
#	scripts/slurm/train_mhr_tf_only.sh
#	scripts/slurm/train_mhr_tf_only_multinode.sh
#	scripts/slurm/train_mhr_weighted_mse.sh
#	scripts/slurm/train_mse.sh
#	scripts/slurm/train_ts_core_density.sh
#	scripts/slurm/train_ts_core_temp.sh
#	scripts/slurm/train_ts_tangential_density.sh
#	scripts/slurm/train_ts_tangential_temp.sh
#	scripts/slurm/train_unimodal.sh
#	src/tokamak_foundation_model/data/data_loader.py
#	src/tokamak_foundation_model/models/modality/__init__.py
#	src/tokamak_foundation_model/models/modality/profile_baseline.py
#	src/tokamak_foundation_model/models/model_factory.py
#	src/tokamak_foundation_model/trainer/trainer.py
Patches the five train_e2e_stage*.py scripts to run under torchrun with
DistributedManager, DistributedSampler, and DDP-wrapped models, and adds
matching ROCm SLURM launchers in scripts/slurm_rocm/.

Stages 1, 2, and 2_delta wrap the high-level model or rollout directly.
Stages 2_extended and 3 introduce a small TrainStepModule wrapper because
their training paths bypass model.__call__; find_unused_parameters=True
keeps DDP happy under gradient checkpointing and LoRA. Validation runs on
all ranks in lockstep to avoid the broadcast_buffers deadlock that hits
when only rank 0 calls forward; only rank 0 logs and writes checkpoints.
Each stage smoke-tested end-to-end on 2x MI210.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- 5 simple production launchers (1x8 default, hand-editable):
  train_e2e_{stage1,stage2,stage2_delta,stage2_extended,stage3}.sh
- 20 shape-matrix launchers for DDP smoke testing:
  train_e2e_<stage>_{1x1,1x8,Nx1,NxN}.sh, with SMOKE=1 + PROFILE=1 hooks
- Profiling sidecar: rocm-smi (GPU util/VRAM/power) + mpstat (CPU) at 1Hz
  via srun --overlap, with _summarize_profile.py aggregator
- profile_indexing.py / profile_indexing.sh: CPU-only timing of
  build_datasets file-length pass; predicts full-pass duration and
  optionally pre-populates lengths cache for training jobs
- _frontier_common.sh: shared module loads, RCCL/MIOpen knobs, MASTER_ADDR
- _srun_rank_wrapper.sh: per-rank env setup mapping SLURM -> torch.distributed
- gitignore: envs/ (21GB conda) and profile/ (session output)
Each rank sees only 1 GPU (masked by ROCR_VISIBLE_DEVICES) so the device
index is always 0 regardless of LOCAL_RANK. Add a `visible > 1` clamp in
DistributedManager.device_index and use it everywhere device_id flows.
Apply the same fix to the explicit DDP wraps in stage 2_extended / stage 3
that used local_rank directly.
# Conflicts:
#	.gitignore
#	scripts/slurm/train_e2e_stage2_delta.sh
#	scripts/slurm/train_e2e_stage2_extended.sh
#	scripts/training/train_e2e_stage1.py
#	scripts/training/train_e2e_stage2.py
#	scripts/training/train_e2e_stage2_delta.py
#	scripts/training/train_e2e_stage2_extended.py
#	scripts/training/train_e2e_stage3.py
#	src/tokamak_foundation_model/data/data_loader.py
#	src/tokamak_foundation_model/e2e/model.py
#	src/tokamak_foundation_model/e2e/output_heads.py
#	src/tokamak_foundation_model/e2e/rollout.py
#	src/tokamak_foundation_model/e2e/tokenizers/slow_time_series.py
…eeze release

After the foundation_model merge, validate() at line 597 dereferenced
model.diagnostics directly, which AttributeError's under DDP. Same shape
of bug at line 1177 for _release_module_freeze (its sibling _apply at
line 1127 was already correct). Both now go through _core(model).
- pyproject.toml: extract the cu124 torch from base pypi-deps into a
  `cuda` feature; add a `frontier` feature with ROCm 7.1 wheels. envs:
    default  = ["cuda"]      (Princeton/CUDA workflow unchanged)
    fdp      = ["fdp","cuda"] (also unchanged)
    frontier = ["frontier"]   (new — Frontier ROCm)
- _frontier_common.sh: drop miniforge3 module + `conda activate`; instead
  put `~/.pixi/bin` on PATH and eval `pixi shell-hook -e frontier`.
- All 27 SLURM scripts: drop the now-orphan `conda activate $CONDA_ENV_PATH`
  line (the common.sh hook handles env activation by itself).

One-time bootstrap on a login node (after pulling):
    pixi install -e frontier

Performance is identical to the prior conda env — same PyTorch ROCm
wheels, just installed via uv instead of conda.
…ions

rocm7.1 wheels only publish torch 2.10.0 paired with torchvision 0.25-0.26;
the original ">=2.5.1,<2.11" / ">=0.20.1,<0.22" pins didn't intersect.
Also list triton-rocm as an explicit dep — torch 2.10 declares it but uv
doesn't auto-discover it through a per-package index spec.
The conda-env bootstrap is superseded by `pixi install -e frontier`
(see _frontier_common.sh). Deleted the 21GB conda env at envs/
locally too; that path was already gitignored.
Copy link
Copy Markdown
Collaborator

@renierts renierts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minimum changes - thank you!!!

@renierts renierts merged commit 4395ce6 into foundation_model May 10, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants