# Language Model Similarity Example

As discussed in the anchors notebook, a similarity anchor allows you to provide arbitrary function code to calculate similarity between user specified text and the passed data.

This notebook shows how to provide a language model to a similarity anchor, allowing the utilization of knowledge inside embedding spaces as part of the ICAT model.

In [None]:
# change these constants as needed based on your hardware constraints
BATCH_SIZE = 16
DEVICE = "cuda"
MODEL_NAME = "bert-base-uncased"

For simplicity, we load in (by default, based on constant above) a basic BERT pre-trained model and do no further fine-tuning. In principle of course, any transformer can be supplied here.

In [None]:
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
text_model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)

To use the language model, we define a function that takes a dataframe (containing all the texts we want to analyze) and the anchor instance (through which we can find the target text we're finding similarity with respect to.) 

An anchorlist instance has a `cache` dictionary that we'll store all of the transformer embeddings in _once_, so that all `featurize()` calls after the first one will be much faster. This `cache` dictionary is also pickled and unpickled when the anchorlist is saved and loaded, so it can persist across sessions.

In [None]:
import numpy as np
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity

import icat

class LMSimAnchor(icat.anchors.SimilarityAnchorBase):
    NAME = "BERT"  # optional attribute to set default name inside ICAT UI
    DESCRIPTION = "Use cosine similarity between BERT embeddings of data and target."
    # optional attribute to set description of anchor type in ICAT UI
    
    def embed(self, data: pd.DataFrame) -> pd.DataFrame:
        """This function takes some set of data and embeds the text column using
        the transformer model stored in ``text_model``."""
        embedded_batches = []
        
        # run the tokenizer and model embedding on batches
        max_batches = data.shape[0] // BATCH_SIZE + 1
        last_batch = data.shape[0] // BATCH_SIZE
        for batch in range(max_batches):
            # compute range for this batch
            batch_start = batch * BATCH_SIZE
            batch_end = data.shape[0] if batch == last_batch else batch_start + BATCH_SIZE

            # get the texts within the batch range
            batch_text = data[self.text_col].tolist()[batch_start:batch_end]

            # tokenize and embed with the model
            tokenized = tokenizer(
                batch_text, 
                return_tensors='pt', 
                truncation=True, 
                padding="max_length",
            )["input_ids"].to(DEVICE).detach()
            text_embeddings = text_model(tokenized).last_hidden_state.detach().cpu().numpy()
            embedded_batches.append(text_embeddings)
            
        # stack all the embeddings and average the token embeddings to get the full text 
        # representation for each
        embeddings = np.concatenate(embedded_batches, axis=0)
        embeddings = embeddings.mean(axis=1)
        embeddings_df = pd.DataFrame(embeddings, index=data.index)
        return embeddings_df

    def featurize(self, data: pd.DataFrame) -> pd.Series:
        target_text = self.reference_texts[0] # the target text we're computing similarity to.
                                              # Note that for simplicity we only use the first
                                              # referenced text, but in principle this function
                                              # could be implemented to handle multiple targets,
                                              # e.g. use the average embedding.
        
        # determine data that hasn't been embedded yet, note that we determine this exclusively 
        # by index
        to_embed = data
        cache_key = f"similarity_embeddings_{MODEL_NAME}"
        if cache_key in self.global_cache:
            to_embed = data[~data.index.isin(self.global_cache[cache_key].index)]
        else:
            # make sure the series exists to place our embeddings into later
            self.global_cache[cache_key] = pd.DataFrame()
            
        # perform any necessary embeddings and store into global cache.
        if len(to_embed) > 0:
            self.global_cache[cache_key] = pd.concat([self.global_cache[cache_key], self.embed(to_embed)])
            
        # tokenize and get the full text embedding for the target text
        tokenized_target = tokenizer(
            target_text, 
            return_tensors='pt', 
            truncation=True, 
            padding="max_length",
        )["input_ids"].to(DEVICE).detach()
        target_embedding = text_model(tokenized_target).last_hidden_state.detach().cpu().numpy()
        target_embedding = target_embedding.mean(axis=1)

        # compute cosine similarity between the target text embedding and all the embeddings
        # from the dataframe
        similarities = cosine_similarity(target_embedding, self.global_cache[cache_key].loc[data.index].values)

        # massage the similarity values a little to get better spread in the visualization 
        # and put a minimum threshold on "activation"
        similarities = similarities * 2 - 1
        similarities[similarities < 0.25] = 0.0

        return pd.Series(similarities[0], index=data.index)

We load in a dataset to work with

In [None]:
from sklearn.datasets import fetch_20newsgroups

dataset = fetch_20newsgroups(subset="train")
df = pd.DataFrame({"text": dataset["data"], "category": [dataset["target_names"][i] for i in dataset["target"]]})
#df = df.iloc[0:1999]
df.head()

ICAT has to be initialized before use, taking care of things like panel and pre-requisite UI setup

In [None]:
icat.initialize()

Now we create a model to explore with. ICAT's "anchor types" tab will automatically detect any `Anchor` class definitions, allowing you to dynamically add that anchor type to work with directly within the interface. Alternatively, you can pass the types you want to use to the constructor, or subsequently call `model.anchor_list.add_anchor_type(LMSimAnchor)`

In [None]:
model = icat.Model(df, text_col="text")

In [None]:
model.view

In [None]:
model.save("wip")

In [None]:
mode