Implementing LibraGrad From Scratch for Llama 3
===============================================



You can open this notebook on [Google Colab](https://colab.research.google.com/github/NightMachinery/LibraGrad/blob/master/notebooks/llama3.ipynb).



## Enable CUDA



If the notebook has a CUDA-enabled runtime, we can use it:



In [1]:
import torch

force_cpu_p = False
if not force_cpu_p and torch.cuda.is_available():
    device = "cuda"
    cuda_p = True

else:
    device = "cpu"
    cuda_p = False

device

## Install Dependencies



In [1]:
! pip install -U pip

! pip install -U pynight
! pip install 'torch>=2.0.0'
! pip install -U transformers datasets pillow 'matplotlib'
! pip install -U 'numpy<2.0.0'

You might need to restart the notebook in Colab for the latest version of the installed packages to be loaded.



## Llama 3.2



In [1]:
! hostname
! date

In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM

model_hf_name = "unsloth/Llama-3.2-1B"
#: You can use any auto-regressive Llama 3 variant, e.g.:
# model_hf_name = "unsloth/Llama-3.2-3B"

tokenizer = AutoTokenizer.from_pretrained(model_hf_name)
model = AutoModelForCausalLM.from_pretrained(model_hf_name)
##
model.tokenizer = tokenizer

model.to(device)
None

In [1]:
model

In [1]:
print(model.__class__)

### Helper Functions



Here we define various helper functions for predicting the next token and computing its attribution scores. Note that the code here is not part of LibraGrad&rsquo;s implementation; rather, here we are implementing the base attribution methods LibraGrad enhances.

First let us import libraries we need:



In [1]:
import os
import sys
import itertools

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from types import SimpleNamespace
from typing import List, Optional, Union, Dict, Any, Tuple
from transformers import PreTrainedTokenizer, PreTrainedModel
from dataclasses import dataclass
from torch import Tensor
import torch.utils.checkpoint

import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib import colormaps

from PIL import Image, ImageDraw, ImageFont

from pynight.common_torch import (
    module_mapper,
    torch_shape_get,
)

from pynight.common_attr import (
    normalize_map,
)

from pynight.common_icecream import ic

Now let us define some dataclasses for moving data around:



In [1]:
@dataclass
class Attribution:
    """Stores attribution scores and completeness error."""
    scores: torch.Tensor
    completeness_error: float

@dataclass
class AttributionResults:
    """Stores both IxG and FullGrad+ attributions."""
    input_x_grad: Attribution
    fullgrad_plus: Attribution

@dataclass
class TopKComparisonItem:
    """Individual comparison item for top-k logits."""
    rank: int
    original_idx: int
    original_val: float
    libra_idx: int
    libra_val: float
    same_index: bool
    value_diff: float

@dataclass
class LogitsComparison:
    """Results from comparing original and Libra model logits."""
    mse: float
    cosine_similarity: float
    max_absolute_diff: float
    top_k_comparison: List[TopKComparisonItem]

@dataclass
class ModelPrediction:
    """Model prediction results including token information and attributions."""
    next_token_str: str
    next_token_id: int
    logits: torch.Tensor
    probs: torch.Tensor
    input_ids: List[int]
    token_texts: Optional[List[str]] = None
    attributions: Optional[AttributionResults] = None

Now we implement the next token prediction and its Input-X-Grad (IxG) and FullGrad+ attributions. Llama3 has no biases, so implementing FullGrad+ on it is straightforward. Also, IxG is equivalent to FullGrad on this model, and thus Libra IxG should also have zero completeness error.



In [1]:
def forward_hook_collector(
    module: torch.nn.Module,
    inputs: Tuple[torch.Tensor, ...],
    outputs: Any,
    layer_inputs: List[torch.Tensor],
) -> None:
    """Forward hook to collect input tensors from model layers."""
    tensor_input = inputs[0] if isinstance(inputs, tuple) else inputs
    
    tensor_input.requires_grad_(True)
    tensor_input.retain_grad()
    
    layer_inputs.append(tensor_input)

def get_model_outputs(
    model: PreTrainedModel,
    input_ids: torch.Tensor,
    embeddings: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, int]:
    """
    Helper function to compute model outputs and next token prediction.
    """
    if getattr(model.config, 'pretraining_tp', 0) > 1:
        raise ValueError("Tensor parallel models are not supported")

    device = next(model.parameters()).device
    input_ids = input_ids.to(device)

    with torch.no_grad():
        attention_mask = torch.ones_like(input_ids, device=device)

    model_inputs = {
        'attention_mask': attention_mask.unsqueeze(0),
        'return_dict': True,
        'output_hidden_states': True,  # Needed for FullGrad+
        'num_logits_to_keep': 1,       # Only compute logits for the last token
    }

    if embeddings is not None:
        model_inputs['inputs_embeds'] = embeddings.unsqueeze(0).to(device)
    else:
        model_inputs['input_ids'] = input_ids.unsqueeze(0)

    outputs = model(**model_inputs)

    # Get logits for the last position only
    logits = outputs.logits[:, -1, :]  # Shape: (1, vocab_size)

    with torch.no_grad():
        probs = F.softmax(logits.detach(), dim=-1)
        next_token_id = torch.argmax(logits.detach(), dim=-1).item()

    return logits, probs, next_token_id

def predict_next_token(
    text: str,
    *,
    model: PreTrainedModel,
    tokenizer: Union[str, PreTrainedTokenizer] = "from_model",
    attribute_p: bool = False,
) -> ModelPrediction:
    """
    Predicts the next token and optionally calculates attributions.
    """
    # Handle tokenizer
    if tokenizer == "from_model":
        tokenizer = model.tokenizer

    device = next(model.parameters()).device

    # Tokenize input text - no need for gradients here
    with torch.no_grad():
        input_ids = tokenizer(text, return_tensors="pt").input_ids[0].to(device)
        token_texts = [tokenizer.decode([id]) for id in input_ids.cpu().tolist()]

    if not attribute_p:
        # Simple forward pass without attribution
        with torch.no_grad():
            logits, probs, next_token_id = get_model_outputs(model, input_ids)
            logits = logits.detach()
            probs = probs.detach()

        return ModelPrediction(
            next_token_str=tokenizer.decode([next_token_id]),
            next_token_id=next_token_id,
            logits=logits.squeeze(0),
            probs=probs.squeeze(0),
            input_ids=input_ids.cpu().tolist(),
            token_texts=token_texts,
        )

    # Attribution computation
    model.zero_grad()

    embeddings = model.model.embed_tokens(input_ids)
    embeddings.requires_grad_(True)
    embeddings.retain_grad()

    layer_inputs: List[torch.Tensor] = []
    hooks = []

    for layer in model.model.layers:
        hook = layer.register_forward_hook(
            lambda mod, inp, out, li=layer_inputs: forward_hook_collector(mod, inp, out, li)
        )
        hooks.append(hook)

    logits, probs, next_token_id = get_model_outputs(model, input_ids, embeddings)
    target_logit = logits[0, next_token_id]

    # Remove hooks before backward pass
    for hook in hooks:
        hook.remove()

    target_logit.backward()
    # ic(torch_shape_get(embeddings.grad))

    # Calculate attributions - no need for gradients here
    with torch.no_grad():
        # Calculate Input x Gradient attributions
        ixg_attr = (embeddings * embeddings.grad).sum(dim=-1).detach()
        ixg_completeness_error = abs(
            ixg_attr.sum().item() - target_logit.detach().item()
        )

        # Calculate FullGrad+ attributions by averaging IxG across all layer inputs
        fullgrad_attrs = []
        for layer_input in layer_inputs:
            if layer_input.grad is not None:
                attr = (layer_input * layer_input.grad).sum(dim=-1).detach()
                # ic(attr.shape)
                #: attr: [batch=1, tokens]
                
                if len(attr.shape) > 1:
                    attr = attr[0]
                fullgrad_attrs.append(attr)

        if fullgrad_attrs:
            fullgrad_plus = torch.stack(fullgrad_attrs).mean(dim=0).detach()
        else:
            fullgrad_plus = torch.zeros_like(ixg_attr)

        # Calculate completeness error for FullGrad+
        fullgrad_completeness_error = abs(
            fullgrad_plus.sum().item() - target_logit.detach().item()
        )

        # Move tensors to CPU and detach before storing
        attributions = AttributionResults(
            input_x_grad=Attribution(
                scores=ixg_attr.cpu().detach(),
                completeness_error=ixg_completeness_error
            ),
            fullgrad_plus=Attribution(
                scores=fullgrad_plus.cpu().detach(),
                completeness_error=fullgrad_completeness_error
            )
        )

    # Detach results before returning
    return ModelPrediction(
        next_token_str=tokenizer.decode([next_token_id]),
        next_token_id=next_token_id,
        logits=logits.detach().squeeze(0),
        probs=probs.detach().squeeze(0),
        input_ids=input_ids.cpu().tolist(),
        token_texts=token_texts,
        attributions=attributions
    )

We now define some helper functions for comparing the outputs of the model, which we&rsquo;ll later use to verify that LibraGrad has not changed the forward pass:



In [1]:
def compare_logits(
    original_logits: torch.Tensor,
    libra_logits: torch.Tensor,
    top_k: int = 3,
) -> LogitsComparison:
    """
    Compares logits between original and Libra models.
    """
    with torch.no_grad():
        # Basic similarity metrics
        mse = F.mse_loss(original_logits.detach(), libra_logits.detach()).item()
        cos_sim = F.cosine_similarity(
            original_logits.detach().unsqueeze(0),
            libra_logits.detach().unsqueeze(0)
        ).item()
        max_abs_diff = torch.max(torch.abs(
            original_logits.detach() - libra_logits.detach()
        )).item()

        # Get top K indices and values
        original_top_k = torch.topk(original_logits.detach(), top_k)
        libra_top_k = torch.topk(libra_logits.detach(), top_k)

        # Create detailed comparison items
        top_k_comparison = []
        for i in range(top_k):
            comparison_item = TopKComparisonItem(
                rank=i + 1,
                original_idx=original_top_k.indices[i].item(),
                original_val=original_top_k.values[i].item(),
                libra_idx=libra_top_k.indices[i].item(),
                libra_val=libra_top_k.values[i].item(),
                same_index=original_top_k.indices[i].item() == libra_top_k.indices[i].item(),
                value_diff=abs(original_top_k.values[i].item() - libra_top_k.values[i].item())
            )
            top_k_comparison.append(comparison_item)

    return LogitsComparison(
        mse=mse,
        cosine_similarity=cos_sim,
        max_absolute_diff=max_abs_diff,
        top_k_comparison=top_k_comparison
    )

def print_logits_comparison(comparison_results: LogitsComparison) -> None:
    """Prints the logits comparison results in a readable format."""
    print("Overall Metrics:")
    print(f"MSE: {comparison_results.mse:.6f}")
    print(f"Cosine Similarity: {comparison_results.cosine_similarity:.6f}")
    print(f"Max Absolute Difference: {comparison_results.max_absolute_diff:.6f}")

    print("\nTop-K Comparison:")
    print("Rank  Original(idx,val)     Libra(idx,val)        Same?  Diff")
    print("-" * 70)

    for comp in comparison_results.top_k_comparison:
        same_marker = "✓" if comp.same_index else "✗"
        print(
            f"{comp.rank:2d}    ({comp.original_idx:5d}, {comp.original_val:8.3f})    "
            f"({comp.libra_idx:5d}, {comp.libra_val:8.3f})    "
            f"{same_marker}     {comp.value_diff:.6f}"
        )

#### Visualization



Here we define some utilities to visualize the attribution scores.



##### Plotter



In [1]:
def plot_attributions_PIL_v2(
    token_texts,
    attributions,
    title,
    height=70,  # Height in pixels
    pos_cmap="Blues",
    neg_cmap="Oranges",
    show_p=True,
    dpi=300,
    font_size=42,
    white_foreground_threshold=0.6,
):
    """
    Plots the attributions by coloring token backgrounds using PIL.
    Positive scores: white to blue (or custom colormap).
    Negative scores: white to orange (or custom colormap).
    Args:
        token_texts (List[str]): List of token strings.
        attributions (torch.Tensor): Tensor of normalized attribution scores (expected in [-1, 1]).
        title (str): Plot title.
        height (int): Height of the plot in pixels. Default is 70.
        pos_cmap (str): Colormap name for positive scores. Default is "Blues".
        neg_cmap (str): Colormap name for negative scores. Default is "Oranges".
        show_p (bool): If True, displays the image. Default is True.
        dpi (int): Dots per inch. Default is 300.
        font_size (int): Font size in points. Default is 12.
        white_foreground_threshold (float): Threshold for absolute attribution score above which text color becomes white. Default is 0.6.
    Returns:
        PIL.Image: The generated image
    """
    # Convert attributions to numpy if it's a torch tensor
    if hasattr(attributions, "cpu"):
        attributions = attributions.cpu().numpy()

    # Get colormaps
    pos_colormap = colormaps[pos_cmap]
    neg_colormap = colormaps[neg_cmap]

    # Try to load a TrueType font with specified size
    try:
        # Extended font paths including Colab paths
        font_paths = [
            "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf",  # Linux
            "/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf",  # Colab
            "/Library/Fonts/Arial.ttf",  # macOS
            "C:\\Windows\\Fonts\\arial.ttf",  # Windows
        ]

        font = None
        for path in font_paths:
            if os.path.exists(path):
                try:
                    font = ImageFont.truetype(path, font_size)
                    break
                except Exception:
                    continue

        if font is None:
            # If no system fonts work, try downloading and using a Google Font
            try:
                import urllib.request
                import tempfile

                # Download Roboto font from Google Fonts
                font_url = "https://github.com/google/fonts/raw/main/apache/roboto/Roboto-Regular.ttf"
                temp_font_path = tempfile.mktemp(suffix=".ttf")
                urllib.request.urlretrieve(font_url, temp_font_path)

                font = ImageFont.truetype(temp_font_path, font_size)
                print("Using downloaded Roboto font.")
            except Exception as e:
                # Fallback to default font if download fails
                font = ImageFont.load_default()
                print(
                    f"Warning: Could not download font ({str(e)}). Using default font."
                )
    except Exception as e:
        font = ImageFont.load_default()
        print(f"Warning: Font loading error ({str(e)}). Using default font.")

    # Constants for layout
    padding_horizontal = font_size  # Scale padding with font size
    padding_vertical = font_size // 2
    title_height = int(font_size * 2.5)  # More space for title
    text_height = height - 2 * padding_vertical
    total_height = height + title_height + padding_vertical

    # Create temporary drawing surface for measurements
    temp_img = Image.new("RGB", (1, 1), "white")
    draw = ImageDraw.Draw(temp_img)

    # Calculate exact token widths and total width
    token_sizes = [draw.textbbox((0, 0), token, font=font) for token in token_texts]
    token_widths = [bbox[2] - bbox[0] + padding_horizontal * 2 for bbox in token_sizes]
    total_width = sum(token_widths)

    # Measure title width
    title_bbox = draw.textbbox((0, 0), title, font=font)
    title_width = title_bbox[2] - title_bbox[0]

    # Ensure total width is at least as wide as title
    total_width = max(total_width, title_width + padding_horizontal * 2)

    # Create final image with white background
    img = Image.new("RGB", (total_width, total_height), "white")

    # Set DPI in the image metadata
    img.info["dpi"] = (dpi, dpi)

    draw = ImageDraw.Draw(img)

    # Helper function to get color from matplotlib colormap
    def get_color(score, is_positive=True):
        if is_positive:
            rgb = pos_colormap(score)
        else:
            rgb = neg_colormap(score)
        return tuple(int(x * 255) for x in rgb[:3])

    # Draw title
    title_x = (total_width - title_width) // 2
    title_y = padding_vertical
    draw.text((title_x, title_y), title, fill="black", font=font)

    # Draw tokens with backgrounds
    current_x = 0
    for token, width, score in zip(token_texts, token_widths, attributions):
        # Calculate background color
        if score >= 0:
            color = get_color(score, is_positive=True)
        else:
            color = get_color(-score, is_positive=False)

        # Draw background rectangle
        draw.rectangle(
            [(current_x, title_height), (current_x + width, total_height)], fill=color
        )

        # Calculate exact text position for centering
        text_bbox = draw.textbbox((0, 0), token, font=font)
        text_width = text_bbox[2] - text_bbox[0]
        text_height = text_bbox[3] - text_bbox[1]
        text_x = current_x + (width - text_width) // 2
        text_y = title_height + (height - text_height) // 2

        # Determine text color based on attribution score
        text_color = "white" if abs(score) > white_foreground_threshold else "black"

        # Draw token text
        draw.text((text_x, text_y), token, fill=text_color, font=font)

        current_x += width

    # Display the image if show_p is True
    if show_p:
        try:
            from IPython.display import display

            display(img)

        except ImportError:
            # Use regular show method
            img.show()

    return img


def save_with_dpi(img, filename, dpi=300):
    """
    Helper function to save an image with specific DPI.
    Args:
        img (PIL.Image): The image to save
        filename (str): Output filename
        dpi (int): Dots per inch. Default is 300.
    """
    # Save with DPI information
    img.save(filename, dpi=(dpi, dpi))

In [1]:
def analyze_test_cases(
    test_cases,
    model,
    plot_ixg=True,
    plot_fullgrad=True,
    plot_fn=None,
):
    """
    Analyzes test cases using the model and plots attributions.

    Args:
        model: The model to use for predictions
        plot_ixg: Whether to plot Input x Gradient attributions
        plot_fullgrad: Whether to plot FullGrad+ attributions
    """
    global global_libra_mode_p

    if plot_fn is None:
        plot_fn = plot_attributions_PIL_v2

    for text, global_libra_mode_p in itertools.product(test_cases, [False, True]):
        # print(f"\nAnalyzing: {text!r}")

        if global_libra_mode_p:
            method_prefix = "Libra "
        else:
            method_prefix = ""

        #: Get predictions and attributions
        result = predict_next_token(
            text,
            model=model,
            attribute_p=True,
        )

        # print(f"Predicted next token: {result.next_token_str!r}")

        #: Print completeness errors for both methods
        print(
            f"{method_prefix}Input-X-Grad Completeness Error: {result.attributions.input_x_grad.completeness_error:.6f}"
        )
        print(
            f"{method_prefix}FullGrad+ Completeness Error: {result.attributions.fullgrad_plus.completeness_error:.6f}"
        )

        if plot_ixg:
            normalized_ixg = normalize_map(
                result.attributions.input_x_grad.scores,
                normalize=["scale_by_max_signed_attr"],
                outlier_quantile=0.01,
            ).attributions_normalized.clamp(max=1, min=-1)

            title = f"{method_prefix}Input-X-Grad Attributions for Next Token Prediction: {result.next_token_str!r}"
            #: The =!r= in Python string formatting is a conversion flag that calls the repr() function on the value before formatting it. This means the value will be displayed with quotes and escape characters as needed, showing its "official" string representation.

            plot_fn(
                result.token_texts,
                normalized_ixg,
                title=title,
            )

        if plot_fullgrad:
            normalized_fg = normalize_map(
                result.attributions.fullgrad_plus.scores,
                normalize=["scale_by_max_signed_attr"],
                outlier_quantile=0.01,
            ).attributions_normalized.clamp(max=1, min=-1)

            title = f"{method_prefix}FullGrad+ Attributions for Next Token Prediction: {result.next_token_str!r}"

            plot_fn(
                result.token_texts,
                normalized_fg,
                title=title,
            )

### Test Samples



Here we define some example sentences for the next token prediction task we are testing the models on.



In [1]:
test_cases = [
    # Basic knowledge completion
    "The chemical symbol for gold is",  # Expected: Au
    "The largest planet in our solar system is",  # Expected: Jupiter
    "The speed of light in meters per second is approximately ",  # Expected: 299792458
    
    # Mathematical reasoning
    "If x + 5 = 12, then x equals",  # Expected: 7
    "The square root of 144 is",  # Expected: 12
    "The next number in the sequence: 2, 4, 8, 16,",  # Expected: 32
    
    # Language patterns and idioms
    "Birds of a feather flock",  # Expected: together
    "When in Rome, do as the Romans",  # Expected: do
    "The early bird catches the",  # Expected: worm
    
    # Context-dependent completion
    "In chess, the piece that moves only diagonally is the",  # Expected: bishop
    "In baseball, three strikes and you're",  # Expected: out
    "In music, forte means to play",  # Expected: loud/loudly
    
    # Multi-token reasoning
    "If today is Monday, the day after tomorrow will be",  # Expected: Wednesday
    "The opposite of artificial intelligence is natural",  # Expected: stupidity
    "Water freezes at 0°C and boils at",  # Expected: 100°C
    
    # Common knowledge with a twist
    "Sharks are fish, but dolphins are",  # Expected: mammals
    "The Earth is round, but it's actually shaped like a",  # Expected: spheroid/ellipsoid
    "A tomato is technically a fruit, but in cooking it's used as a",  # Expected: vegetable
    
    # Pattern completion with multiple valid answers
    "Red, Orange, Yellow, Green, Blue,",  # Expected: Indigo/Purple/Violet
    "Spring, Summer, Fall,",  # Expected: Winter
    
    # Technical completions
    "In Python, a list is mutable but a tuple is",  # Expected: immutable
    "HTTP status code 404 means page not",  # Expected: found
    "RAM stands for Random Access",  # Expected: Memory
    
    # Logical reasoning
    "If all A are B, and all B are C, then all A are",  # Expected: C
    "If it's raining, the ground is wet. The ground is dry, therefore it's",  # Expected: not raining
    "Every action has an equal and opposite",  # Expected: reaction
    
    # Cultural references
    "Luke, I am your",  # Expected: father
    "To be, or not to",  # Expected: be
    "Houston, we have a",  # Expected: problem
    
    # Time-based patterns
    "January, February, March,",  # Expected: April
    "Monday, Tuesday, Wednesday,",  # Expected: Thursday
    "Dawn, morning, noon, afternoon,",  # Expected: evening/dusk
    
    # Numerical patterns with complexity
    "2, 3, 5, 7, 11,",  # Expected: 13 (prime numbers)
    "1, 1, 2, 3, 5, 8,",  # Expected: 13 (Fibonacci)
    "1, 4, 9, 16, 25,",  # Expected: 36 (square numbers)
    
    # Geographic knowledge
    "The Great Wall is in",  # Expected: China
    "The Amazon Rainforest is primarily in",  # Expected: Brazil
    "Mount Everest is located in",  # Expected: Nepal/Tibet
    
    # Scientific concepts
    "E equals mc",  # Expected: squared
    "The human body is made up of approximately 60% ",  # Expected: water
    "The closest star to Earth is the",  # Expected: Sun
]

# Test cases focusing on ambiguity and context
ambiguous_test_cases = [
    # Same prompt, different contexts
    "The word 'bank' can refer to a financial institution or the edge of a",  # Expected: river
    "A mouse can be a rodent or a computer",  # Expected: device/peripheral
    "Java can be a programming language or a type of",  # Expected: coffee
    
    # Contextual completion
    "In biology, a cell is basic unit of life, but in prison it's a",  # Expected: room
    "The word 'bright' can describe intelligence or",  # Expected: light
    "The term 'hard drive' in computing refers to storage, but 'hard' alone means",  # Expected: difficult
]

# Test cases for numerical reasoning
numerical_test_cases = [
    # Mathematical patterns
    "1/4 of 100 is",  # Expected: 25
    "A dozen dozens is",  # Expected: 144
    "The next power of 2: 2,4,8,16,32,",  # Expected: 64
    
    # Unit conversions
    "There are 1000 meters in a",  # Expected: kilometer
    "There are 100 centimeters in a",  # Expected: meter
    "There are 60 seconds in a",  # Expected: minute
]

# Test cases requiring reference to previous context
reference_test_cases = [
    # Simple name references
    "The student's name is Alice Lovely. Her name starts with the letter",  # Expected: A
    "My friend Bob Wilson lives in Paris. Wilson's first name starts with the character '",  # Expected: B
    "Dr. James Smith teaches biology. Dr. Smith's given name is",
    
    # Multiple references to choose from
    "Mohammad met Mary at a cafe. The person whose name is shorter is",  # Expected: Mohammad
    "The cities Paris and Rome are beautiful. The city with more letters is",  # Expected: Paris
    
    # Numerical references
    "The temperature was 23 degrees yesterday and 25 today. The difference is ",  # Expected: 2
    "Room A has 15 chairs and Room B has 20 chairs. The smaller number is",  # Expected: 15
    
    # Complex references
    "The red car costs $20000 and the blue car costs $25000. The cheaper car is the",  # Expected: red
    "In the story, Sarah is 12 and her brother Tom is 8. The older sibling's name is",  # Expected: Sarah
    
    # Attribute references
    "The book has a blue cover and golden pages. The color of the cover is",  # Expected: blue
    "The painting shows a sunset over mountains. The scene takes place during",  # Expected: sunset
    
    # Multi-token references
    "James Bond ordered a martini, shaken not stirred. His drink preference was",  # Expected: martini
    "The recipe calls for olive oil and balsamic vinegar. The first ingredient is",  # Expected: olive oil
    
    # Time and sequence references
    "Monday comes before Tuesday. The earlier day is",  # Expected: Monday
    "Spring arrives after winter. The colder season is",  # Expected: winter
    
    # Location references
    "The keys are either in the kitchen or the bedroom. The room that starts with 'k' is",  # Expected: kitchen
    "Between New York and Paris, the American city is",  # Expected: New York
    
    # Property references
    "Diamonds are hard and clouds are soft. The harder object is the",  # Expected: Diamonds
    "A cheetah runs fast while a turtle moves slowly. The faster animal is",  # Expected: cheetah
    
    # References with distractors
    "Although Maria likes blue, Juan's favorite color is red. Maria prefers the color",  # Expected: blue
    "While the square has 4 sides, the triangle drawn in red has 3. The shape with more sides is called a",  # Expected: square
]

all_test_cases = [
    *reference_test_cases,
    *test_cases,
    *numerical_test_cases,
    *ambiguous_test_cases,
]

### Standard Gradients



#### Store Forward Pass Outputs of Standard Model



Prior to implementing modifications, we&rsquo;ll execute a forward pass with the original model to obtain baseline token predictions, enabling subsequent verification of forward pass consistency.



In [1]:
text = "Once upon a"
original_pred = predict_next_token(text, model=model)

### Libra Layers



We now implement LibraGrad for the Llama 3 model architecture.



In [1]:
global_libra_mode_p = True
global_libra_verbose_p = False

def print_libra(*args, **kwargs):
    if global_libra_verbose_p:
        print(*args, **kwargs)

In [1]:
def swap_backward(forward, backward):
    return forward.detach() + (backward - backward.detach())

#### SiLU



In [1]:
def libra_silu(
    x,
    *,
    libra_p="from_global",
    inplace_p=False,
):
    """Swish (or Silu) activation function.

    It is defined as: `swish(x) = x * sigmoid(x)`.

    The Swish (or Silu) activation function is a smooth,
    non-monotonic function that is unbounded above and
    bounded below.

    Args:
        x: Input tensor.

    Reference:

    - [Ramachandran et al., 2017](https://arxiv.org/abs/1710.05941)
    """
    # print_libra(f"libra_silu entered: libra_p={libra_p}")

    if libra_p == "from_global":
        libra_p = global_libra_mode_p

    sigmoid_x = F.sigmoid(x)

    if libra_p:
        print_libra("Libra SiLU")

        sigmoid_x = sigmoid_x.detach()

    if inplace_p:
        return x.mul_(sigmoid_x)
    else:
        return x * sigmoid_x


class LibraSiLU(nn.SiLU):
    def __init__(self, inplace: bool = False):
        super().__init__(inplace)
        self.prefix = f"unset.{self.__class__.__name__}"

    def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
        self.prefix = prefix
        return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)

    def forward(
        self,
        input: torch.Tensor,
        *,
        libra_p="from_global",
    ) -> torch.Tensor:
        return libra_silu(
            input,
            libra_p=libra_p,
            inplace_p=self.inplace,
        )

