# Extracting All Attention Graphs from a Transformer in English and Hebrew

In [None]:
pip install transformers

Here I just used the following prompt:

"I am interested in extracting attention graphs from the transformer model GPT-2 for a given text called `input_text`. Please choose an input text and help me write code to extract the attention graphs for it. Each attention graph should be plotted with weight lables on edges and node labels."

This took a little conversation warmup to obtain working code, but once the code ran, it was easy to check visually that it was working using the attention graphs. 

In [None]:
import torch
import networkx as nx
import matplotlib.pyplot as plt
from transformers import GPT2Tokenizer, GPT2Model

# Set up the GPT-2 model and tokenizer
model_name = 'gpt2'
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2Model.from_pretrained(model_name, output_attentions=True)

# Define the input text
input_text = "The quick brown fox jumps over the lazy dog"

# Tokenize the input text and convert it to a tensor
input_ids = torch.tensor(tokenizer.encode(input_text)).unsqueeze(0)

# Get the model output with attention weights
outputs = model(input_ids)

# Iterate through each layer of the output and extract the attention weights
for layer, attention in enumerate(outputs.attentions):
    # Convert the attention weights to a weighted adjacency matrix
    weights = attention[0].detach().numpy()  # Shape: (num_heads, sequence_length, sequence_length)
    for head in range(weights.shape[0]):
        graph = nx.DiGraph()
        sequence_length = weights.shape[1]
        for i in range(sequence_length):
            for j in range(sequence_length):
                graph.add_edge(i, j, weight=weights[head][i][j])
        # Plot the graph with edge and node labels
        plt.figure(figsize=(10, 10))
        pos = nx.circular_layout(graph)
        nx.draw_networkx_edges(graph, pos, width=2, edge_color='grey')
        edge_labels = {(u, v): f"{d['weight']:.2f}" for u, v, d in graph.edges(data=True)}
        nx.draw_networkx_edge_labels(graph, pos, edge_labels=edge_labels, font_size=12)
        nx.draw_networkx_nodes(graph, pos, node_color='lightblue', node_size=2000)
        node_labels = {i: tokenizer.decode([input_ids[0][i].item()]) for i in range(sequence_length)}
        nx.draw_networkx_labels(graph, pos, labels=node_labels, font_size=12)
        plt.title(f"Layer {layer+1} Head {head+1}")
        plt.axis('off')
        plt.show()


The first example doesn't work properly, as we can see from the attention graphs. It seems the model learns the unicode encoding of the Hebrew and then forms the attention graphs (using a sliding window of 12 tokens presumably).  

In [None]:
import torch
import networkx as nx
import matplotlib.pyplot as plt
from transformers import GPT2Tokenizer, GPT2Model

# Set up the GPT-2 model and tokenizer
model_name = 'gpt2'
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2Model.from_pretrained(model_name, output_attentions=True)

# Define the input text
input_text = "שְׁמַע יִשְׂרָאֵל יְהוָה אֱלֹהֵינוּ יְהוָה אֶחָד"

# Tokenize the input text and convert it to a tensor
tokens = torch.tensor(tokenizer.encode(input_text, add_special_tokens=True)).unsqueeze(0)

# Get the model output with attention weights
outputs = model(tokens)

# Iterate through each layer of the output and extract the attention weights
for layer, attention in enumerate(outputs.attentions):
    # Convert the attention weights to a weighted adjacency matrix
    weights = attention[0].detach().numpy()  # Shape: (num_heads, sequence_length, sequence_length)
    for head in range(weights.shape[0]):
        graph = nx.DiGraph()
        sequence_length = weights.shape[1]
        for i in range(sequence_length):
            for j in range(sequence_length):
                graph.add_edge(i, j, weight=weights[head][i][j])
        # Plot the graph with edge and node labels
        plt.figure(figsize=(10, 10))
        pos = nx.circular_layout(graph)
        nx.draw_networkx_edges(graph, pos, width=2, edge_color='grey')
        edge_labels = {(u, v): f"{d['weight']:.2f}" for u, v, d in graph.edges(data=True)}
        nx.draw_networkx_edge_labels(graph, pos, edge_labels=edge_labels, font_size=12)
        nx.draw_networkx_nodes(graph, pos, node_color='lightblue', node_size=2000)
        node_labels = {i: tokenizer.decode([tokens[0][i].item()]) for i in range(sequence_length)}
        nx.draw_networkx_labels(graph, pos, labels=node_labels, font_size=12)
        plt.title(f"Layer {layer+1} Head {head+1}")
        plt.axis('off')
        plt.show()


