Skip to content

CharlesCNorton/proof2weights

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

78 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

proof2weights

proof2weights defines a neural network, up to and including a GPT-2 transformer, in the Rocq prover, with arithmetic in IEEE-754 binary32 through the Flocq library, and extracts it to OCaml. The extracted program loads a .safetensors file and performs inference using the rounding behavior the proofs specify. The development loads the published 124-million-parameter GPT-2 weights, runs the forward pass over all twelve transformer blocks, and produces the same next-token prediction as the PyTorch reference implementation. The same approach extends to the Llama architecture: the development also runs SmolLM2, an instruction-tuned model, and the verified forward generates a coherent chat response. The weights written to disk are computed by the same definitions the proofs concern, and the floating-point arithmetic executed at inference time is the arithmetic the development reasons about.

The core development is contained in a single file, Phases1_15_complete.v. It contains no Admitted and no axioms beyond the Coq.Reals base that Flocq requires. It compiles under Rocq 9 with coq-flocq and extracts to standalone OCaml. A second file, Native_extract.v, re-extracts the same development with floating-point arithmetic mapped to the host's hardware float, which reduces the time for one forward pass from minutes to seconds.

Motivation

A common approach to verifying a neural network establishes a property in a proof assistant and then implements the weights and inference code separately in Python or C for deployment. The deployed numbers are a transcription of the proven numbers, and the deployed arithmetic is a separate implementation of the proven arithmetic. The two are not guaranteed to agree, and in floating point they frequently differ.

In proof2weights the weights are defined in Rocq, serialized by a Rocq function, and that serializer is extracted and executed, so the bytes written to disk are produced by the definitions the proofs concern. The reader is also Rocq: a function parses .safetensors, decodes IEEE-754 values, and assembles a typed model, and the forward pass that consumes the model is extracted from the same source. No separate implementation is introduced.

        Rocq definitions + Flocq IEEE-754
                     |
              extraction (verbatim)
                     |
                 OCaml binary
                 /          \
   write .safetensors      read .safetensors, run inference

Integer-exact core

The foundation is exact integer serialization. Tensors are records of a name, a shape (list nat), and data (list Z); a network is a list of tensors. Values serialize to little-endian i32, and the round trip is proved: decoding the encoding of any 32-bit integer returns that integer (roundtrip_z). On top of this sits a safetensors writer (an 8-byte little-endian header length, a JSON header, then the concatenated tensor bytes) and the inverse readers.

The development also includes the serialization and validation machinery a deployment uses, each with its definitions and the lemmas that state their properties: shape and bounds validation with sound boolean checkers, signed int8 and packed int4 quantization with proved range containment, run-length compression with a decode that inverts encode, chunked streaming, lazy tensors whose force-after-defer is the identity, network sharding under a byte budget, and JSON proof certificates with attestations and provenance chains.

IEEE-754 floating point

The floating-point type is binary32, not a fixed-point approximation. The binary32 type is Flocq's binary_float at precision 24 and exponent bound 128, and the arithmetic (f32_plus, f32_mult, f32_div, f32_neg, f32_abs, f32_compare, and square root via Flocq's Bsqrt) is round-to-nearest, ties-to-even. Bit patterns convert both ways for binary32, binary16, and bfloat16, and the encode/decode round trips are proved (roundtrip_f32, roundtrip_f16), as is B2R f32_one = 1.

The transcendentals are built from these primitives. The exponential is range reduced: the argument is saturated to the binary32 exponential range, divided by 256 so it lands where a short Taylor series is accurate, evaluated by a six-term Taylor series, and squared eight times, which remains finite across the input range. Sigmoid is 1 / (1 + exp(-x)), GELU is x * sigmoid(1.702 x), tanh is 2 sigmoid(2x) - 1, ReLU is a clamp, and softmax subtracts the row maximum before exponentiating. These run in binary32 and preserve vector and matrix dimensions, which is proved for each.

Neural-network library

The library includes the following components, each with a dimension- preservation lemma:

  • Dense layers, residual blocks, and bottleneck blocks.
  • Convolution weight records with im2col-style flattening, max and average pooling with non-negativity bounds.
  • Batch, layer, and group normalization; token and learned position embeddings and their sum.
  • Vanilla RNN, LSTM, and GRU cells, their sequence unrollings (output length equals input length, proved), bidirectional wrapping, and an RNN sequence classifier.
  • Scaled dot-product attention, multi-head attention with head splitting and concatenation, causal masking, and cross-attention, each shown to preserve sequence length.
  • Pre-norm and post-norm transformer blocks, feed-forward sublayers, full encoder and decoder layers, and assembled GPT-style, BERT-style, and full encoder-decoder models.