#### LlamaMLP



In [1]:
class LibraLlamaMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_proj = nn.Linear(
            self.hidden_size, self.intermediate_size, bias=config.mlp_bias
        )
        self.up_proj = nn.Linear(
            self.hidden_size, self.intermediate_size, bias=config.mlp_bias
        )
        self.down_proj = nn.Linear(
            self.intermediate_size, self.hidden_size, bias=config.mlp_bias
        )
        if config.hidden_act != "silu":
            raise NotImplementedError(
                "Only Libra SiLU has been implemented for this model currently."
            )

        self.act_fn = libra_silu

    def forward(
        self,
        x,
        libra_p="from_global",
    ):
        if libra_p == "from_global":
            libra_p = global_libra_mode_p

        # ic(self.act_fn)
        
        if self.config.pretraining_tp > 1:
            slice = self.intermediate_size // self.config.pretraining_tp
            gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
            up_proj_slices = self.up_proj.weight.split(slice, dim=0)
            down_proj_slices = self.down_proj.weight.split(slice, dim=1)

            gate_proj = torch.cat(
                [
                    F.linear(x, gate_proj_slices[i])
                    for i in range(self.config.pretraining_tp)
                ],
                dim=-1,
            )
            up_proj = torch.cat(
                [
                    F.linear(x, up_proj_slices[i])
                    for i in range(self.config.pretraining_tp)
                ],
                dim=-1,
            )

            fused_stream = self.act_fn(gate_proj) * up_proj
            if libra_p:
                print_libra("Libra Self-Gating")
                fused_stream = swap_backward(fused_stream, fused_stream / 2)

            intermediate_states = (fused_stream).split(slice, dim=2)
            down_proj = [
                F.linear(intermediate_states[i], down_proj_slices[i])
                for i in range(self.config.pretraining_tp)
            ]
            down_proj = sum(down_proj)
            
        else:
            fused_stream = self.act_fn(self.gate_proj(x)) * self.up_proj(x)
            if libra_p:
                print_libra("Libra Self-Gating")
                fused_stream = swap_backward(fused_stream, fused_stream / 2)

            down_proj = self.down_proj(fused_stream)

        return down_proj

#### LayerNorm



In [1]:
class LibraLlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states, libra_p="from_global"):
        if libra_p == "from_global":
            print_libra("Libra LayerNorm")
            
            libra_p = global_libra_mode_p

        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        denom = torch.rsqrt(variance + self.variance_epsilon)
        if libra_p:
            denom = denom.detach()
        
        hidden_states = hidden_states * denom
        return self.weight * hidden_states.to(input_dtype)

    def extra_repr(self):
        return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"

#### Attention



##### Patch `F.scaled_dot_product_attention`:



In [1]:
#: If `scaled_dot_product_attention_orig` is not already defined (useful for when this code is reloaded):
if not "scaled_dot_product_attention_orig" in globals():
    scaled_dot_product_attention_orig = F.scaled_dot_product_attention


def scaled_dot_product_attention_patched(
    query,
    key,
    value,
    attn_mask=None,
    dropout_p=0.0,
    is_causal=False,
    scale=None,
    libra_p="from_global",
    **kwargs,
):
    if libra_p == "from_global":
        libra_p = global_libra_mode_p

    if libra_p:
        print_libra("Libra Attention")
        
        query = query.detach()
        key = key.detach()

    return scaled_dot_product_attention_orig(
        query=query,
        key=key,
        value=value,
        attn_mask=attn_mask,
        dropout_p=dropout_p,
        is_causal=is_causal,
        scale=scale,
        **kwargs,
    )


