### Import necessary libraries

In [15]:
import os
import sys
import tensorflow as tf
from PIL import Image, ImageDraw, ImageFont
from transformers import AutoTokenizer, TFBertForMaskedLM

### This section defines the configuration parameters used throughout the script. These parameters include model identifiers, visualization settings, and output options.

In [16]:
# Identifier for the pre-trained masked language model
MODEL = "bert-base-uncased"

# Number of top predictions to generate for the mask token
K = 1

# Default font for drawing text on images; can be adjusted as needed
FONT = ImageFont.load_default()

# Base size of each grid cell in the attention diagram
BASE_GRID_SIZE = 40

# Pixels allocated per word for image dimensions
PIXELS_PER_WORD = 200

# If True, prints the predictions and attention values to the console
PRINT_TO_CONSOLE = True

# If True, saves the generated diagrams to the output folder
SAVE_TO_OUTPUT_FOLDER = False


### This function retrieves the index of the mask token within the input tensor.

In [17]:
def get_mask_token_index(mask_token_id, inputs):
    """
    Retrieves the index of the mask token within the input tensor.

    Args:
        mask_token_id (int): The ID of the mask token.
        inputs (dict): The tokenized input, which is a dictionary containing at least the key 'input_ids'.
                       The 'input_ids' should be a tensor of shape [batch_size, sequence_length].

    Returns:
        int: The index of the mask token in the input_ids list if found, otherwise None.
    """
    # Convert the input tensor to a list of token IDs
    input_ids = inputs['input_ids'][0].numpy().tolist()

    try:
        # Find the index of the mask token in the list
        mask_token_index = input_ids.index(mask_token_id)
        return mask_token_index
    except ValueError:
        # Return None if the mask token is not found in the list
        return None


### This function converts an attention score to a grayscale color value.

In [18]:
def get_color_for_attention_score(attention_score):
    """
    Converts an attention score to a grayscale color value.

    Args:
        attention_score (float): A floating-point value representing the attention score,
                                 typically in the range [0, 1].

    Returns:
        tuple: A tuple (R, G, B) where each component is an integer in the range [0, 255],
               representing a grayscale color corresponding to the attention score.
    """
    # Convert the attention score to a grayscale value (0 to 255)
    gray_value = int(attention_score * 255)

    # Return the grayscale value as an RGB tuple
    return gray_value, gray_value, gray_value


### This function generates and saves attention diagrams for each layer and head.

In [19]:
def visualize_attentions(tokens, attentions):
    """
    Generates and saves attention diagrams for each layer and head.

    Args:
        tokens (list): A list of token strings for which the attentions are computed.
        attentions (list): A nested list of attention weights. The outer list corresponds to layers,
                           and the inner list corresponds to attention heads within each layer.
    """
    # Iterate through each layer of attention weights
    for layer_number, layer_attentions in enumerate(attentions):
        # Iterate through each attention head within the current layer
        for head_number, attention_weights in enumerate(layer_attentions[0]):
            # Generate and save the attention diagram for the current layer and head
            generate_diagram(layer_number + 1, head_number + 1, tokens, attention_weights)


### This function generates and saves an attention diagram for a specific layer and head.