Both an integer fixed-point path and an IEEE-754 float path exist for the transformer operations; the float path is used for the GPT-2 model below.

The float GPT-2

The transformer is assembled end to end in binary32. The configuration record matches the published GPT-2 family, and the parameter counts are proved by reflection: gpt2_total_params gpt2_small reduces to 124439808, with the medium, large, and XL counts likewise pinned, alongside head-dimension and feed-forward-expansion checks.

A typed weight model (f32_model_weights) holds the token and position embeddings, a list of per-block weights (two layer norms, the fused QKV and output attention projections, and the two MLP projections), and the final layer norm. The forward pass embeds tokens and positions, runs the pre-norm decoder stack (layer norm, multi-head causal attention, residual, layer norm, MLP, residual), applies the final layer norm, and projects through the tied embedding to logits. Greedy generation decodes over those logits, and the generated sequence is proved to always extend the prompt, so the prompt is a prefix of the output. A finiteness certificate checks that an output contains no NaN or infinity, with a soundness lemma that a passing check implies every entry is a finite IEEE-754 value, and shape validators reject weights whose dimensions do not match the configuration.

The safetensors loader

The loader is implemented in Rocq. The header length is read as a little-endian u64, the JSON header is parsed into a string, and a JSON scanner (whitespace, natural numbers, quoted strings, integer arrays, and a substring key search) extracts each tensor's dtype, shape, and data_offsets. A named tensor is loaded by scoping to its key, reading its byte offsets, slicing the data section, and decoding little-endian f32 into binary32. The model loader constructs every GPT-2 tensor name (including the per-layer h.<i>. prefixes, built with a verified nat-to-string), loads and reshapes each, and assembles f32_model_weights. The token-embedding matrix is proved to have vocab_size rows and the assembled model to have exactly n_layer blocks. Decoding bytes to binary32 goes through Flocq's b32_of_bits composed with the single-NaN collapse, so the loaded value is the IEEE-754 value the bytes denote.

Inference on GPT-2 weights

The development loads the published 124-million-parameter GPT-2 base weights, runs the IEEE-754 forward over all twelve blocks, projects every one of the 50257 logits through the tied embedding, and predicts a next token. For the prompt "The quick brown" the greedy next token is "ie" (token 494), which matches the prediction of the PyTorch reference for the same prompt, and the top five candidates are returned in the same order. The logits computed here are offset from the reference's by approximately one unit, because the MLP uses the x * sigmoid(1.702 x) form of GELU while GPT-2 was trained with the tanh-based gelu_new; the offset is uniform and does not change the ranking or the argmax.

Applying the verified definitions to a model of this size requires addressing two performance constraints. The list-based loader cannot ingest a file of roughly 500 MB, because representing it as a list byte builds a linked list of hundreds of millions of boxed values, and the extracted matrix transpose inside the linear layer is quadratic in the output dimension through list indexing, which is acceptable at small dimensions but not at 768 and 3072. The GPT-2 runner therefore reads the file bytes natively, decodes each value with the verified f32_bytes_to_binary32, and composes the verified primitives (f32_mat_vec_mul, f32_dot, f32_layer_norm_2d, f32_causal_attention, f32_concat_heads, f32_gelu_vec, f32_add_matrices) in the order the proven block specifies, decoding each weight matrix already transposed so the verified matrix-vector product computes the same dot products the verified linear layer would. Weights are streamed block by block, so memory remains in the low gigabytes. Every floating-point value is produced by the verified operators; only the byte addressing is native. In a mode that prints the full logit matrix, the runner produces output bit-identical to the top-level verified f32_gpt2_logits on a small fixture, confirming the composition agrees with the verified forward.

Inference and generation on a Llama model

The same approach extends to the Llama architecture, which differs from GPT-2 in four respects: RMSNorm in place of layer normalization, rotary position embeddings (RoPE) in place of learned position embeddings, grouped-query attention, and a SwiGLU feed-forward network. Llama.v adds the primitives these require on top of the binary32 development. RMSNorm and SiLU compose existing operations; f32_sin and f32_cos, which RoPE needs, are defined by argument reduction modulo 2*pi and a Taylor polynomial. The runner loads SmolLM2-135M-Instruct (576 hidden, 30 layers, 9 query and 3 key/value heads, intermediate 1536, tied embeddings) and composes these primitives into the forward pass: RMSNorm, the query/key/value projections, RoPE applied to the per-head query and key vectors, grouped-query causal attention, the output projection, and the SwiGLU block, with weights streamed per layer.