F.scaled_dot_product_attention = scaled_dot_product_attention_patched

#### Replace Modules With Their Libra Counterparts



We&rsquo;re going to swap out parts of our PyTorch model using a helper function called `module_mapper` (imported from `PyNight` earlier). Just keep in mind that this function is more of a hack rather than an officially supported solution, so we should double-check that everything works correctly after making these changes.



In [1]:
from transformers.models.llama.modeling_llama import (
    LlamaRMSNorm,
    LlamaMLP,
)

We need all constructor arguments to be present as attributes to be able to replace modules using `module_mapper`. So we need to store some attributes manually:



In [1]:
def store_hidden_size_for_LlamaRMSNorm(model):
    for module in model.modules():
        if isinstance(module, LlamaRMSNorm):
            #: The weight shape is (hidden_size,), so we take the first dimension
            module.hidden_size = module.weight.shape[0]

    return model

model = store_hidden_size_for_LlamaRMSNorm(model)

In [1]:
module_mapping = {
    LlamaRMSNorm: LibraLlamaRMSNorm,
    LlamaMLP: LibraLlamaMLP,
    nn.SiLU: LibraSiLU,
    #: We have globally patched `F.scaled_dot_product_attention`, so no need for further modifications.
}

libra_model = module_mapper(model, module_mapping).new_model
libra_model.to(device)

libra_model

Now we will verify that the forward pass has not changed after our modifications:



In [1]:
libra_pred = predict_next_token(text, model=libra_model)

#: Verify predictions match
assert original_pred.next_token_id == libra_pred.next_token_id

#: Compare logits
comparison = compare_logits(original_pred.logits.squeeze().cpu(), libra_pred.logits.squeeze().cpu())
print_logits_comparison(comparison)

#### Attribution Tests on Libra



In [1]:
analyze_test_cases(
    test_cases=[
        "You can type whatever you want here. E.g., the capital of France is"
    ],
    model=model,
    plot_ixg=True,
    plot_fullgrad=True
)

In [1]:
analyze_test_cases(
    test_cases=all_test_cases,
    model=model,
    plot_ixg=True,
    plot_fullgrad=True
)

print("\n\nFinished.")