Skip to content

adamdivak/hyper_lm

Repository files navigation

HyperLM: a fully hyperbolic small language model

Abstract

Hyperbolic geometry is an alternative geometry formulation that describes a different set of mathematical primitives than ordinary, flat Euclidean space. There is a growing interest in building neural networks on hyperbolic geometry, with several works suggesting improved performance, especially on data sets that contain hierarchies. However, due to differences in elementary calculations such as addition or distance, many modern neural network components do not have an accepted hyperbolic version. Thus, we have created HyperT, a fully hyperbolic multi-head transformer, and used it to create HyperLM, the first fully hyperbolic language model capable of generating coherent English text. We have shown that HyperLM scales correctly with size and matches the performance of its Euclidean counterpart. A large ablation study pinpoints which design choices influence performance most, such as the hyperbolic classifier head and residual connections. Contrary to earlier claims, we find no evidence of systematic performance gains over Euclidean models, nor of a neat hierarchy emerging in token embeddings. Our versatile and clean implementation can serve as a backbone for future research, especially in comparing it to novel models released in parallel with our project.

Publication

The full paper can be found in AdamDivak_HyperLM_thesis.pdf

Installation

I've been running this code on three different systems, so there is some variation in the setup. In all cases a base Python environment was set up, then old-school pip was used to install the project requirements. The only difference was the base environment:

  • On my local Mac/Linux machines, conda was used to set up the base environment
  • On RunPod I used the template runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04 (which is why I used Python 2.4 throughout the project)
  • On SLURM/Snellius I used the 2023 module, but not the PyTorch module, as that had a really ancient version

Replicate this on a new machine:

# clone project
git clone https://github.com/adamdivak/hyper_lm
cd hyper_lm

# clone modified hypll
git clone https://github.com/adamdivak/hyperbolic_learning_library.git hypll

# [OPTIONAL] create conda environment
conda env create -f environment.yaml
conda activate hyper_lm

# install requirements
pip install -r requirements_24.txt

A little helper script with some scaffolding for managing any of the environments:

./scripts/setup_env.sh <remote_host> <root_dir> <venv_root_dir>

How to run

Sync files to remote hosts

Again a little helper: make sure the remote environment is set up, correctly rsync the code files but avoid uploading logs, etc. Optionally also runs the code, but I didn't use this in the end, as I found it more convenient to do it manually in tmux or submitting the jobs on Snellius.

scripts/run_remote.sh <remote_host> --sync-only

<remote_host> can be runpod, snellius or aurora (my regular Linux machine).

Manual initialization on RunPod/aurora

Initialize environment on RunPod/aurora:

cd /workspace/hyper_lm
tmux
source /workspace/venvs/hyper_lm/bin/activate
source ~/work/venvs/hyper_lm/bin/activate

Basic usage - training

Train models using src/train.py. All config options are in configs/, and any of them can be overridden on the command line. Below are a few minimal examples, more detailed settings can be found in the readme of lightning-hydra-template.

Train model with default configuration:

# train on CPU
python src/train.py trainer=cpu

# train on GPU
python src/train.py trainer=gpu

Train model with chosen experiment configuration from configs/experiment/

python src/train.py experiment=experiment_name.yaml

Override any parameter from command line like this

python src/train.py trainer.max_epochs=20 data.batch_size=64

Train on SLURM: instead of directly running the train script, submit a job to SLURM. The jobs file does a few smart things:

  • copy the data files to the local scratch using rsync drive to avoid loading them from the network
  • kill the training 5 minutes before the job times out and copy the results back to the network
sbatch ~/hyper_lm/scripts/train_snellius.job <args>

Resuming trainings

Resume training from an existing checkpoint - expects identical setup and resumes the optimizer, lr scheduler, etc. If the wandb.id is specified then continues logging to the same run.

python src/train.py experiment=lm/hypll_tiny_gpt ckpt_path="logs/train/runs/2025-05-22_17-45-11/checkpoints/last.ckpt logger.wandb.id=4bl82led"

Resume a training which was done using torch.compile, but without using compile now

python scripts/modify_checkpoint.py clean logs/train/runs/2025-05-26_19-02-04/checkpoints/last.ckpt