On the chat prompt for "What is the capital of France?", the verified forward reproduces the PyTorch reference's full top-eight next-token ranking, with logits agreeing to four decimal places. Greedy generation over the verified logits produces "The capital of France is Paris." Generation uses a key/value cache: the prompt is processed once, its per-layer rotary keys and values are stored, and each new token is computed as a single position attending over the cache, which is bit-identical to a full recompute and removes the per-token cost of reprocessing the prefix. llama_chat.py drives an interactive session: it applies the model's chat template, sends the token ids to the runner, and decodes the generated ids, so each reply is produced entirely by the verified operators. In the native build the trigonometric functions are taken from the host libm, rounded to binary32; the remaining operations use the same trusted hardware-float boundary as the rest of the development.

Proof-carrying receipts

A generation can emit a receipt that binds the result to the exact weights (by checksum), the prompt, the full output token sequence, and the IEEE-754 semantics. Receipt.v defines the receipt, a checker verify_receipt, and its soundness and completeness theorems; the prompt-preservation guarantee follows from the proven generation property gpt2_generation_preserves_prompt. A receipt is checked without trusting the producer: recompute the weight checksum from the file, re-run the deterministic verified generation, and compare both against the receipt. llama_receipt.py emit writes a receipt for an answer and llama_receipt.py verify recomputes and re-runs to confirm it. Because the forward is a pure verified function, the regeneration is reproducible, so anyone holding the weights and the receipt can confirm that the recorded output is exactly what the model produces under the proven semantics.

Two extraction modes

The development supports two extraction modes from the same source.

The default mode, from Phases1_15_complete.v, keeps binary32 as Flocq's inductive binary_float and keeps Z and positive as their Coq inductive datatypes, so every integer operation, and therefore every float operation built on it, is the computational content of its proof. Z is not mapped to native machine int, because that mapping is unsound when a mantissa-alignment shift or an intermediate product exceeds the representable range, in which case binary32 addition of operands with a large exponent gap produces an incorrect result. nat, used only for indices, dimensions, and token identifiers and always small, extracts to native OCaml int, and ascii extracts to OCaml char with a destructuring matcher. This mode introduces no trusted floating-point boundary; the arithmetic is exactly Flocq's. Its bignum arithmetic runs at microseconds per operation, so a full GPT-2 forward pass takes tens of minutes.

The native mode, from Native_extract.v, re-extracts the same development with binary32 mapped to the host's hardware float, in the manner CompCert extracts its verified floats and treats the IEEE-754 agreement as a trusted boundary at the OCaml level. Each operation is the binary64 result rounded to binary32 through an Int32 bit round-trip. Because binary64 carries 53 bits of significand, more than twice binary32's 24 plus two, rounding a single +, -, *, /, or sqrt through binary64 and then to binary32 yields the same result as rounding directly in binary32, so the extracted value is the binary32 value Flocq specifies. Decoding reads the four little-endian bytes directly to a binary32 with Int32.float_of_bits. This mode introduces a trusted assumption, that the host's float matches IEEE-754 binary32, and runs approximately three hundred times faster.

The two modes were compared directly. On a small fixture their logits are bit-identical to nine digits. On the full 124-million-parameter GPT-2 they produce identical token identifiers and logits that agree to four decimals, the inductive mode in tens of minutes and the native mode in approximately seven seconds.

What is proved

The development proves, with no Admitted or admit:

  • Serialization round trips: i32, and the binary32 and binary16 bit-pattern encode/decode identities.
  • Dimension preservation through the stack: dense, attention (per-head and the masked path), softmax, layer norm, the transformer blocks, the RNN family, and the float linear-algebra and attention primitives.
  • Soundness bridges from boolean checkers to propositions: shape validity, value bounds, network verification, and the float finiteness certificate.
  • GPT-2 parameter counts and configuration validity by reflection.
  • Quantization range containment, pooling bounds, lazy-tensor round trip, and reflexivity of network equality.
  • Prompt preservation under greedy generation, for both the integer and float models.
  • Correct rounding and a half-ULP accuracy bound for each float primitive: multiplication, addition, division, and square root each return the round-to-nearest, ties-to-even result of the exact real operation and differ from it by at most half a ULP (Accuracy.v).

The float arithmetic is exactly Flocq's, and each operation is proved to land within half a ULP of the exact real result, so the extracted executable is a faithful binary32 computation with a per-operation accuracy bound. Composing those bounds into a single worst-case error for the whole forward is the next theorem.

Use as an IEEE-754 reference

Because the extracted forward pass is deterministic and pins every rounding and every reduction order, it is a fixed reference against which other inference implementations can be measured. Production float32 implementations differ from each other and across hardware because of fused multiply-add contraction, BLAS summation order, and GPU nondeterminism; this development fixes a single evaluation and proves it is the one the IEEE-754 semantics, as formalized by Flocq, specify. The repository includes a differential-testing harness that runs the same network through this reference and through a numpy float32 implementation of the identical operations and reports the divergence and any next-token disagreements; the divergence is small at small sizes and grows with depth, consistent with the compounding of reduction-order rounding. At full scale, the reference produces the same next-token prediction as the PyTorch implementation on GPT-2 weights, as described above.

Building and running

Requires Rocq 9 with coq-flocq, and OCaml (4.14 or later). The GPT-2 fetch and the differential harness additionally use Python with torch, transformers, numpy, and safetensors.

# Compile the development and extract the inductive OCaml (phases1_15_complete.{ml,mli}).
rocq compile Phases1_15_complete.v

# Compile the native re-extraction (phases1_15_native.{ml,mli}), requires the .vo above.
rocq compile -Q . "" Native_extract.v

# Build the inductive (exact) runner.
ocamlopt -rectypes -w -a phases1_15_complete.mli phases1_15_complete.ml gpt2_talk.ml -o gpt2_talk

# Build the native (fast) runner.
ocamlopt -rectypes -w -a phases1_15_native.mli phases1_15_native.ml gpt2_talk_native.ml -o gpt2_talk_native

# Fetch GPT-2, save f32 weights with the loader's tensor names, and print the
# PyTorch reference prediction.
python gpt2_setup.py

# Predict the next token. Arguments: mode (full|next), file, n_embd, n_head,
# n_layer, n_inner, vocab, n_positions, then the comma-separated token ids.
./gpt2_talk_native next gpt2.safetensors 768 12 12 3072 50257 1024 464,2068,7586

# Llama path: add the Llama primitives, extract them natively, build the runner.
rocq compile -Q . "" Llama.v
rocq compile -Q . "" Llama_native.v
ocamlopt -rectypes -w -a llama_native.mli llama_native.ml llama_talk_native.ml -o llama_talk_native

# Fetch SmolLM2, save f32 weights and the rotary frequencies, capture the oracle.
python smollm_setup.py

# Chat. Tokenization runs locally; the verified forward runs on the host.
python llama_chat.py "What is the capital of France?"

ref_logits.ml is a smaller runner that uses the verified list-based loader on toy .safetensors files: it reads file bytes as inductive Z, calls parse_header_size and parse_header_string to split the header from the data, calls f32_load_model to assemble the typed weights, and calls f32_gpt2_logits.

Repository layout

Path Contents
Phases1_15_complete.v The development: definitions, proofs, and the inductive extraction.
Native_extract.v Re-extraction mapping binary32 to hardware float (the native build).
gpt2_talk.ml GPT-2 runner against the inductive extraction (exact).
gpt2_talk_native.ml The same runner against the native extraction (fast).
gpt2_setup.py Fetches GPT-2, saves f32 weights with the loader's tensor names, prints the PyTorch reference.
Llama.v Llama primitives: RMSNorm, SiLU, and f32_sin/f32_cos for RoPE.
Llama_native.v Native extraction of the Llama primitives (sin/cos via libm).
llama_talk_native.ml SmolLM2 runner: the verified Llama forward and greedy generation.
smollm_setup.py Fetches SmolLM2, saves f32 weights and rotary frequencies, prints the PyTorch oracle.
llama_chat.py Interactive chat: local tokenization, verified forward on the host.
Accuracy.v Correct-rounding theorems for the f32 multiply, add, divide, and square root.
Receipt.v Inference receipt, the checker verify_receipt, and its soundness and completeness.
llama_receipt.py Emit and verify proof-carrying receipts for generated answers.
ref_logits.ml Smaller runner using the verified list-based loader on toy models.
tiny_gpt2_ref.py numpy reference of the identical computation, and a tiny .safetensors generator.
experiment_gen.py, experiment_cmp.py Differential-testing harness: generate models, compare the reference against numpy.
run_float_demo.sh Compile, extract, build, and run on a toy model.
old/ Earlier per-component development, superseded by the consolidated file.

Related work

  • Flocq provides the IEEE-754 formalization this development computes in.
  • CompCert's verified floating point is the model for extracting Flocq arithmetic to a real executable, and for the trusted native-float boundary in the native build.
  • MLCert certifies generalization bounds for machine learning in Coq and extracts; its focus is bounds rather than a transformer running real weights.
  • verinncoq/converter verifies properties of externally trained networks.
  • Cheerios is verified serialization for Coq.

License

MIT

About

Verified extraction of neural network weights from Coq proofs to deployable formats

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors