## Testing LZ78 Embeddings: Simple Experiment

In [None]:
from sys import stdout
from tqdm import tqdm
from lz_embed.transformer_based import LZPlusEmbeddingModel, WeightType, EmbeddingType
from datasets import load_dataset
import numpy as np
import matplotlib.pyplot as plt
import torch
from sklearn.metrics import ndcg_score
import mteb

In [None]:
%load_ext autoreload
%autoreload 2

### Very Simple Model Training
This is super non-optimal! Just training on Wikipedia and computing embeddings on the fly (as opposed to caching + PCA, which is perhaps the ultimate goal) 

In [None]:
import gc
del model
gc.collect()
torch.cuda.empty_cache()

In [None]:
import os
os.environ["OPENAI_API_KEY"] = "YOUR KEY HERE"

In [None]:
model = LZPlusEmbeddingModel(
    # inner_model_name="text-embedding-3-large",
    inner_model_name="Alibaba-NLP/gte-Qwen2-7B-instruct",
    device="cuda:7",
    inner_model_type=EmbeddingType.TRANSFORMERS,
    valid_character_string="abcdefghijklmnopqrstuvwxyz ",
    make_lowercase=True,
    weight_type=WeightType.UNIFORM
)

In [None]:
dataset = load_dataset("salesforce/Wikitext", "wikitext-2-v1")

In [None]:
EPOCHS = 20
stdout.flush()
for _ in tqdm(range(EPOCHS)):
    for text in (dataset["train"]):
        text = text["text"]
        if not text:
            continue
        model.train_spa(text)

In [None]:
model.spa.prune(5)

In [None]:
print(f"The LZ tree has {model.spa.get_total_nodes() / 1e6} million nodes")

In [None]:
model.compute_subspace(
    512, num_gen_seqs=2000, gen_seq_len=250, backshift_len=6,
    enable_low_rank_projection=True
)

## MTEB Evaluation
As the `LZPlusEmbeddingModel` class inherits from `SentenceTransformer`, any MTEB task can be evaluated using the `mteb` library's interface.

Below are a few from the benchmark, with a very high-level description of how the task is scored.

### AILA Statutes (Retrieval)
We are given some documents and queries. The embedding model is scored based on whether the relevant documents for each query are close to the query in embedding space.

In [None]:
tasks = mteb.get_tasks(tasks=["AILAStatutes"])
evaluation = mteb.MTEB(tasks=tasks)

# If this doesn't actually run, you'll have to delete a JSON file in results/test
results = evaluation.run(
    model, output_folder=f"results/test",
    show_progress_bar=True
)

In [None]:
results[0].scores["test"][0]["main_score"] * 100

### ArXivHierarchicalClusteringP2P (Clustering)
We are given articles from Arxiv, and the embedding model is scored based on how well embeddings of the articles can be hierarchically clustered (compared to ground-truth "topic" labels for the articles).

In [None]:
tasks = mteb.get_tasks(tasks=["ArXivHierarchicalClusteringP2P"])
evaluation = mteb.MTEB(tasks=tasks)

# If this doesn't actually run, you'll have to delete a JSON file in results/test
results = evaluation.run(
    model, output_folder=f"results/test",
    show_progress_bar=True
)

In [None]:
results[0].scores["test"][0]["v_measure"] * 100

### DBpediaClassification
Some classification task for encyclopedia articles, scored based on accuracy. Classification appears to be performed based on k-nearest-neighbors in embedding space.

**Warning**: this task takes substatially longer than the the previous two; about 45 minutes on an A6000 GPU.

In [None]:
tasks = mteb.get_tasks(tasks=["DBpediaClassification"])
evaluation = mteb.MTEB(tasks=tasks)

# If this doesn't actually run, you'll have to delete a JSON file in results/test
results = evaluation.run(
    model, output_folder=f"results/test",
    show_progress_bar=True
)

In [None]:
results[0].scores["test"][0]["accuracy"] * 100

## Notes: Next Steps
- (implemented) _lowercase and ignore punctuation_
- can take the embeddings as we go instead of just at leaves
- SoTA instead of Qwen; can use an API
    - Tried this; OpenAI embeddings are quite slow
- Focus on wikipedia classification to see if we should train on more data (in-distribution)
- (implemented) _Monte Carlo -> PCA to get subspace_
- Ablate on different averaging methods
- Try to do a zipf weighting
    - This might be morally similar to doing a plain average of the weights; TODO---think more about this point

### Plan
- First thing we should ablate is the weights
- Our method has a unique advatage of giving an accurate perplexity estimate.