# Tutorial: LZ Embeddings

## 1. Setup Instructions

### A. Rust LZ78 Library
You need to install Rust and Maturin, and then install the Python bindings for the `lz78` library as an editable Python package.
1. Install Rust: [Instructions](https://www.rust-lang.org/tools/install).
    - After installing Rust, close and reopen your terminal before proceeding.
2. If applicable, switch to the desired Python environment.
3. Install Maturin: `pip install maturin`
4. Install the `lz78` Python package: `cd crates/python && maturin develop -r && cd ../..`

**NOTE**: If you use virtual environments, you may run into an issue. If you are a conda user, it's possible the `(base)` environment may be activated on startup. `maturin` does not allow for two active virtual environments (ie. via `venv` and `conda`). You must make sure only one is active. One solution is to run `conda deactivate` in preference of your `venv` based virtual environment.

**NOTE**: If you are using MacOS, you may run into the following error with `maturin develop`:
```
error [E0463]: can't find crate for core
    = note: the X86_64-apple-darwin target may not be installed
    = help: consider downloading the target with 'rustup target add x86_64-apple-darwin'
```
Running the recommended command `rustup target add x86_64-apple-darwin` should resolve the issue.

### B. LZ78 Embeddings
From the root directory of the repository, run
```
pip install --editable .
```

### **Warning**
Sometimes, Jupyter doesn't register that a cell containing code from the `lz78` library has started running, so it seems like the cell is waiting to run until it finishes.
This can be annoying for operations that take a while to run, and **can be remedied by putting `stdout.flush()` at the beginning of the cell**.

## 2. Imports

In [None]:
from sys import stdout
from tqdm import tqdm
from lz_embed.classical import BasicLZSpectrum, AlphabetInfo
from lz_embed.transformer_based import LZPlusEmbeddingModel, DeepEmbeddingModel
from datasets import load_dataset
import numpy as np
import matplotlib.pyplot as plt

In [None]:
%load_ext autoreload
%autoreload 2

## 3. Summary
This notebook contains two simple ways of utilizing LZ78 to create an embedding.
A high-level description of each method is below; see the rest of the notebook for more details: 
1. **The LZ Spectrum**: computes the proportion of every symbol in the alphabet for each node of the LZ78 tree (in a pre-defined order) and concatenates all of them. This is simple, but only works well for small alphabets, or else the length of the resulting embedding grows too large.
2. **Combining LZ with a pretrained embedding model**: first, some training data is used to build an LZ78 tree. Then, the embeddings for all nodes can be computed using a pretrained embedding model.

    To embed a text, we first use the LZ78 tree to parse the text into phrases. Then, each phrase is embedded based on the embedding of the corresponding leaf, and the weighted average of phrase embeddings is taken.

    _Note_: for this notebook, the embeddings of all nodes are not actually computed _a priori_, but rather they are computed at evaluation time. This is just to facilitate quick experiments.

#### **Disclaimer: this is preliminary work, and not very polished, documented, or debugged.**

## 4. LZ Spectrum
### Methodology
The LZ78 embeds a sequence over an alphabet, $\mathcal{A}$ of size $A$ as follows:
1. It builds an LZ78 tree using the sequence (for now, with no training required). There is an option to limit the depth of the tree, which is recommended.
2. The nodes of the tree are ordered based on their corresponding phrases, as follows:
    - Phrase $x \in \mathcal{A}^n$ is less than phrase $b \in \mathcal{A}^m$ iff $n < m$ or $n = m$ and $\sum_{i=1}^n x_i A^{i-1} < \sum_{i=1}^n y_i A^{i-1}$

3. The empirical distribution of symbols seen while at each node (with the last component cut out, as it is a deterministic function of the first $A-1$ components) are concatenated together, in the ordering defined above. If there are "gaps" (nodes missing in a level of the tree), they are filled in with the uniform distribution.

### Instantiation
To instantiate a `BasicLZSpectrum`, you need to specify information about the alphabet of the sequences you are encoding.
This is via the `AlphabetInfo` dataclass, where you specify either:
1. `AlphabetInfo(alphabet_size=A)`, for encoding integer sequences, or
2. `AlphabetInfo(valid_character_string="abcdefgh...")` for encoding strings. This specifies that strings being encoded can only consist of characters in `valid_character_string`.

You can also specify `max_depth` (which is recommended), which stops adding new leaves to the LZ tree below the specified depth.

**Fixed-length embeddings**: If the `max_depth` argument is specified, then the tensors returned by the `BasicLZSpectrum.embed` function are guaranteed to be a fixed length.
You can also specify the `fixed_length` argument upon instantiation to either pad (with the value $1/A$) or truncate embeddings to a specified length.

In [None]:
model = BasicLZSpectrum(AlphabetInfo(alphabet_size=3), max_depth=4)

### Embedding
You can embed a sequence by calling `embed_single`, or a list of sequences by calling `embed`.

In [None]:
stdout.flush()
seq1 = [0, 1, 2, 1, 2, 0, 1, 1, 1, 2, 2, 2] * 5000
seq2 = np.random.randint(0, 3, size=(50000))
emb = model.embed_single(seq1)
(emb1, emb2) = model.embed([seq1, seq2])
assert np.allclose(emb1, emb)

In [None]:
plt.figure(figsize=(12, 3))
plt.stem(emb1, "red")
plt.grid(True)
plt.title("LZ Spectrum of Structured Sequence", fontdict={"size": 16})
plt.show()

In [None]:
plt.figure(figsize=(12, 3))
plt.stem(emb2, "blue")
plt.grid(True)
plt.title("LZ Spectrum of Random Sequence", fontdict={"size": 16})
plt.show()

### Issues and Next Steps
One key issue with this type of embedding is that the length of the embedding vector scales as $(A-1)A^d$, where $A$ is the alphabet size and $d$ is the depth of the tree. This will result in unreasonable-large embedding vectors for domains such as text.

Also, short sequences produce smaller LZ trees, resulting in an embedding vector that is mostly $1/A$-valued, which may be undesirable for downstream tasks.

#### **Some next steps:**
- Use an n-gram model for this instead of LZ
- Select nodes of the LZ tree to include in the embedding (sparsify the tree) based on some training data
- Think of other forms of dimensionality reduction for the embedding (maybe PCA over a training set)

#### **Note**: this may be more useful for a small-alphabet use case like DNA

## 5. LZ + Embedding Model

### Methodology
1. **Train**: using training data, build an LZ78 tree.
2. For each node in the tree, compute the embedding of the corresponding phrase using some embedding model. For now, this does **not** happen _a priori_ and only happens as phrases are encountered when embedding a sequence.
3. **Embedding**: To embed a sequence, first use the LZ78 SPA tree to compute:
    - The embeddings of phrases that the sequence parses into, and
    - The average log loss of each such phrase.
4. The final embedding is the weighted average of the embeddings, with the weights proportional to $2^{-\text{average log loss}}$.

### Instantiation
To instantiate an `LZPlusEmbeddingModel` object, you need to pass in a `DeepEmbeddingModel` object, which takes in the Huggingface model string of the embedding model.

You also need to pass in `AlphabetInfo`, like in **Section 4**.

In [None]:
model = LZPlusEmbeddingModel(
    DeepEmbeddingModel("Alibaba-NLP/gte-Qwen2-7B-instruct", device="cuda:7"),
    alpha_info=AlphabetInfo(valid_character_string="abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ. ,?\n\"';:\t-_"),
)

### Training
Training is done via the `LZPlusEmbeddingModel.train` method

In [None]:
dataset = load_dataset("salesforce/Wikitext", "wikitext-2-v1")

In [None]:
EPOCHS = 20
stdout.flush()
for _ in tqdm(range(EPOCHS)):
    for text in (dataset["train"]):
        text = text["text"]
        if not text:
            continue
        model.train(text)

In [None]:
print(f"The LZ tree has {model.spa.get_total_nodes() / 1e6} million nodes")

### Embedding
Embedding is done via the same iterface as in **Section 4**.
Embeddings returned are the same dimensionality as the base embedding model (for now).

Below is a quick example:

In [None]:
queries = [
    'how much protein should a female eat',
    'summit define'
]
# No need to add instruction for retrieval documents
documents = [
    "As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
    "Definition of summit for English Language Learners. : 1  the highest point of a mountain : the top of a mountain. : 2  the highest level. : 3  a meeting or series of meetings between the leaders of two or more governments."
]
input_texts = queries + documents

In [None]:
emb = model.embed(input_texts)

In [None]:
scores = (emb[:2] @ emb[2:].T) * 100

for i in range(2):
    print(f"Correlation of Query {i} with...", end="\t")
    for j in range(2):
        print(f"Document {j}: {round(scores[i, j].item(), 2)}", end="")
        if j == 0:
            print(",", end=" ")
    print()

In [None]:
emb

### Issues and Next Steps
The number of nodes in an LZ tree is very large, which means the "model size" can be quite large (and the plot for the competition is "model size vs. accuracy").

Also, it is unclear how much this differs to existing works that take the embeddings of tokens produced via, e.g., byte-pair encoding.

#### **Some next steps**:
1. Usng the embeddings for context aggregation, which will reduce the model size
2. Try this out on some downstream tasks
3. Find a good training set for growing the LZ tree