<a href="https://colab.research.google.com/github/abigailhaddad/ChatGPT_with_Python_for_shiny_docs/blob/master/finding_keywords_in_model_logits.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Imports and Functions

In [None]:
!pip install einops

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn.functional as F

def load_huggingface_model(model_name):
    """
    Load the Hugging Face model and tokenizer.
    """
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
    return model, tokenizer

def get_model_logits(model, tokenizer, prompt, max_length=50):
    """
    Get the logits from the Hugging Face model for the given prompt.
    """
    input_ids = tokenizer.encode(prompt, return_tensors='pt')
    with torch.no_grad():
        outputs = model(input_ids, return_dict=True, output_attentions=False)
    logits = outputs.logits
    return logits[:, -1, :]

def process_logits(logits, tokenizer, top_k=50):
    """
    Process the logits to extract the top possible next tokens.
    """
    # Apply softmax to convert logits to probabilities
    probs = F.softmax(logits, dim=-1)

    # Get the top token predictions
    top_probs, top_indices = torch.topk(probs, top_k)
    top_tokens = [tokenizer.decode([idx]) for idx in top_indices[0]]

    return list(zip(top_tokens, top_probs[0].tolist()))

def keyword_analysis(tokens, keywords):
    """
    Analyze which keywords can be formed from the tokens, using backtracking to handle multiple paths.

    :param tokens: List of tokens.
    :param keywords: List of keywords to search for.
    :return: Dictionary with keys True and False, values are lists of found and not found keywords.
    """
    found = set()

    def can_form_keyword(keyword, token_list):
        if not keyword:
            return True
        for i, token in enumerate(token_list):
            if keyword.startswith(token):
                if can_form_keyword(keyword[len(token) :], token_list[i + 1 :]):
                    return True
        return False

    for keyword in keywords:
        if can_form_keyword(keyword, tokens):
            found.add(keyword)

    not_found = set(keywords) - found
    return {True: list(found), False: list(not_found)}

def extract_tokens(sorted_token_probs):
    """
    Extract tokens from sorted token probabilities, convert to lowercase, strip, and remove duplicates.

    :param sorted_token_probs: List of sorted token probabilities.
    :return: List of unique, processed tokens.
    """
    seen = set()
    tokens = []
    for token, _ in sorted_token_probs:
        processed_token = token.lower().strip()
        if processed_token not in seen:
            seen.add(processed_token)
            tokens.append(processed_token)
    return tokens


## Getting the Model

In [None]:
model_name = "microsoft/phi-2"
model, tokenizer = load_huggingface_model(model_name)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

## Running the Code for Prompt/Keyword Combinations

In [None]:
# This is the prompt we're sending to the model
prompt = "What is the process of photosynthesis?"
# These are the keywords we're looking for in possible responses
keywords = ["chlorophyll", "sunlight", "mitochondria", "plant"]
# these are the logit probabilities associated with possible responses
logits = get_model_logits(model, tokenizer, prompt)
# these are tuples showing each of the top_k tokens and its corresponding probability
token_probs = process_logits(logits, tokenizer, top_k=250)
# these are the tokens sorted by probability
sorted_tokens = extract_tokens(token_probs)
# these are the keyword results: the True ones were present in the tokens (or could be reconstructed from the tokens),
# the False ones were not present
keyword_results = keyword_analysis(sorted_tokens, keywords)
keyword_results

{True: ['plant'], False: ['chlorophyll', 'sunlight', 'mitochondria']}