Next, I tried having it use a different model to parse the Hebrew. This worked well and didn't seem to need any special treatment of the Nikkud. When At this point it began struggling to translate from Hebrew to Yiddish. 

In [None]:
import torch
import networkx as nx
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModel

# Set up the GPT-2 model and tokenizer
model_name = 'onlplab/alephbert-base'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, output_attentions=True)

# Define the input text
input_text = "שְׁמַע יִשְׂרָאֵל יְהוָה אֱלֹהֵינוּ יְהוָה אֶחָד"

# Tokenize the input text and convert it to a tensor
input_ids = torch.tensor(tokenizer.encode(input_text, add_special_tokens=True)).unsqueeze(0)

# Get the model output with attention weights
outputs = model(input_ids)

# Iterate through each layer of the output and extract the attention weights
for layer, attention in enumerate(outputs.attentions):
    # Convert the attention weights to a weighted adjacency matrix
    weights = attention[0].detach().numpy()  # Shape: (num_heads, sequence_length, sequence_length)
    for head in range(weights.shape[0]):
        graph = nx.DiGraph()
        sequence_length = weights.shape[1]
        for i in range(sequence_length):
            for j in range(sequence_length):
                graph.add_edge(i, j, weight=weights[head][i][j])
        # Plot the graph with edge and node labels
        plt.figure(figsize=(10, 10))
        pos = nx.circular_layout(graph)
        nx.draw_networkx_edges(graph, pos, width=2, edge_color='grey')
        edge_labels = {(u, v): f"{d['weight']:.2f}" for u, v, d in graph.edges(data=True)}
        nx.draw_networkx_edge_labels(graph, pos, edge_labels=edge_labels, font_size=12)
        nx.draw_networkx_nodes(graph, pos, node_color='lightblue', node_size=2000)
        node_labels = {i: tokenizer.decode([input_ids[0][i].item()]) for i in range(sequence_length)}
        nx.draw_networkx_labels(graph, pos, labels=node_labels, font_size=12)
        plt.title(f"Layer {layer+1} Head {head+1}")
        plt.axis('off')
        plt.show()


In [None]:
import torch
import networkx as nx
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModel

# Set up the GPT-2 model and tokenizer
model_name = 'onlplab/alephbert-base'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, output_attentions=True)

# Define the input text
input_text = "שמע ישראל יהוה אלהינו יהוה אחד"

# Tokenize the input text and convert it to a tensor
input_ids = torch.tensor(tokenizer.encode(input_text, add_special_tokens=True)).unsqueeze(0)

# Get the model output with attention weights
outputs = model(input_ids)

# Iterate through each layer of the output and extract the attention weights
for layer, attention in enumerate(outputs.attentions):
    # Convert the attention weights to a weighted adjacency matrix
    weights = attention[0].detach().numpy()  # Shape: (num_heads, sequence_length, sequence_length)
    for head in range(weights.shape[0]):
        graph = nx.DiGraph()
        sequence_length = weights.shape[1]
        for i in range(sequence_length):
            for j in range(sequence_length):
                graph.add_edge(i, j, weight=weights[head][i][j])
        # Plot the graph with edge and node labels
        plt.figure(figsize=(10, 10))
        pos = nx.circular_layout(graph)
        nx.draw_networkx_edges(graph, pos, width=2, edge_color='grey')
        edge_labels = {(u, v): f"{d['weight']:.2f}" for u, v, d in graph.edges(data=True)}
        nx.draw_networkx_edge_labels(graph, pos, edge_labels=edge_labels, font_size=12)
        nx.draw_networkx_nodes(graph, pos, node_color='lightblue', node_size=2000)
        node_labels = {i: tokenizer.decode([input_ids[0][i].item()]) for i in range(sequence_length)}
        nx.draw_networkx_labels(graph, pos, labels=node_labels, font_size=12)
        plt.title(f"Layer {layer+1} Head {head+1}")
        plt.axis('off')
        plt.show()