Load a model from an existing checkpoint, without loading everything else (like the optimizer state). Good for continuing the training with different parameters. Optionally specify model.net.model_checkpoint_strict_loading=False to ignore missing keys, which makes it possible to partially load a pre-trained model (e.g. when increasing the number of layers in continued training).

python src/train.py experiment=lm/hypll_tiny_gpt model.net.model_checkpoint_path="logs/train/runs/2025-05-22_17-45-11/checkpoints/last.ckpt"

Offline training / uploading results later

At some point I didn't have internet access for an extended period, so I needed to run the training offline and sync the results to W&B later.

Train offline, disabling all functionality that would require internet access during training

python src/train.py experiment=lm/hypll_tiny_gpt logger.wandb.offline=true logger.wandb.log_model=false callbacks.wandb_duplicate_run_checker.enabled=false"

Copy files from an offline instance for syncing to W&B

python scripts/wandb_offline_run_sync.py copy /Volumes/aurora/work/hyper_lm/logs/train/runs --to ~/tmp/wandb_offline_runs
python scripts/wandb_offline_run_sync.py sync ~/tmp/wandb_offline_runs

Offline server handling

sshfs aurora:/home/adam /Volumes/aurora
sudo diskutil unmount force /Volumes/aurora

Hyperparameter search

Hyperparameter search was set up using Hydra and Optuna, but this was not what I used in the end for analysis. The whole setup was much more broken than I expected, to the point that I had to manually fix some Optuna plugins. I also couldn't figure out how to make certain more complex specifications I wanted to do, so eventually I went back to using dedicated shell scripts to run the experiments.

Not used: built-in hyperparameter search

Run a hyperparameter search experiment by additionally specifying hparams_search on the command line. Right now there are corresponding experiment and hparams_search files for each model. I couldn't yet figure out how I could add multiple models with different configurations for hyperparam search in a single experiment. I'll need to get back to this to clean it up.

python src/train.py experiment=lm/hypll_tiny_gpt hparams_search=hypll_transformer_hp

Hyperparameter / multi-parameter runs

All final experiments in the thesis used dedicated shell scripts to run the experiments. This is not very elegant, but it was easy enough for me. Run any of them by running the corresponding train_lms_*.sh script.

./scripts/train_lms_1_architecture_params.sh

Evaluation

Download specific checkpoint from the server:

rsync -rlptz --exclude=data --exclude=.idea --exclude=.git runpod:/workspace/hyper_lm/logs/train/runs/2025-05-26_19-02-04/ logs/train/runs/2025-05-26_19-02-04/

Evaluate a trained model

python src/eval.py experiment=lm/hypll_tiny_gpt ckpt_path=logs/train/runs/2025-05-26_19-02-04/checkpoints/last_cleaned.ckpt +trainer.limit_val_batches=10

Analysis

Calculate token statistics SentencePiece claims to calculate this (both when using the Python wrapper, and separately using the spm_export_vocab command), but it doesn't. When using BPE, it only saves a priority order of tokens, but not their frequency. So we need to calculate it ourselves by re-loading the whole tokenized dataset.

python src/data/tinystories_datamodule.py

Visualize model predictions

python src/analysis/visualize_predictions.py -- --checkpoint=logs/train/runs/2025-05-26_19-02-04/checkpoints/last_cleaned.ckpt --input_text="Once upon a time"

Run just the tests on all previously executed runs in a given project. (Also useful for simply downloading the checkpoints for other processing.)

python src/analysis/post_training_test.py --project lm_parameter_search2_ablation --target_project analysis --use_cached

Create most of the plots that are in the thesis Sorry, this is a bit of a mess and a clean-up would do good around these files

./scripts/analyze_results.sh

Auxiliary: verify model correctness by checking that gradients are not mixed up across batches

python src/train.py experiment=seq_class/simple_seq_class callbacks=batch_gradient_verification
python src/train.py experiment=seq_class/simple_seq_class callbacks=batch_gradient_verification model=hypll_gpt_classifier

You can test that the gradient verification works by forcefully introducing errors in the model using _introduce_batch_mixing_error.

Run sequence classification experiments

There are currently two kinds of experiments implemented, sequence classification and language modelling.

Sequence classification contains multiple simple synthetic tasks to test the attention/transformer implementations on small problems, before moving to full language modelling. These are fast to train and have exact correct solutions. Language modelling contains the training of a full language model using the TinyStories dataset.

