# Approximate Nearest Neighbors for Text Classification
Author: [Collin Zoeller](https://www.linkedin.com/in/collinzoeller)
<br> Carnegie Mellon University

This notebook demonstrates how to use the ANNOY library for fast approximate nearest neighbor search to classify text data. The goal is to classify user-generated text data into pre-defined categories using a pre-trained transformer model. The ANNOY library is used to build an approximate nearest neighbor index for the target classes, and then to classify new observations based on their nearest neighbors in the embedding space.


Imports

In [None]:
!pip install -q sentence-transformers
!pip install -q annoy
!pip install -q kagglehub

In [1]:
import random
import time
import pandas as pd
import numpy as np
import os
import shutil
from glob import glob
from collections import Counter
from sentence_transformers import SentenceTransformer
from annoy import AnnoyIndex
from sklearn.metrics import classification_report
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import kagglehub
from typing import Union, Self


  from .autonotebook import tqdm as notebook_tqdm


## Data

Use 100k samples of previously scraped data to train and evaluate the model. produce Synthetic data by adding noise to the scraped data.


REDDIT DATA
- Data retrieved from: https://www.reddit.com/r/datasets/comments/w340kj/dataset_of_job_descriptions_for_your_pleasure/
- Data hosted at: https://drive.google.com/drive/folders/1XxNuhiei5taFR6gziofYAx0oWfGeV7y9

KAGGLE DATA
- SOURCE: https://www.kaggle.com/datasets/jatinchawda/job-titles-and-description


In [None]:
def create_training_data(num_obs=None):
    """Create training data from scraped data."""
    df = pd.read_parquet('data/kaggle_clean_data.parquet')
    df = df.rename(columns={"job_title": "title"})
    df1 = pd.read_parquet('data/reddit_data.parquet')
    df = pd.concat([df, df1])
    df = clean_data(df)

    if num_obs:
        df = df.sample(num_obs)

    return df


def clean_data(data:np.array):
    """
    Standardize data for english ASCII-only characters.
    :param data:
    :return:
    """
    df = pd.DataFrame(data, columns=['X'])

   # non-ascii
    df['X'] = df['X'].str.encode('ascii', 'ignore').str.decode('ascii')

    # remove empty strings
    df = df[df['X'] != ""]

    return df.X.to_numpy()


def noisify(truevals: np.array):
    """Add noise to the data."""

    df = pd.DataFrame(truevals, columns=['X'])
    if random.random() < 0.3:
        df['X'] = df['X'].str.upper()
    if random.random() > 0.5:
        df['X'] = df['X'].str.lower()
    if random.random() < 0.5:
        df['X'] = df['X'].apply(lambda x: x + " " + random.choice(["Senior", "Junior", "Lead"]))
    if random.random() < 0.3:
        df['X'] = df['X'].apply(lambda x: "".join(list(x).pop(random.randint(0, len(x) - 1))))
    if random.random() < 0.1:
        df['X'] = df['X'].str[::-1]
    if random.random() < 0.2:
        df['X'] = df['X'].apply(lambda x: x.replace(" ", random.choice(["_", "-", ""])))
    if random.random() < 0.2:
        df['X'] = df['X'].apply(lambda x: x + str(random.randint(0, 99)))
    if random.random() < 0.2:
        df['X'] = df['X'].apply(lambda x: x[:random.randint(1, len(x))])
    return df.X.to_numpy()


def make_random_data(classes, num_obs=1000, colidx=0):
    """Create random data with noise."""
    if isinstance(classes, str) and classes.endswith('.csv'):
        df = pd.read_csv(classes)
        classes = df.iloc[:, colidx].unique()

    true_values = np.random.choice(classes, num_obs)
    noisy_data = noisify(true_values)
    return true_values, noisy_data

### Download the Kaggle dataset

In [None]:
if not os.path.exists("data/kaggle_clean_data.parquet"):
    os.makedirs("data", exist_ok=True)
    path = kagglehub.dataset_download("jatinchawda/job-titles-and-description")
    print("Path to dataset files:", path)
    shutil.move(f"{path}/clean_data.parquet", "data/kaggle_clean_data.parquet")

    # Save only the title column
    df = pd.read_parquet("data/kaggle_clean_data.parquet", columns=["job_title"])
    df.to_parquet("data/kaggle_clean_data.parquet")

else:
    print("Kaggle data already downloaded.")

### Format the Reddit dataset
This data should already be downloaded from https://drive.google.com/drive/folders/1XxNuhiei5taFR6gziofYAx0oWfGeV7y9 and saved as data/reddit_jobs.

In [None]:

if not os.path.exists("data/reddit_data.parquet"):

    if not os.path.exists("data/reddit_jobs"):
        raise FileNotFoundError("Download the Reddit dataset from the drive at"
                                " https://drive.google.com/drive/folders/1XxNuhiei5taFR6gziofYAx0oWfGeV7y9 ."
                                "\nSave as data/reddit_jobs.")

    print("Formatting Reddit data...")
    df = pd.concat([pd.read_csv(file) for file in glob(f"data/reddit_jobs/*.csv")])
    df=df['title'].to_frame()
    df.to_parquet("data/reddit_data.parquet")
    print("Reddit data saved to data/reddit_data.parquet")
else:
    print("Reddit data already formatted correctly.")

## Modules

Functions and tools. The process is modularized for simplicity.


In [None]:
"""
Functions and tools for the occupation classification project.
"""

def timer(func):
    def wrapper(*args, **kwargs):
        start = time.time()
        result = func(*args, **kwargs)
        print(f"(finished in {time.time() - start:.2f} seconds, "
              f"avg: {(time.time() - start) / len(args[1]):.4f} sec/label)")
        return result

    return wrapper



class Encoder:
    """
    Encoder for the given model. The model can be a string or a SentenceTransformer object.
    Embeddings can be saved to or loaded from disk.
    """

    def __init__(self, model: Union[os.PathLike, str, SentenceTransformer]):

        """
        :param model: path, name, SentenceTransformer object, or TransformerLoader object
        """

        self.model = SentenceTransformer(model) if not isinstance(model, SentenceTransformer) else model
        self.embeds = None
        self.labels = None


    @timer
    def embed(self, data: np.array, batched: bool = False, batch_size: int = 1000) -> Self:

        """
        Embed the given data using the model.
        :param data: Array of string data to be embedded
        :param batched: Whether to use batched encoding, better for large datasets
        :param batch_size: Batch size for batched encoding (irrelevant if batched is False)
        :return: Encoder object
        """

        self.labels = data

        if not batched:
            print(f"\nEncoding {len(data)} target label values")
            self.embeds = self.model.encode(data, convert_to_numpy=True, show_progress_bar=True)

            return self

        else:
            num_batches = (len(data) + batch_size - 1) // batch_size
            print(f"\nEncoding {len(data)} target label values ({num_batches} batches)")
            batch_indices = np.array_split(np.arange(len(data)), num_batches)

            self.embeds = np.vstack([self.model.encode(data[indices.tolist()],
                                                       convert_to_numpy=True,
                                                       show_progress_bar=False) for indices in batch_indices])
            return self

    def save(self, path: Union[os.PathLike, str]) -> np.array:
        np.savez_compressed(path, labels=self.labels, embeds=self.embeds)
        return self.labels, self.embeds

    def load(self, path: str) -> np.array:
        data = np.load(path)
        self.labels = data['arr_0']
        self.embeds = data['arr_1']
        self.embeds = np.load(path)
        return self.labels, self.embeds


class Classifier:

    """
    Classifier using Annoy index for fast nearest neighbor search.
    """

    def __init__(self):
        self.tree = None
        self.eval = {}
        self.x_embeddings = None
        self.x_labels = None
        self.pred = None
        self.y_embeddings = None
        self.y_labels = None

    def build_tree(self, labels: np.array, embeddings: np.array, num_trees=10):

        """Build an Annoy index for the given embeddings.
        labels should be same size as embeddings.
        :param labels: string labels
        :param embeddings: Embeddings to build the index
        :param num_trees: Number of trees to build
        """

        print(f"\nBuilding Annoy index with {num_trees} trees")
        t = AnnoyIndex(embeddings.shape[1], 'euclidean')
        for i, emb in enumerate(embeddings):
            t.add_item(i, emb)
        t.build(num_trees)
        self.tree = t

        self.y_embeddings = embeddings
        self.y_labels = labels

        return self

    def predict(self,
                labels: np.array,
                embeddings: np.array,
                neighbors: int = 1,
                save_path: Union[os.PathLike, str, None] = None) -> Self:

        """
        Predict the labels for the given embeddings using the Annoy index.
        :param labels: labels for the embeddings
        :param embeddings: numerical representation of the labels (should be same length as labels)
        :param neighbors: number of neighbors to consider
        :param save_path:
        :return:
        """

        print(f"\nPredicting labels for {len(embeddings)} embeddings")
        indices = [self.tree.get_nns_by_vector(emb.tolist(), neighbors) for emb in embeddings]

        # Map indices to their corresponding class labels
        neighbor_labels = [[labels[idx] for idx in idxs] for idxs in indices]

        # Determine the most common label among neighbors
        y_pred = [Counter(labels).most_common(1)[0][0] for labels in neighbor_labels]

        self.x_labels = labels
        self.x_embeddings = embeddings
        self.pred = y_pred

        # Save the predictions
        if save_path is not None:
            if save_path.endswith(".csv"):
                df = pd.DataFrame([labels, y_pred],
                                  columns=["label", "yhat "])
                df.to_csv(save_path)
            elif save_path.endswith(".parquet"):
                df = pd.DataFrame([labels, y_pred],
                                  columns=["label", "yhat"])
                df.to_parquet(save_path)

            else:
                np.savez_compressed(save_path, [labels, y_pred])

        return self

    def evaluate(self,
                 y_true: np.array,
                 save_path: Union[os.PathLike, str, None] = None) -> None:

        """Evaluate the classification performance against an array of ground truth labels.
        You must predict the labels of the corresponding text before calling this function.
        :param y_true: True labels
        :param save_path: Name or path of the report csv
        :return: prints overall metrics to console and saves micro/macro metrics to a csv if path is provided
        """

        # macro (overall) metrics
        comp = pd.DataFrame({"label": y_true, "yhat": self.pred})
        comp['correct'] = comp['label'] == comp['yhat']
        accuracy = comp['correct'].mean()
        precision = comp.groupby('yhat')['correct'].mean().mean()
        recall = comp.groupby('label')['correct'].mean().mean()
        f1 = 2 * (precision * recall) / (precision + recall)

        macros = {
            "overall": {
                        "accuracy": accuracy,
                        "precision": precision,
                        "recall": recall,
                        "f1": f1
                        }
        }

        print(f"Accuracy: {accuracy:.4f}"
              f"\nPrecision: {precision:.4f}"
              f"\nRecall: {recall:.4f}"
              f"\nF1: {f1:.4f}")

        # Micro (per-class) metrics
        report = classification_report(y_true, self.pred, output_dict=True)
        report = {**report, **macros}
        self.eval = report
        if save_path is not None:
            pd.DataFrame(report).T.to_csv(f"{save_path}.csv")
        return

    def visualize(self,
                  top: int = 10,
                  label_points: bool = False,
                  figsize: tuple[int, int] = (10, 10),
                  save: Union[os.PathLike, None] = None) -> None:

        """
        Visualize the embeddings and their labels.
        :param top: Number of top classes to plot (>20 classes runs out of colors)
        :param label_points: Label the target class points
        :param figsize: Figure size (larger size better for large top)
        :param save: path to save the plot
        :return:
        """
        x_embeddings = np.vstack(self.x_embeddings)
        x_labels = np.array(self.pred)
        y_embeddings = np.vstack(self.y_embeddings)
        y_labels = np.array(self.y_labels)

        # PCA
        pca = PCA(n_components=2)
        pca_embed_x = pca.fit_transform(x_embeddings)
        pca_embed_y = pca.transform(y_embeddings)

        x_unique_classes, x_counts = np.unique(x_labels, return_counts=True)
        n_unique_classes = x_unique_classes[np.argsort(-x_counts)[:top]]

        # Plot
        colors = plt.colormaps['tab20']
        class_to_color = {cls: colors(i) for i, cls in enumerate(n_unique_classes)}

        plt.figure(figsize=figsize)

        # plot input embeddings
        for cls in n_unique_classes:
            cls_idx = np.where(x_labels == cls)[0]
            print(f"Class: {cls}, Count: {len(cls_idx)}")
            plt.scatter(pca_embed_x[cls_idx, 0], pca_embed_x[cls_idx, 1],
                        color=class_to_color[cls], alpha=0.6, marker='o')

        # Label classes
        for cls in n_unique_classes:
            cls_idx = np.where(y_labels == cls)[0]
            plt.scatter(pca_embed_y[cls_idx, 0], pca_embed_y[cls_idx, 1],
                        color=class_to_color[cls], label=f"{cls}", alpha=1.0, marker='x', s=250)
            plt.title(f"Embedding Clusters: Top {top} Occupations")

            if label_points:
                for idx in cls_idx:
                    plt.text(pca_embed_y[idx, 0], pca_embed_y[idx, 1], cls, fontsize=9, ha='center',
                             va='bottom')

        if not label_points:
            plt.legend(loc='upper right')

        plt.show()
        if save is not None:
            plt.savefig(save)
        return


## Pipeline

### 1. Load Output Data

In [3]:
# Output Labels from Dingle and Neiman
labels = pd.read_csv('https://raw.githubusercontent.com/jdingel/DingelNeiman-workathome/master/occ_onet_scores/output/occupations_workathome.csv')
labels.head()
labels = labels['title'].to_numpy()


Unnamed: 0,onetsoccode,title,teleworkable
0,11-1011.00,Chief Executives,1
1,11-1011.03,Chief Sustainability Officers,1
2,11-1021.00,General and Operations Managers,1
3,11-2011.00,Advertising and Promotions Managers,1
4,11-2021.00,Marketing Managers,1


### 2. Encode target classes

Batch_size determines the number of observations to embed at once so to avoid memory issues. While higher batch sizes are faster, they may not fit in memory.



In [None]:
encoder = Encoder("all-MiniLM-L6-v2")
y_labels, y_embeddings = encoder.embed(labels).save("embeds/occupations.npz")

### 3. Encode data

Data here may be unlabeled (such as the Kaggle or Reddit data), but it does not say much for the model's performance. Consider creating labeled data for evaluation. For demonstration purposes, we create randomized data from the output labels by adding noise.

You can optionally use a sample of the unlabelled data to see how well the model works!

In [None]:
# Use this block to use the synthetic data
true, data = make_random_data(classes=labels, num_obs=10, colidx=1)
x_labels, x_embeddings = encoder.embed(data, batched=True, batch_size=1024).save("embeds/test.npz")

In [None]:
# Use this block to use real labeled data
num_obs = 100
data = create_training_data(num_obs=num_obs) # samples 100 observations from the scraped data
x_labels, x_embeddings = encoder.embed(data).save("embeds/test.npz")

### 4. Build Annoy Index for target classes
The hyperparameter at training is the number of random trees to build. The more trees, the more accurate the search, but the longer it takes to build the index.

In [None]:
num_trees = 500
tree = Classifier().build_tree(y_labels, y_embeddings, num_trees=num_trees)

### 5. Predict labels

Below returns a tree with the retrieved labels and embeddings. Creating a separate tree objects allows for multiple predictions without rebuilding the index. The tree is self-contained enough to evaluate and visualize the data in separate instances.

num_neighbors is the number of nearest neighbors to consider when classifying the data, equivalent to the k in KNN.

In [None]:
num_neighbors = 10
predicted = tree.predict(x_labels, x_embeddings, neighbors=num_neighbors)

### 5.5 (optional) Evaluate if using labeled data

Returns the standard (micro) classification report from sklearn and calculated macro metrics. TMicro metrics include accuracy, precision, recall, f1 score, and support. Overall metrics are calculated by averaging the micro metrics across all classes.

Only macro metrics are printed to the console, but the full report is saved to a csv file if a path is provided.



In [None]:
predicted.evaluate(true, save_path="eval_report")

### 6. Visualize
Create a 2d plot of the top 20 most common occupations in the sample. The plot shows the distribution of the embeddings in the feature space, and how they are clustered. Each point represents an observation, and the color represents the assigned class label. The X's represent the target classes.

In [None]:
# prints the top 20 occupations
predicted.visualize(top=20, label_points=True, figsize=(10, 10), save="visualize.png")