In [20]:
def generate_diagram(layer_number, head_number, tokens, attention_weights):
    """
    Generates and saves an attention diagram for a specific layer and head.

    Args:
        layer_number (int): The layer number for which the diagram is generated.
        head_number (int): The head number within the layer for which the diagram is generated.
        tokens (list): A list of token strings for which the attentions are computed.
        attention_weights (numpy.ndarray): A 2D array of attention weights for the tokens.
    """
    # Ensure the output directory exists
    output_dir = "output"
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # Dynamically calculate grid size based on the number of tokens
    GRID_SIZE = max(BASE_GRID_SIZE, PIXELS_PER_WORD // len(tokens))

    # Calculate the size of the image based on the number of tokens and constants
    image_size = GRID_SIZE * len(tokens) + PIXELS_PER_WORD
    img = Image.new("RGBA", (image_size, image_size), "black")
    draw = ImageDraw.Draw(img)

    # Draw the tokens along the top and left edges of the image
    for i, token in enumerate(tokens):
        # Create an image for the token and rotate it for vertical alignment
        token_image = Image.new("RGBA", (image_size, image_size), (0, 0, 0, 0))
        token_draw = ImageDraw.Draw(token_image)
        token_draw.text(
            (image_size - PIXELS_PER_WORD, PIXELS_PER_WORD + i * GRID_SIZE),
            token,
            fill="white",
            font=FONT
        )
        token_image = token_image.rotate(90)
        img.paste(token_image, mask=token_image)

        # Draw the token text along the top edge
        _, _, width, _ = draw.textbbox((0, 0), token, font=FONT)
        draw.text(
            (PIXELS_PER_WORD - width, PIXELS_PER_WORD + i * GRID_SIZE),
            token,
            fill="white",
            font=FONT
        )

    # Draw the attention weights as a grid of colored rectangles
    for i in range(len(tokens)):
        y = PIXELS_PER_WORD + i * GRID_SIZE
        for j in range(len(tokens)):
            x = PIXELS_PER_WORD + j * GRID_SIZE
            color = get_color_for_attention_score(attention_weights[i][j])
            draw.rectangle((x, y, x + GRID_SIZE, y + GRID_SIZE), fill=color)

    # Save the image in the output directory with a filename indicating the layer and head numbers
    img.save(os.path.join(output_dir, f"Attention_Layer{layer_number}_Head{head_number}.png"))


### This function interacts with the user, runs the BERT model, and generates attention diagrams.

In [21]:
def main():
    """
    Main function to interact with the user, run the BERT model, and generate attention diagrams.

    This function performs the following steps:
    1. Prompts the user to input a text with a mask token.
    2. Tokenizes the input text and retrieves the index of the mask token.
    3. Runs the pre-trained BERT model to get predictions and attention weights.
    4. Prints the top predictions to the console if PRINT_TO_CONSOLE is set to True.
    5. Prints the attention values to the console if PRINT_TO_CONSOLE is set to True.
    6. Generates and saves attention diagrams for each layer and head.

    Note:
        The input text must include the mask token specified by the tokenizer.
    """
    # Prompt the user to input text with a mask token
    text = input("Text: ")

    # Initialize the tokenizer and tokenize the input text
    tokenizer = AutoTokenizer.from_pretrained(MODEL)
    inputs = tokenizer(text, return_tensors="tf")

    # Get the index of the mask token in the input tokens
    mask_token_index = get_mask_token_index(tokenizer.mask_token_id, inputs)
    if mask_token_index is None:
        sys.exit(f"Input must include mask token {tokenizer.mask_token}.")

    # Load the pre-trained BERT model
    model = TFBertForMaskedLM.from_pretrained(MODEL)

    # Run the model to get predictions and attention weights
    result = model(**inputs, output_attentions=True)

    # Get the logits for the mask token and find the top K predictions
    mask_token_logits = result.logits[0, mask_token_index]
    top_tokens = tf.math.top_k(mask_token_logits, K).indices.numpy()

    # Print the top predictions to the console
    if PRINT_TO_CONSOLE:
        print("\nTop Predictions:")
        for i, token in enumerate(top_tokens, start=1):
            prediction = text.replace(tokenizer.mask_token, tokenizer.decode([token]))
            print(f"{i}. {prediction}")

        # Print the attention values for each layer and head to the console
        print("\nAttention Values:")
        for layer_number, layer_attentions in enumerate(result.attentions):
            print(f"\nLayer {layer_number + 1}:")
            for head_number, attention_weights in enumerate(layer_attentions[0]):
                print(f"  Head {head_number + 1}:")
                print(attention_weights.numpy())

    # Generate and save attention diagrams for each layer and head
    visualize_attentions(tokenizer.tokenize(text), result.attentions)

# Run the main function
if __name__ == "__main__":
    main()


All PyTorch model weights were used when initializing TFBertForMaskedLM.

All the weights of TFBertForMaskedLM were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertForMaskedLM for predictions without further training.



Top Predictions:
1. We turned down a narrow lane and passed through a small field.

Attention Values:

Layer 1:
  Head 1:
[[0.03097289 0.01296986 0.02010352 0.03354304 0.08393382 0.03994218
  0.02355631 0.05800313 0.02138769 0.05763273 0.09235885 0.04889867
  0.26484472 0.05964648 0.15220615]
 [0.06965856 0.06030757 0.11948897 0.09427328 0.03080137 0.07301991
  0.09236757 0.02815308 0.10869991 0.03868811 0.03037876 0.05873355
  0.03434809 0.05633667 0.1047446 ]
 [0.066923   0.05075134 0.05296352 0.05627129 0.0283573  0.04133936
  0.03549168 0.07492625 0.07602086 0.09269796 0.02639455 0.07008545
  0.17591259 0.10197965 0.04988525]
 [0.03277787 0.05419286 0.04733646 0.02089492 0.02889161 0.05015407
  0.04733294 0.0807021  0.03740091 0.05268443 0.02988024 0.03713492
  0.24494487 0.11863713 0.11703465]
 [0.05116723 0.05531684 0.09827884 0.06669771 0.0438147  0.06398281
  0.08585194 0.05248184 0.07872381 0.07641729 0.04633489 0.09173141
  0.07381663 0.07072875 0.04465528]
 [0.0690482  0.07

Usage examples:

```bash
I am going to restaurant to eat [MASK].
```
```bash
We turned down a narrow lane and passed through a small [MASK].
```