# Approximate Nearest Neighbors for Text Classification

Author: [Collin Zoeller](www.linkedin.com/in/collinzoeller)
<br> Carnegie Mellon University

This notebook demonstrates how to use the ANNOY library for fast approximate nearest neighbor search to classify text data. The goal is to classify user-generated text data into pre-defined categories using a pre-trained transformer model. The ANNOY library is used to build an approximate nearest neighbor index for the target classes, and then to classify new observations based on their nearest neighbors in the embedding space.


Imports

In [None]:
import random
import time
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
from annoy import AnnoyIndex
from sklearn.metrics import classification_report
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt


## Data
(Need to collect wider breadth of data and host it on GitHub)

## Modules

Functions and tools


In [None]:
def noisify(occupation):
    """Introduce simple noise to a base occupation title."""

    # Occasionally append a random seniority level
    if random.random() < 0.3:
        occupation = occupation + " " + random.choice(["Senior", "Junior", "Lead"])

    # missing letters
    if random.random() < 0.3:
        locc = list(occupation)
        locc.pop(random.randint(0, len(occupation)-1))
        occupation = "".join(locc)

    # lowercase
    if random.random() < 0.5:
        occupation = occupation.lower()
    # all caps
    if random.random() < 0.1:
        occupation = occupation.upper()

    return occupation


def make_random_data(onet_classes, num_obs=10000):
    """Generate synthetic data with noise."""

    obs = pd.DataFrame([random.choice(onet_classes) for _ in range(num_obs)], columns=["labels"])
    obs['x'] = obs['labels'].apply(noisify)
    return obs.labels.to_numpy(), obs.x.to_numpy()


def build_tree(embeddings, num_trees=10):
    """Build an Annoy index for the given embeddings."""
    # Initialize the Annoy index
    t = AnnoyIndex(embeddings.shape[1], 'angular')
    for i, emb in enumerate(embeddings):
        t.add_item(i, emb)
    t.build(num_trees)
    return t


def embed_batched(model, data, dim, batch_size=1000):
    """embeddings in batches."""

    embeds = np.empty((0, dim))

    for i in range(0, len(data), batch_size):
        batch = data[i:i+batch_size]
        batch_embeddings = model.encode(batch, convert_to_numpy=True, show_progress_bar=False)
        embeds = np.vstack([embeds, batch_embeddings])

    return embeds


def predict_labels(tree, classes, embeddings):
    """Predict the nearest neighbor labels for the given embeddings."""
    labels = np.array([])
    for emb in embeddings:
        idx = tree.get_nns_by_vector(emb, 1)[0]
        labels = np.append(labels, classes[idx])

    return labels


def evaluate(y_true, y_pred):
    """Evaluate the classification performance."""
    comp = pd.DataFrame({"label": y_true, "yhat": y_pred})
    comp['correct'] = comp['label'] == comp['yhat']
    accuracy = comp['correct'].mean()
    precision = comp.groupby('yhat')['correct'].mean().mean()
    recall = comp.groupby('label')['correct'].mean().mean()
    f1 = 2 * (precision * recall) / (precision + recall)

    print(f"Accuracy: {accuracy:.4f}"
          f"\nPrecision: {precision:.4f}"
          f"\nRecall: {recall:.4f}"
          f"\nF1: {f1:.4f}")

    report = classification_report(y_true, y_pred, output_dict=True)
    pd.DataFrame(report).T.to_csv("report.csv")
    return


def visualize(embeddings, labels, save: bool = False):
    """Visualize the embeddings using PCA."""
    embeddings = np.vstack(embeddings)
    labels = np.array(labels)

    pca = PCA(n_components=2)
    pca_embed = pca.fit_transform(embeddings)
    # pca_embed = np.column_stack((labels, pca_embed))
    pca_embed[:, 1:] = pca_embed[:, 1:].astype(float)

    unique_classes, counts = np.unique(labels, return_counts=True)
    top_20_indices = np.argsort(-counts)[:20]
    unique_classes = unique_classes[top_20_indices]

    colors = plt.colormaps['tab20']
    class_to_color = {cls: colors(i) for i, cls in enumerate(unique_classes)}

    plt.figure(figsize=(10, 6))

    for cls in unique_classes:
        cls_idx = np.where(labels == cls)[0]

        plt.scatter(pca_embed[cls_idx, 0], pca_embed[cls_idx, 1],
                    color=class_to_color[cls], label=cls, alpha=0.5)

    plt.title("Embedding Clusters in 2D: Top 20 Occupations")
    plt.show()
    if save:
        plt.savefig("embeddings.png")
    return


def pipeline(model: str, labels: np.array, data: np.array, num_trees: int, batch_size: int, save_fig: bool = False):

    # 1. Load pre-trained model
    model = SentenceTransformer(model)

    # 2. Encode target classes
    print(f"Encoding {len(labels)} target label values")
    target_embeddings = model.encode(labels, convert_to_numpy=True, show_progress_bar=True)

    # 3. Build Annoy Index for target classes
    print(f"Building Annoy index with {num_trees} trees")
    tree = build_tree(target_embeddings, num_trees=num_trees)

    # 4. Encode feature space and classify
    print(f"Encoding {len(data)} feature vectors")
    feature_embeddings = embed_batched(model, data, target_embeddings.shape[1], batch_size=batch_size)

    # 5. Predict labels
    print("Predicting labels")
    yhat = predict_labels(tree, labels, feature_embeddings)

    # 6. Visualize
    visualize(feature_embeddings, yhat, save=save_fig)

    return yhat

## Pipeline

### 1. Load pre-trained model

In [None]:
modelname ="all-MiniLM-L6-v2"
model = SentenceTransformer(modelname)

### 2. Encode target classes

In [None]:
print(f"Encoding {len(labels)} target label values")
target_embeddings = model.encode(labels, convert_to_numpy=True, show_progress_bar=True)