Minimal tests on sequence classification

There are currently 3 tasks implemented:

  1. Simple sequence classification, which is a binary classification task using a 2 useful tokens + 1 masking token dataset with a single rule. This was designed to test attention. Solving the task requires position counting on the varying-length sequence, so it can not be fully solved by a linear layer, though the linear models achieved ~70% accuracy instead of just 50%.
    • Input: Sequence of tokens 0/1/2. Sequences start and end with a random number of 2's (padding), with random sequences of 0's and 1's in the middle.
    • Output: The second non-padding token (either 0 or 1).
    • Example: Input 2222220102 → Output 1 (the second token after padding)
  2. Dyck language, which is a balanced parenthesis problem with different paren classes, so difficulty can be increased by introducing new tokens. I expected this to be more difficult than the simple sequence classification, but it was not.
    • Input: A sequence of opening and closing brackets of various types (e.g., (, ), [, ]).
    • Output: The next valid closing bracket to maintain a well-formed Dyck sequence (balanced parentheses).
    • Example: Input ([[<>[<>]] → Output ] (to close the most recently opened bracket)
  3. Complex sequence classification, which is a binary classification problem using a combination of multiple simple rules (e.g. rule 1 or rule 2 and rule 3 or rule 4 must be satisfied). This was designed to be hard enough that a single head can not solve it. I assumed it would either require multiple heads, or at least two layers.
    • Input: A sequence of tokens 0/1/2.
    • Output: 1 if (rule1 or rule2) and (rule3 or rule4) else 0
      1. Valid if contains the pattern 0100 anywhere (pattern recognition)
      2. Valid if has at least three 1's (not necessarily continuous) (counting)
      3. Valid if tokens on all even positions match (masking/pattern recognition)
      4. Valid if number of 0s and 1s match (counting)
    • Example: Input 11100011010001, rule1: 1, rule2: 1, rule3: 0, rule4: 1 → Output 1

There are 4 models I actively use as baseline/testing (there are several others implemented for earlier experiments):

  • Euclidean embedding + linear - a regular embedding layer followed by an MLP
  • Euclidean transformer - a minimal transformer
  • Hyperbolic embedding + linear - a hyperbolic embedding layer followed by an MLP based on hypll
  • Hyperbolic transformer - a minimal transformer based on hypll
# simple sequence classification - simple (Euclidean) linear
python src/train.py experiment=sq/ssq_e_mlp_embed hparams_search=e_mlp_embed_hp trainer=gpu
# simple sequence classification - simple (Euclidean) transformer
python src/train.py experiment=sq/ssq_e_transformer hparams_search=e_transformer_hp trainer=gpu
# simple sequence classification - hyperbolic linear
python src/train.py experiment=sq/ssq_hypll_mlp_embed hparams_search=hypll_mlp_embed_hp trainer=gpu
# simple sequence classification - hyperbolic transformer
python src/train.py experiment=sq/ssq_hypll_transformer hparams_search=hypll_transformer_hp trainer=gpu


python src/train.py experiment=seq_class/dyck_htrafo_hp hparams_search=htrafo_hp trainer=gpu
python src/train.py experiment=seq_class/csq_htrafo_hp hparams_search=htrafo_hp trainer=gpu

Or simply train all combinations at once

python scripts/train_all_sq.sh

Run reproduction of other papers

There is a reproduction of the base models of other papers (Hypformer, FHNN, HVT, HNN++) in the minimal_tests directory. minimal_tests/transformer_comparison.py contains most code, though it's outdated and not very organized at this point, so no guarantees around it.

Built with / some code or inspiration taken from

Contains the model implementations of the following papers, with small modifications to allow comparison on my datasets:

Errata / fix next time

  • float32_matmul_precision is applied during training, but not during evaluation
  • test batch size is limited to 1000. Should use the full test set instead
  • train is limited to 26k batches in most parameter searches, use the full epoch length instead
  • some train scripts use mlp_upscale_factor=1 for historical reasons, this should be removed
  • weight decay is probably incorrectly applied to all parameters, including things like bias and layernorm. Check and adapt the parameter groups from NanoGPT instead.

About

HyperLM: A transformer and GPT-2 compatible small language model based entirely on hyperbolic geometry

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors