In [None]:
import random
import time
from datetime import timedelta
from datetime import datetime

import os
import numpy as np
import pandas as pd

import gc

import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM

from _models_benchmarking import load_opus_samples, load_squad_v2_samples, load_mmlu_samples, load_boolq_samples, load_commonsenseqa_samples, load_ai2arc_e_samples, load_ai2arc_c_samples
from _models_benchmarking import evaluate_with_stats, save_list_to_file

# Frozen Embeddings in Language Modeling: Proof-of-Concept Study

This repository presents code and results from a series of experiments studying language models trained with **predicted and fully frozen token embedding weights** throughout the entire pretraining phase.

*Our work provides a fundamental proof-of-concept : it is in principle possible to train competitive language models with embeddings that were never updated during training.*

---

## Background and Motivation

Traditional transformer language models (e.g. GPT-2, GPT-3) initialize token embeddings randomly and **allow these vectors to be updated** jointly with all other model parameters during training. It is commonly believed that a major part of the language model's semantic capacity is stored in the embedding layer.

Here, we challenge this paradigm:  
- We **predict** (precompute) token embeddings using external criteria (e.g., visual or Unicode composition),  
- Then **fully freeze** those embeddings,
- And train all other model parameters on standard large-scale text corpora.

We compare such **frozen-embedding models** directly with classical models (where embeddings are trainable), using the same tokenizers, architectures, and training setup.

---

## Experimental Setup

- **Training data:** All models were pretrained on a 9-billion-token English (and optionally multilingual) corpus.  
- **Pretraining only:** No large-scale supervised finetuning was performed.  
- **Instruct/SFT "sprinkling":** To enable validation on reasoning benchmarks (MMLU, SQuAD, etc.), a small fraction (~10%) of the training data consisted of instruct/SFT-formatted examples ("instruct data infusion"), but the vast majority (90%) of tokens were standard unsupervised text.
- **Baselines:** For each experiment, we compare:
    - Models with **frozen, predicted embeddings** (never updated during training)
    - Models with **trainable (classic) embeddings** (all weights optimizable)

All other hyperparameters (number of layers, heads, model dimension, batch size, LR schedule, etc.) were kept identical for both types.

---

## Summary of What is Provided

- Reproducible code for data processing, model definition, and training (PyTorch).
- Scripts/notebooks for running evaluation benchmarks (MMLU, SQuAD, ARC, BLEU, etc).
- All raw results and code for analysis in this repo.

---

## Scientific Contribution

* To our knowledge, this is the first systematic direct comparison of large-scale transformer language models trained with (i) fully frozen, externally-predicted embedding layers and (ii) standard learned embeddings, under identical conditions.
* We demonstrate that competitive language modeling is possible even with completely static and non-semantic embedding weights; the main semantic structure of the model emerges in the transformer blocks, not in the embedding table.
* This opens up directions for language model modularity, standardization of tokenization, and parameter fusion (see MoE experiments).

**Note:** This is a "proof-of-concept" (POC) study. The models are not expected to match or exceed the performance of advanced instruction-tuned models; all results should be interpreted in the context of unsupervised pretraining only. Metrics (MMLU, etc.) will naturally be lower than state-of-the-art because only generic pretraining was performed.

---

## Comparison to GPT-2 Baseline

Pretraining scale is roughly comparable to the original GPT-2-medium / large papers (300-700M parameters, up to 5B tokens trained-on).  
You may interpret the results as a fair indicator of "GPT-2 class" LM performance, given the data and compute.

---

## Transparency and Reproducibility

**All reported results in the main paper were obtained using the code and scripts in this repository.**

We encourage other researchers to:
- Reproduce and extend the experiments,
- Explore different embedding generation strategies,
- And contribute to further clarity on the role of embedding layers in language modeling.

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [5]:
all_subjects = ['high_school_european_history', 'business_ethics', 'clinical_knowledge', 'medical_genetics', 'high_school_us_history', 'high_school_physics', 'high_school_world_history', 'virology', 'high_school_microeconomics', 'econometrics', 'college_computer_science', 'high_school_biology', 'abstract_algebra', 'professional_accounting', 'philosophy', 'professional_medicine', 'nutrition', 'global_facts', 'machine_learning', 'security_studies', 'public_relations', 'professional_psychology', 'prehistory', 'anatomy', 'human_sexuality', 'college_medicine', 'high_school_government_and_politics', 'college_chemistry', 'logical_fallacies', 'high_school_geography', 'elementary_mathematics', 'human_aging', 'college_mathematics', 'high_school_psychology', 'formal_logic', 'high_school_statistics', 'international_law', 'high_school_mathematics', 'high_school_computer_science', 'conceptual_physics', 'miscellaneous', 'high_school_chemistry', 'marketing', 'professional_law', 'management', 'college_physics', 'jurisprudence', 'world_religions', 'sociology', 'us_foreign_policy', 'high_school_macroeconomics', 'computer_security', 'moral_scenarios', 'moral_disputes', 'electrical_engineering', 'astronomy', 'college_biology']

mmlu_samples = load_mmlu_samples(all_subjects, n_questions=-1)
boolq_samples = load_boolq_samples(n_questions=3270)
commonsense_qa_samples = load_commonsenseqa_samples(n_questions=1221)
ai2arc_e_qa_samples = load_ai2arc_e_samples(n_questions=570)
ai2arc_c_qa_samples = load_ai2arc_c_samples(n_questions=299)
squad_v2_samples = load_squad_v2_samples(n_questions=256)

In [6]:
mmlu_samples_by_subj = {}
for subject in all_subjects:
    mmlu_samples_by_subj[subject] = load_mmlu_samples([subject], n_questions=-1)
    #print(len(mmlu_samples_by_subj[subject]))

In [7]:
model_name = 'nemo_bvv_ru'
model_id = "Bochkov/" + model_name
print(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_id) 
reports = []
total_params = sum(p.numel() for p in model.parameters())
status = model_name + ' ' + f"Total parameters:     {total_params / 1e9:.1f}B"
print(status)
reports.append(status)
save_list_to_file(reports, '_models_benchmarking_' + model_name + '.txt')

Bochkov/nemo_bvv_ru
nemo_bvv_ru Total parameters:     0.4B


In [8]:
for subject in all_subjects:
    mean, std, ci95 = evaluate_with_stats(model, tokenizer, 'MMLU', mmlu_samples_by_subj[subject], all_subjects, device=device, n_runs=10, p_source_lang='', p_target_lang='')
    status = model_name + ' ' + f"MMLU [{subject}]: {mean:.2f}% ± {ci95:.2f}% (σ={std:.2f}%)"
    print(status)
    reports.append(status)
    save_list_to_file(reports, '_models_benchmarking_' + model_name + '.txt')

nemo_bvv_ru MMLU [high_school_european_history]: 5.76% ± 1.03% (σ=1.65%)
nemo_bvv_ru MMLU [business_ethics]: 10.30% ± 1.90% (σ=3.07%)
nemo_bvv_ru MMLU [clinical_knowledge]: 16.00% ± 0.86% (σ=1.38%)
nemo_bvv_ru MMLU [medical_genetics]: 9.90% ± 1.70% (σ=2.74%)
nemo_bvv_ru MMLU [high_school_us_history]: 5.15% ± 1.04% (σ=1.67%)
nemo_bvv_ru MMLU [high_school_physics]: 7.02% ± 1.44% (σ=2.32%)
nemo_bvv_ru MMLU [high_school_world_history]: 5.99% ± 0.67% (σ=1.08%)
nemo_bvv_ru MMLU [virology]: 11.39% ± 1.42% (σ=2.29%)
nemo_bvv_ru MMLU [high_school_microeconomics]: 11.18% ± 0.99% (σ=1.60%)
nemo_bvv_ru MMLU [econometrics]: 7.89% ± 1.58% (σ=2.54%)
nemo_bvv_ru MMLU [college_computer_science]: 6.50% ± 0.64% (σ=1.02%)
nemo_bvv_ru MMLU [high_school_biology]: 12.42% ± 1.17% (σ=1.88%)
nemo_bvv_ru MMLU [abstract_algebra]: 6.50% ± 0.93% (σ=1.50%)
nemo_bvv_ru MMLU [professional_accounting]: 7.73% ± 1.00% (σ=1.61%)
nemo_bvv_ru MMLU [philosophy]: 15.59% ± 1.30% (σ=2.09%)
nemo_bvv_ru MMLU [professional_medicin

In [9]:
mean, std, ci95 = evaluate_with_stats(model, tokenizer, 'MMLU', mmlu_samples, all_subjects, device=device, n_runs=10, p_source_lang='', p_target_lang='')
status = model_name + ' ' + f"MMLU: {mean:.2f}% ± {ci95:.2f}% (σ={std:.2f}%)"
print(status)
reports.append(status)
save_list_to_file(reports, '_models_benchmarking_' + model_name + '.txt')

nemo_bvv_ru MMLU: 8.80% ± 0.20% (σ=0.32%)


In [10]:
mean, std, ci95 = evaluate_with_stats(model, tokenizer, 'ARC-e', ai2arc_e_qa_samples, [], device=device, n_runs=10, p_source_lang='', p_target_lang='')
status = model_name + ' ' + f"ARC-e: {mean:.2f}% ± {ci95:.2f}% (σ={std:.2f}%)"
print(status)
reports.append(status)
save_list_to_file(reports, '_models_benchmarking_' + model_name + '.txt')

nemo_bvv_ru ARC-e: 19.53% ± 1.11% (σ=1.80%)


In [11]:
mean, std, ci95 = evaluate_with_stats(model, tokenizer, 'ARC-c', ai2arc_c_qa_samples, [], device=device, n_runs=10, p_source_lang='', p_target_lang='')
status = model_name + ' ' + f"ARC-c: {mean:.2f}% ± {ci95:.2f}% (σ={std:.2f}%)"
print(status)
reports.append(status)
save_list_to_file(reports, '_models_benchmarking_' + model_name + '.txt')

nemo_bvv_ru ARC-c: 21.34% ± 1.18% (σ=1.91%)


In [12]:
mean, std, ci95 = evaluate_with_stats(model, tokenizer, 'C-SENSE', commonsense_qa_samples, [], device=device, n_runs=10, p_source_lang='', p_target_lang='')
status = model_name + ' ' + f"C-SENSE: {mean:.2f}% ± {ci95:.2f}% (σ={std:.2f}%)"
print(status)
reports.append(status)
save_list_to_file(reports, '_models_benchmarking_' + model_name + '.txt')

nemo_bvv_ru C-SENSE: 19.51% ± 0.28% (σ=0.44%)


In [13]:
mean, std, ci95 = evaluate_with_stats(model, tokenizer, 'SQUAD', squad_v2_samples, [], device=device, n_runs=10, p_source_lang='', p_target_lang='')
status = model_name + ' ' + f"SQUAD: {mean:.2f}% ± {ci95:.2f}% (σ={std:.2f}%)"
print(status)
reports.append(status)
save_list_to_file(reports, '_models_benchmarking_' + model_name + '.txt')

nemo_bvv_ru SQUAD: 6.80% ± 0.74% (σ=1.20%)


In [None]:
lang = 'en' 
lang_second = 'ru' 
lang_lang = lang + '-' + lang_second
opus_src_tgt_samples = load_opus_samples(lang, lang_second, lang_lang, n_questions=128)
opus_tgt_src_samples = load_opus_samples(lang_second, lang, lang_lang, n_questions=128)
mean, std, ci95 = evaluate_with_stats(model, tokenizer, 'BLEU', opus_src_tgt_samples, [], device=device, n_runs=10, p_source_lang=lang, p_target_lang=lang_second)
status = model_name + ' ' + f"BLEU [{lang}-{lang_second}]: {mean:.2f}% ± {ci95:.2f}% (σ={std:.2f}%)"
print(status)
reports.append(status)
save_list_to_file(reports, '_models_benchmarking_' + model_name + '.txt')

mean, std, ci95 = evaluate_with_stats(model, tokenizer, 'BLEU', opus_tgt_src_samples, [], device=device, n_runs=10, p_source_lang=lang_second, p_target_lang=lang)
status = model_name + ' ' + f"BLEU [{lang_second}-{lang}]: {mean:.2f}% ± {ci95:.2f}% (σ={std:.2f}%)"
print(status)
reports.append(status)
save_list_to_file(reports, '_models_benchmarking_' + model_name + '.txt')

In [None]:
lang = 'en' 
lang_second = 'zh' 
lang_lang = lang + '-' + lang_second
opus_src_tgt_samples = load_opus_samples(lang, lang_second, lang_lang, n_questions=128)
opus_tgt_src_samples = load_opus_samples(lang_second, lang, lang_lang, n_questions=128)
mean, std, ci95 = evaluate_with_stats(model, tokenizer, 'BLEU', opus_src_tgt_samples, [], device=device, n_runs=10, p_source_lang=lang, p_target_lang=lang_second)
status = model_name + ' ' + f"BLEU [{lang}-{lang_second}]: {mean:.2f}% ± {ci95:.2f}% (σ={std:.2f}%)"
print(status)
reports.append(status)
save_list_to_file(reports, '_models_benchmarking_' + model_name + '.txt')

mean, std, ci95 = evaluate_with_stats(model, tokenizer, 'BLEU', opus_tgt_src_samples, [], device=device, n_runs=10, p_source_lang=lang_second, p_target_lang=lang)
status = model_name + ' ' + f"BLEU [{lang_second}-{lang}]: {mean:.2f}% ± {ci95:.2f}% (σ={std:.2f}%)"
print(status)
reports.append(status)
save_list_to_file(reports, '_models_benchmarking_' + model_name + '.txt')