## Installation, Imports and Setup

### Tokens for Downloads

Without a Github token the different variant of shap cannot be loaded. Without a HGF Token llama cannot load from the huggingface hub.

This is set up for colab, alternatively the commented string variant below can be used. For this replace the string with an actual token.

*   Github [Token Info](https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens)
*   Huggingface [Token Info](https://huggingface.co/docs/hub/security-tokens)


In [1]:
# grabbing tokens for repository and model access
from google.colab import userdata

gh_token = userdata.get("GITHUB_TOKEN")
hgf_token = userdata.get("HGF_TOKEN")

# gh_token="TOKEN"
# hgf_token="TOKEN"

### Installs and Imports

In [2]:
# basic installs and additional utilies (usually not needed in colab)
!pip install matplotlib
!pip install numpy
!pip install pandas
!pip install ipywidgets
!pip install ipython

# model package installs
!pip install torch

!pip install transformers
!pip install huggingface_hub
!pip install accelerate



In [3]:
# installing captum package from GitHub repository
!pip install git+https://${gh_token}@github.com/LennardZuendorf/thesis-captum.git

# alternatively captum can be installed from pip
## !pip install captum

Collecting git+https://****@github.com/LennardZuendorf/thesis-captum.git
  Cloning https://****@github.com/LennardZuendorf/thesis-captum.git to /tmp/pip-req-build-t_qn_0sx
  Running command git clone --filter=blob:none --quiet 'https://****@github.com/LennardZuendorf/thesis-captum.git' /tmp/pip-req-build-t_qn_0sx
  Resolved https://****@github.com/LennardZuendorf/thesis-captum.git to commit 7dd85e4a2762b0d2c9850c33c966fb9d049dd909
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [4]:
# basic imports
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

# model imports
import torch
import transformers

# interpretability import
import captum

### Setup Model

In [5]:
# setting device based on available hardware
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f"Device set to {device}.")

Device set to cpu.


In [50]:
# setup mistral model and tokenizer
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

# mistral loading function so this can be run individually


def load_mistral():

    # load tokenizer and model from huggingface
    mistral_tokenizer = AutoTokenizer.from_pretrained(
        "mistralai/Mistral-7B-Instruct-v0.2"
    )
    mistral_model = AutoModelForCausalLM.from_pretrained(
        "mistralai/Mistral-7B-Instruct-v0.2"
    )

    # manage setup based on available device
    device = torch.device("cpu")
    mistral_model.to(device)

    mistral_config = GenerationConfig.from_pretrained(
        "mistralai/Mistral-7B-Instruct-v0.2"
    )
    mistral_config.update(**{
        "temperature": 0.7,
        "max_new_tokens": 50,
        "top_p": 0.9,
        "repetition_penalty": 1.2,
        "do_sample": True,
        "seed": 42,
    })

    return mistral_model, mistral_tokenizer, mistral_config

In [None]:
# loading mistral model and tokenizer
mistral_model, mistral_tokenizer, mistral_config = load_mistral()

## Implementation Code

### Utilities

In [25]:
# formatting function for format output text and tokens
import re

# function to format the model reponse nicely


def format_output_text(output: list):
    # remove special tokens from list
    formatted_output = format_tokens(output)

    # start string with first list item if it is not empty
    if formatted_output[0] != "":
        output_str = formatted_output[0]
    else:
        # alternatively start with second list item
        output_str = formatted_output[1]

    # add all other list items with a space in between
    for txt in formatted_output[1:]:
        # check if the token is a punctuation mark
        if txt in [".", ",", "!", "?"]:
            # add punctuation mark without space
            output_str += txt
        # add token with space if not empty
        elif txt != "":
            output_str += " " + txt

    # return the combined string with multiple spaces removed
    return re.sub(" +", " ", output_str)


# format the tokens by removing special tokens and special characters
def format_tokens(tokens: list):
    # define special tokens to remove and initialize empty list
    special_tokens = ["[CLS]", "[SEP]", "[PAD]", "[UNK]", "[MASK]", "▁", "Ġ", "</w>"]
    updated_tokens = []

    # loop through tokens
    for t in tokens:
        # remove special token from start of token if found
        if t.startswith("▁"):
            t = t.lstrip("▁")

        # loop through special tokens and remove them if found
        for s in special_tokens:
            t = t.replace(s, "")

        # add token to list
        updated_tokens.append(t)

    # return the list of tokens
    return updated_tokens


# function to remove orphan whitespaces in a list of text


def remove_orphan_whitespaces(texts: list):
    # instantiating a new empty list
    cleaned_list = []

    # loopin over list
    for text in texts:
        if text != " " and text != "":
            cleaned_list.append(text)

    # additionally rmeoving multiple spaces and return
    return cleaned_list

### Interpretability

