In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from datasets import load_dataset
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from peft import get_peft_model, LoraConfig, PeftModel
from random import sample
import pandas as pd
import re
import os
import shap

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
adapter_path = "../../ai-KD/llama8b-LoRA-IS"
save_path = "shap-ai.png"

In [2]:
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [3]:
student_tokenizer = AutoTokenizer.from_pretrained("../../Meta-Llama-3.1-8B")
student_trained_model = AutoModelForCausalLM.from_pretrained(
    "../../Meta-Llama-3.1-8B",
    torch_dtype=torch.bfloat16,
    device_map=device
)
student_trained_model = PeftModel.from_pretrained(student_trained_model, adapter_path) # testing trained
student_trained_model = student_trained_model.merge_and_unload()
student_trained_model.eval()

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (n

In [4]:
student_tokenizer.add_special_tokens({"pad_token":"<pad>"})
student_trained_model.generation_config.pad_token_id = student_tokenizer.pad_token_id
student_trained_model.config.vocab_size = len(student_tokenizer)

In [5]:
def compute_llama_shap_values(model, tokenizer, prompt, target_text="18", save_path='shap_plot.png', max_new_tokens=50):
    """
    Compute SHAP values for a LLaMA model's predictions based on the input prompt and save the plot.
    The function focuses on the probability of generating the target_text ("18") as the next token(s).

    Args:
        model: The LLaMA language model.
        tokenizer: The tokenizer associated with the model.
        prompt: Input prompt string.
        target_text: The target text to check in the model's output (e.g., "18").
        save_path: File path to save the SHAP plot.
        max_new_tokens: The maximum number of tokens to generate.

    Returns:
        Saves a SHAP force plot to the specified path.
    """

    # Check tokenizer and model vocabulary sizes
    tokenizer_vocab_size = len(tokenizer)
    model_vocab_size = model.config.vocab_size
    print(f"Tokenizer Vocabulary Size: {tokenizer_vocab_size}")
    print(f"Model Vocabulary Size: {model_vocab_size}")

    # If they are not equal, adjust the model's vocab_size if appropriate
    if tokenizer_vocab_size != model_vocab_size:
        print("Adjusting model's vocab_size to match tokenizer's vocabulary size.")
        model.config.vocab_size = tokenizer_vocab_size
    
    # Tokenize the target text
    target_tokens = tokenizer.tokenize(target_text)
    target_token_ids = tokenizer.convert_tokens_to_ids(target_tokens)
    target_length = len(target_token_ids)

    print(f"Target Text: {target_text}")
    print(f"Target Tokens: {target_tokens}")
    print(f"Target Token IDs: {target_token_ids}")

    # Verify target_token_ids are valid
    vocab_size = model.config.vocab_size
    print(f"Model Vocabulary Size: {vocab_size}")

    if any(tid >= vocab_size or tid < 0 for tid in target_token_ids):
        raise ValueError(f"Invalid token ID in target_token_ids: {target_token_ids}")

    # Maximum sequence length
    max_model_length = model.config.max_position_embeddings
    max_input_length = max_model_length - target_length

    def model_prediction(texts):
        # Ensure texts is a list of strings
        if isinstance(texts, np.ndarray):
            texts = texts.tolist()
            if isinstance(texts[0], list):
                texts = [''.join(sublist) for sublist in texts]
            else:
                texts = [str(t) for t in texts]
        elif isinstance(texts, list):
            texts = [str(t) for t in texts]
        elif isinstance(texts, str):
            texts = [texts]
        else:
            raise ValueError(f"Unexpected type for texts: {type(texts)}")

        # Tokenize the input texts
        inputs = tokenizer(
            texts,
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=max_input_length  # Ensure it doesn't exceed the model's max
        ).to(device)

        input_ids = inputs['input_ids']
        attention_mask = inputs['attention_mask']

        print(f"Input IDs Shape: {input_ids.shape}")
        print(f"Attention Mask Shape: {attention_mask.shape}")
        print(f"Max Model Length: {max_model_length}")

        # Prepare inputs for generation
        batch_size = input_ids.shape[0]
        target_token_ids_batch = torch.tensor([target_token_ids] * batch_size).to(device)

        # Concatenate input_ids and target_token_ids
        input_ids_extended = torch.cat([input_ids, target_token_ids_batch], dim=1)
        attention_mask_extended = torch.cat([attention_mask, torch.ones_like(target_token_ids_batch)], dim=1)

         # After input_ids_extended is defined
        max_token_id = input_ids_extended.max().item()
        min_token_id = input_ids_extended.min().item()
        print(f"Max Token ID in input_ids_extended: {max_token_id}")
        print(f"Min Token ID in input_ids_extended: {min_token_id}")

        assert min_token_id >= 0, "input_ids_extended contains negative token IDs."
        assert max_token_id < vocab_size, f"input_ids_extended contains token IDs >= vocab_size ({vocab_size})."

        print(f"Input IDs Extended Shape: {input_ids_extended.shape}")
        print(f"Attention Mask Extended Shape: {attention_mask_extended.shape}")

        # Ensure the total sequence length does not exceed the model's limit
        if input_ids_extended.shape[1] > max_model_length:
            excess_length = input_ids_extended.shape[1] - max_model_length
            print(f"Excess Length: {excess_length}")
            input_ids_extended = input_ids_extended[:, excess_length:]
            attention_mask_extended = attention_mask_extended[:, excess_length:]

        print(f"Final Input IDs Shape (after truncation): {input_ids_extended.shape}")
        print(f"Final Attention Mask Shape (after truncation): {attention_mask_extended.shape}")
        assert input_ids_extended.shape[1] <= max_model_length, "Input exceeds max model length after truncation!"

        # Forward pass through the model
        with torch.no_grad():
            outputs = model(
                input_ids=input_ids_extended,
                attention_mask=attention_mask_extended,
                output_hidden_states=False,
                use_cache=False
            )
            logits = outputs.logits  # Shape: (batch_size, seq_len, vocab_size)

            # Calculate log probabilities of the target tokens
            log_probs = []
            for i in range(batch_size):
                # Determine positions of target tokens
                start_idx = input_ids_extended.shape[1] - target_length
                end_idx = input_ids_extended.shape[1]

                # Get the logits for the target tokens
                target_logits = logits[i, start_idx - 1:end_idx - 1, :]  # Exclude the last token

                print(f"Logits Shape: {logits.shape}")
                print(f"Target Logits Shape: {target_logits.shape}")

                # Compute log softmax
                log_softmax = torch.log_softmax(target_logits, dim=-1)

                # Gather log probabilities
                token_log_probs = log_softmax[range(target_length), target_token_ids]
                total_log_prob = token_log_probs.sum().item()
                log_probs.append(total_log_prob)

            return np.array(log_probs)

    # Create SHAP explainer with the Text masker
    masker = shap.maskers.Text(tokenizer)
    explainer = shap.Explainer(model_prediction, masker)

    # Compute SHAP values
    shap_values = explainer([prompt])

    # Plot SHAP values and save the plot
    plt.figure(figsize=(12, 4))
    shap.plots.text(shap_values[0], show=False)
    plt.tight_layout()

    # Ensure the directory exists
    os.makedirs(os.path.dirname(save_path), exist_ok=True)

    # Save the plot
    plt.savefig(save_path)
    plt.close()  # Close the figure to free up memory

    print(f"SHAP plot saved to {save_path}")

In [6]:
prompt = "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?"

compute_llama_shap_values(student_trained_model, student_tokenizer, prompt, save_path=save_path, max_new_tokens=350)


Tokenizer Vocabulary Size: 128257
Model Vocabulary Size: 128257
Target Text: 18
Target Tokens: ['18']
Target Token IDs: [972]
Model Vocabulary Size: 128257
Input IDs Shape: torch.Size([1, 2])
Attention Mask Shape: torch.Size([1, 2])
Max Model Length: 131072
Max Token ID in input_ids_extended: 128000
Min Token ID in input_ids_extended: 972
Input IDs Extended Shape: torch.Size([1, 3])
Attention Mask Extended Shape: torch.Size([1, 3])
Final Input IDs Shape (after truncation): torch.Size([1, 3])
Final Attention Mask Shape (after truncation): torch.Size([1, 3])
Logits Shape: torch.Size([1, 3, 128256])
Target Logits Shape: torch.Size([1, 128256])
Input IDs Shape: torch.Size([1, 64])
Attention Mask Shape: torch.Size([1, 64])
Max Model Length: 131072
Max Token ID in input_ids_extended: 128000
Min Token ID in input_ids_extended: 6
Input IDs Extended Shape: torch.Size([1, 65])
Attention Mask Extended Shape: torch.Size([1, 65])
Final Input IDs Shape (after truncation): torch.Size([1, 65])
Final A

../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [1059,0,0], thread: [0,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [1059,0,0], thread: [1,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [1059,0,0], thread: [2,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [1059,0,0], thread: [3,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [1059,0,0], thread: [4,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [1059,0,0], thread: [5,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1284: indexSelectLargeIndex: block: [1059

RuntimeError: CUDA error: device-side assert triggered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
