# Setup/Config


### Dependencies


In [1]:
import datasets
import gradio as gr
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer
from typing import Optional, List
import os

import orcalib as orca
from orcalib.orca_torch import ClassificationMode, OrcaModel
from orcalib.orca_classification import OrcaClassificationHead
from orcalib.orca_torch_mixins import LabelColumnNameMixin, DropExactMatchOption

ImportError: cannot import name 'EXACT_MATCH_THRESHOLD' from 'orca_common' (/Users/ryankopec/Library/Caches/pypoetry/virtualenvs/orcademo-Mh9Exlzv-py3.11/lib/python3.11/site-packages/orca_common/__init__.py)

### Config

If you are training a new model, please update the following config values:

```
USE_PRETRAINED = False
DO_TRAIN = True
```

If you prefer to use a pre-trained model, you can find one for this notebook [here](https://drive.google.com/drive/folders/1SKcZpeVPyG-Mo216RP8d6BNT7Cbon3oP).

If you are starting with a clean DB, update the following config value:

```
RELOAD_MEMORY_TABLE = True
```

Note: that this will take a little bit of time, but you will see a progress bar.


In [None]:
# Global Config

DEVICE = "mps"  # 'mps' (mac), 'cpu', or 'cuda'
DTYPE = torch.float32
NUM_MEMS = 10  # how many memories should the model use
USE_PRETRAINED = False  # should the pre-trained model be loaded (independent of DO_TRAIN)
PRETRAINED_PATH = "news_classification_demo_model.pt"  # path to save/load the model

SHRINK_DATASET = True  # should the dataset be shrunk to a smaller size for rapid testing

# Data Loading Config
RELOAD_MEMORY_TABLE = True

# Training Config
BATCH_SIZE = 32
NUM_EPOCHS = 1  # how many epochs to train for (if training is enabled)
LEARNING_RATE = 1e-4
DO_TRAIN = True  # should the model be trained (leave as False to use the pretrained model without tuning)
DO_SAVE = True  # should the model be saved after training (WARNING: will ovewrite PRETRAINED_PATH if True)
REBUILD_DATASET_LOOKUPS = True  # When True, the dataset lookups will be rebuilt. Otherwise, they will be loaded from disk.

# Testing Config
DO_START_UI = False  # simple gradio UI to manually test the model
DO_RUN_BENCHMARK = True  # benchmark the model on the test set and record results with Orca

# Names and Paths
MEMORY_TABLE_NAME = "memory_table"
MEMORY_INDEX_NAME = "memory_index"
ORCA_DATABASE_NAME = "news_classification_demo"

TRAIN_DATASET_FILENAME = os.path.join("data", "trainset_cached")
TEST_DATASET_FILENAME = os.path.join("data", "testset_cached")

# Loading Data


We're using a relatively simple open source dataset here that consists of ~120k news headlines, categorized into 4 categories: (World News, Sports, Business, Science/Technology). It has the advantage of being small enough to allow for quick iteration and experimentation, but it's not so small as to be unrealistic compared to real world datasets. It also contains quite a few mis-classifications in the training data, making for a good "real world data quality" approximation.


In [None]:
ds = datasets.load_dataset("ag_news")  # vanilla hugingface datasets here
ds_test = ds["test"]
ds_train = ds["train"]
print("Loaded ag_news dataset")

if SHRINK_DATASET:
    # Select a subset of the training/testing data to use as memories
    # This is useful for rapid testing
    ds_split = ds_train.train_test_split(test_size=0.05)
    ds_train = ds_split["test"]
    print("\tShrank training dataset to 5% of original size")

    ds_split = ds_test.train_test_split(test_size=0.5)
    ds_test = ds_split["test"]
    print("\tShrank testing dataset to 50% of original size")


db = orca.OrcaDatabase(ORCA_DATABASE_NAME)  # creates and connects to an orca database instance
print("Connected to Orca database")

Now we'll create a table to store the memories the model will use during training and inference. We add a text index to allow the model to search over them efficiently in its latent space.

We use Orca's Huggingface Dataset Ingestor to load the data — `orcalib` provides ingestors for many other data sources as well!


In [None]:
memory_labels = {
    "World News": 0,
    "Sports": 1,
    "Business": 2,
    "Sci/Tech": 3,
    "Swimming": 4,
}

label_to_name = {v: k for k, v in memory_labels.items()}

if RELOAD_MEMORY_TABLE:
    db.drop_table(MEMORY_TABLE_NAME, error_if_not_exists=False)
    train_table = db.create_table(
        if_table_exists=TableCreateMode.REPLACE_CURR_TABLE,
        table_name=MEMORY_TABLE_NAME,
        text=orca.TextT.notnull,
        label=orca.EnumT[memory_labels].notnull,
    )
    # We'll use this index to efficiently retrieve memories during training/inference.
    db.create_text_index(index_name=MEMORY_INDEX_NAME, table_name=MEMORY_TABLE_NAME, column="text")

    orca.HFDatasetIngestor(db, table_name=MEMORY_TABLE_NAME, dataset=ds_test, replace=True).run()

During training, each element in the training/testing datasets is typically accessed multiple times. To avoid repeating the memory lookups, we'll pre-cache the lookup results as new features in our datasets. During training/testing, we can inject the lookups directly into the model to dramatically reduce training time. Using cached lookups doesn't change the semantics of training; it just makes it faster.

In [None]:

from datasets import Dataset
from orcalib import OrcaLookupCacheBuilder

# The cacher will perform the memory lookups and add the results to the Dataset, which
# will significantly speed up training and inference.
lookup_cacher = OrcaLookupCacheBuilder( 
    db=db,
    index_name=MEMORY_INDEX_NAME,
    num_memories=NUM_MEMS,
    embedding_col_name="embedded_text", # Prompt's embedding
    # memory_column_aliases map the memory-lookup column names to feature names that will be added to the 
    # Dataset. This is useful for aligning the inputs to your model's forward() method and preventing
    # naming conflicts in the Dataset
    memory_column_aliases={"$embedding": "memory_embeddings", "label": "memory_labels"},
    drop_exact_match=True,
)

if DO_TRAIN:
    if REBUILD_DATASET_LOOKUPS:
        ds_train = lookup_cacher.add_lookups_to_hf_dataset(ds_train, "text")
        print("Added lookups to training dataset")
        ds_test = lookup_cacher.add_lookups_to_hf_dataset(ds_test, "text")
        print("Added lookups to testing dataset")

        ds_train.save_to_disk(TRAIN_DATASET_FILENAME)
        ds_test.save_to_disk(TEST_DATASET_FILENAME)
        print("Saved datasets to disk. Re-run with REBUILD_DATASET_LOOKUPS=False to use cached datasets.")
    else:
        ds_train = Dataset.load_from_disk(TRAIN_DATASET_FILENAME)
        print("Loaded training dataset from disk")
        ds_test = Dataset.load_from_disk(TEST_DATASET_FILENAME)
        print("Loaded testing dataset from disk")
        
    train_loader = DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(ds_test, batch_size=BATCH_SIZE, shuffle=False)
    print("Created data loaders")

# Building and Training the Model


### Model Definition


Here we're defining a simple classification model for the news headlines dataset using one of the pre-built, memory-augmented pytorch layers that are part of orcalib.

One distinction worth highlighting here is that we are setting the output dimension of the model (i.e. number of classes) higher than the 4 classes the dataset actually has. In a conventional model this would be pretty silly, but with an Orca Augmented model you can add new classes (up to the model output dimension) purely through memory modifications later on, without additional training.


In [None]:
class SimpleOrcaClassificationModel(OrcaModel, LabelColumnNameMixin):

    def __init__(self, lookup_cacher: Optional[OrcaLookupCacheBuilder] = None):
        super().__init__(database=db)

        self.lookup_cacher = lookup_cacher
        self.head = OrcaClassificationHead(
            model_dim=768,
            num_classes=10,  # actually only have 5 classes, but we're leaving room to grow
            memory_index_name=MEMORY_INDEX_NAME,
            label_column_name="label",
            num_memories=NUM_MEMS,
            dropout=0.1,
            activation=torch.nn.functional.relu,
            num_layers=1,
            classification_mode=ClassificationMode.MEMORY_BOUND,  # forces the model to predict from the memory distribution, leading to much better generalization to un-seen data (at the price of slightly lower static data accuracy)
            drop_exact_match=DropExactMatchOption.TRAINING_ONLY,
            exact_match_threshold=0.99,
        )

    def forward(self, x, memory_embeddings=None, memory_labels=None):
        if self.lookup_cacher is not None:
            self.lookup_cacher.inject_lookup_results(
                self,
                memory_embeddings=memory_embeddings,
                memory_labels=memory_labels,
            )
        logits = self.head(x)
        return logits

### Training Utils


Very vanilla training loop setup. Nothing particularly interesting or unique to Orca happening in this section.


In [None]:
from dataclasses import dataclass
from typing import Iterator


@dataclass(slots=True)
class PreparedRows:
    text: List[str]
    labels: torch.Tensor
    memory_labels: torch.Tensor
    memory_embeddings: torch.Tensor


def prep_data(iterator) -> Iterator[PreparedRows]:
    for stuff in iterator:
        pass
        inputs = torch.stack(stuff["embedded_text"], dim=1).float().to(DEVICE)
        labels = stuff["label"].to(torch.int64).to(DEVICE)
        memory_labels = torch.stack(stuff["memory_labels"], dim=1).to(torch.int64).to(DEVICE)
        memory_embeddings = torch.stack([
            torch.stack(t, dim=1) for t in stuff["memory_embeddings"]
        ], dim=1).float().to(DEVICE)
        yield PreparedRows(inputs, labels, memory_labels, memory_embeddings)


def get_accuracy(logits, labels):
    _, preds = torch.max(logits, 1)
    return (preds == labels).float().mean().item()


def get_test_accuracy(model, loader, progress_bar=False, live_lookups=False):
    model.eval()
    if progress_bar:
        wrapper = tqdm
    else:
        wrapper = lambda x: x
    with torch.no_grad():
        test_acc = 0.0
        test_steps = 0
        for batch in prep_data(wrapper(loader)):
            if live_lookups:
                outputs = model(batch.text)
            else:
                outputs = model.forward(batch.text, memory_embeddings=batch.memory_embeddings, memory_labels=batch.memory_labels)
            test_acc += get_accuracy(outputs, batch.labels)
            test_steps += 1
        avg_test_acc = test_acc / test_steps
    model.train()
    return avg_test_acc


def train_one_epoch(model: torch.nn.Module, optimizer, epoch: int, verbosity: int = 0):
    criterion = torch.nn.CrossEntropyLoss()
    model.train()
    running_loss = 0.0
    running_acc = 0.0
    steps = 0
    for batch in prep_data(train_loader):
        optimizer.zero_grad()
        outputs = model.forward(batch.text, memory_embeddings=batch.memory_embeddings, memory_labels=batch.memory_labels)
        loss = criterion(outputs, batch.labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        running_acc += get_accuracy(outputs, batch.labels)
        steps += 1
        if verbosity > 0 and steps % verbosity == 0:
            avg_loss = running_loss / steps
            avg_acc = running_acc / steps
            print(f"\t Step {steps}, Loss: {avg_loss:.4f}, Accuracy: {avg_acc:.4f}")

    avg_loss = running_loss / steps
    avg_acc = running_acc / steps
    print(
        f"Epoch [{epoch+1}/{NUM_EPOCHS}], Loss: {avg_loss:.4f}, Accuracy: {avg_acc:.4f}, Test Accuracy: {get_test_accuracy(model, test_loader):.4f}"
    )

### Training

In [None]:
model = SimpleOrcaClassificationModel(lookup_cacher).to(DEVICE).to(DTYPE)
embed_model = AutoModel.from_pretrained("sentence-transformers/multi-qa-mpnet-base-dot-v1").to(DEVICE)
embed_tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/multi-qa-mpnet-base-dot-v1")

if USE_PRETRAINED:
    model.load_state_dict(torch.load(PRETRAINED_PATH))

In [None]:
if DO_TRAIN:
    for epoch in range(NUM_EPOCHS):
        optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
        train_one_epoch(
            model,
            optimizer,
            epoch,
        )
    if DO_SAVE:
        torch.save(model.state_dict(), PRETRAINED_PATH)

# Spinning up a Testing UI


Spinning up a simple Gradio UI to test the model manually, with Orca Curate observability and editability enabled.


In [None]:
if DO_START_UI:
    model.update_curate_settings(
        model_id="news_classification_demo", model_version="0.0.1", extra_tags=["from_test_ui"], batch_size=1
    )
    model.enable_curate()

    def predict(text: str) -> str:
        embedding = embed_model(**embed_tokenizer(text, return_tensors="pt").to(DEVICE)).pooler_output.detach()
        outputs = model(embedding)
        _, pred = torch.max(outputs, dim=0)
        label_name = label_to_name[pred.item()]
        model.record_model_input_output(text, label_name)
        return label_name

    with gr.Blocks() as demo:
        with gr.Column():
            with gr.Row():
                inp = gr.Textbox(lines=5, placeholder="Enter your text here...", label="News Headline")
                outp = gr.Label(num_top_classes=1)
            btn = gr.Button(value="Submit")
            btn.click(fn=predict, inputs=inp, outputs=outp)

    demo.launch(show_api=False, inline=False)

    print("UI launched")

# Running a Benchmark


Running the testset through the model with Orca observability enabled, incl. tracking model inputs, outputs, and a feedback score (how well did the model do). Enables finding and iterating on "bad" examples in the Orca UI easily.


In [None]:
def eval_dataset(model, dataset):
    model.eval()
    with torch.no_grad():
        test_acc: float = 0.0
        test_steps: int = 0
        for item in tqdm(dataset):
            expected_label = item["label"]
            expected_label_name = label_to_name[expected_label]
            input_text = item["text"]
            input_embedding = embed_model(
                **embed_tokenizer(input_text, return_tensors="pt").to(DEVICE)
            ).pooler_output.detach()
            model.update_curate_settings(extra_metadata={"expected_label": f"{expected_label_name} ({expected_label})"})
            model_outputs = model(input_embedding)
            predicted_label: int = torch.max(model_outputs, dim=0)[1].item()
            predicted_label_name = label_to_name[predicted_label]
            correct = predicted_label == expected_label
            test_acc += 1.0 if correct else 0.0
            model.record_curate_scores(1.0 if correct else -1.0)
            model.record_model_input_output(input_text, f"{predicted_label_name} ({predicted_label})")
            test_steps += 1
        avg_test_acc = test_acc / test_steps
    model.train()
    return avg_test_acc

In [None]:
if DO_RUN_BENCHMARK:
    model.update_curate_settings(
        model_id="news_classification_demo", model_version="0.0.1", extra_tags={"testset_benchmark"}, batch_size=1
    )
    model.head.enable_curate()
    test_acc = eval_dataset(model, ds_test)
    print(f"Test Accuracy: {test_acc:.1%}")  # 90%

In [None]:
def nearest_neighbor_classifier(input_text):
    tokens = embed_tokenizer(input_text, return_tensors="pt", padding="max_length", truncation=True)
    with torch.no_grad():
        embeddings = embed_model(**tokens.to(DEVICE)).pooler_output
    query = embeddings.cpu().numpy().tolist()
    neighbors = db.vector_scan_index("train_index", query).select("label").fetch(NUM_MEMS).to_tensor("label")
    return neighbors.mode(dim=1).values

In [None]:
if DO_RUN_BENCHMARK:
    correct = 0
    steps = 0
    for batch in tqdm(DataLoader(ds_test, batch_size=20)):
        texts = batch["text"]
        expected_labels = batch["label"]
        predicted_labels = nearest_neighbor_classifier(texts)
        correct += (expected_labels == predicted_labels).sum().item()
        steps += 20
    print(f"KNN Ensemble Accuracy: {correct/steps:.1%}")  # 91%