In [42]:
# class for KernelSHAP attribution with captum
def kernel_attribution(test_input: str, model, tokenizer):

    # creating llm attribution class with KernelSHAP and given Model, Tokenizer
    llm_attribution = LLMAttribution(KernelShap(model), tokenizer)

    # generating attribution
    attribution_input = TextTokenInput(test_input, tokenizer)
    attribution_result = llm_attribution.attribute(attribution_input)

    # returning full attribution result
    return attribution_result

### Model Functions

In [47]:
# advanced formatting function that takes into a account a conversation history
# CREDIT: adapted from Venkata Bhanu Teja Pallakonda in Huggingface discussions
## see https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/discussions/
def format_prompt(message: str, history: list, system_prompt: str):
    prompt = ""

    # if no history, use system prompt and example message
    if len(history) == 0:
        prompt = f"""<s>[INST] {system_prompt} [/INST] How can I help you today? </s>
      [INST] {message} [/INST]"""
    else:
        # takes the very first exchange and the system prompt as base
        for user_prompt, bot_response in history[0]:
            prompt = (
                f"<s>[INST] {system_prompt} {user_prompt} [/INST] {bot_response}</s>"
            )

        # takes all the following conversations and adds them as context
        prompt += "".join(
            f"[INST] {user_prompt} [/INST] {bot_response}</s>"
            for user_prompt, bot_response in history[1:]
        )

    return prompt


# function to extract real answer because mistral always returns the full prompt


def format_answer(answer: str):
    # empty answer string
    formatted_answer = ""

    # extracting text after INST tokens
    parts = answer.split("[/INST]")
    if len(parts) >= 3:
        # Return the text after the second occurrence of [/INST]
        formatted_answer = parts[2].strip()
    else:
        # Return an empty string if there are fewer than two occurrences of [/INST]
        formatted_answer = ""

    return formatted_answer

In [None]:
# explained and standart (vanilla) model generation function

# imports
from captum.attr import LLMAttribution, KernelShap, TextTokenInput

# explained response function utalizing capumt


def respond_explained(prompt: str):

    # creating llm attribution class with KernelSHAP and Mistal Model, Tokenizer
    llm_attribution = LLMAttribution(KernelShap(mistral_model), mistral_tokenizer)

    # generation attribution
    attribution_input = TextTokenInput(prompt, mistral_tokenizer)
    attribution_result = llm_attribution.attribute(attribution_input)

    # return attribution_result
    return attribution_result


# vanilla generation class returning the model response based on the input
# CREDIT: adapted from official Mistral Ai 7B Instruct documentation on Huggingface
## see https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2


def respond(prompt: str):

    # tokenizing inputs and configuring model
    input_ids = mistral_tokenizer(f"{prompt}", return_tensors="pt")["input_ids"]

    # generating text with tokenized input, returning output
    output_ids = mistral_model.generate(
        input_ids, max_new_tokens=50, generation_config=mistral_config
    )
    output_text = mistral_tokenizer.batch_decode(output_ids)
    return output_text

In [53]:
# simple chat function that calls the model
# formats prompts, calls for an answer and returns updated conversation history
def chat(
    message: str,
    mode: str = "vanilla",
    history: list = [],
    system_prompt: str = "Given a dialog context, you need to respond empathically.",
):
    # formatting the prompt using the model's format_prompt function
    prompt = format_prompt(message, history, system_prompt)

    if mode == "vanilla":
        # generating a formatted answer using the model's respond function
        answer = respond(prompt)
        answer = format_answer(answer)
    elif mode == "explained":
        # generating an attribution using the explained reponse function
        attribution = respond_explained(prompt)

        # extracting the answer from the attribution output and plotting the explanation
        answer = format_output_text(attribution.output_tokens)
        attribution.plot_seq_attr(show=True)
    else:
        print("Select mode!")

    # updating the chat history with the new answer
    history.append((message, answer))

    # returning the updated history
    return history


# interace function that provides the chat "interface"


def interface(limit: int = 5):
    history = []

    # asking for input for set turns
    for _ in range(limit):

        # asking for user input
        user_input = input("Enter your message: ")

        # asking for the mode (vanilla chat or explained chat)
        while True:
            mode = input("Enter the mode ('explained' or 'vanilla'): ")

            # validating input until it's correct
            if mode.lower() in ["explained", "vanilla"]:
                break
            else:
                print("Invalid mode. Please enter 'explained' or 'vanilla'.")

        # calling the chat function for conversation
        history = chat(user_input, mode, history)

        # Print the entire conversation history
        for user_msg, bot_response in history:
            print(f"User: {user_msg}")
            print(f"Bot: {bot_response}")

In [None]:
# calling interface chat function
interface()

#### Comment
Using this prototype one can chat with Mistral, in a very simplified form. However the runtime of KernelSHAP here is typicall more than 10 and can exceed 20 minutes depending on the prompt size.