In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import json
from tqdm import tqdm

In [None]:
device = "cuda:0"
model_path = "LayTextLLM/LayTextLLM-Zero"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code = True, padding_side = 'left')
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code = True).to(device)

In [None]:
## model generation setup
generate_params = {
    "use_cache": True,
    "do_sample": False,
    "num_beams": 1,
    "max_new_tokens": 512,
    "min_new_tokens": None,
    "top_p": 0.9,
    "repetition_penalty": 1.0,
    "length_penalty": 1.0,
    "num_return_sequences": 1,
    "temperature": 1.0,
    "output_scores": True,
    "output_hidden_states": True,
    "output_attentions": True,
    "return_dict_in_generate": True,
    "keyword": None
}

In [None]:
BOX_TOKEN = "<unk>"
BOX_TOKEN_ID = 0
INPUT_PROMPT_TEMPLATE = "given document <document>{ocr}</document>, answer following question: {question} Please think step-by-step.\n## answer:"

with open("datasets/funsd_test.json", "r") as fin:
    test_data = json.load(fin)

print('==========num examples', len(test_data))

In [None]:
with torch.no_grad():
    for idx,example in enumerate(tqdm(test_data[2:3])):

        input_ids, input_polys = [], []
        img_size = {}

        texts = example['ocr']
        polys = example['poly']
        w, h = example['img_size']['w'], example['img_size']['h']
        question = example['question']
        answer = example['answer']
        meta = example['metadata']

        ## if ocr is empty, skip this example
        if len(texts) == 0:
            continue

        ## prepare input text ids, and layout polys
        for text, poly in zip(texts, polys):
            input_ids += [BOX_TOKEN_ID]
            text_ids = tokenizer.encode(text, add_special_tokens=False)
            input_ids += text_ids
            text_poly = [poly[0]/w,poly[1]/h,poly[4]/w,poly[5]/h]
            input_polys.append(text_poly)

        # extract layout embeddings
        input_polys = torch.as_tensor(input_polys).unsqueeze(0).to(device)

        # extract text embeddings
        # assign template to input texts
        input_data = {"ocr": tokenizer.decode(input_ids), "question": question}
        input_texts = INPUT_PROMPT_TEMPLATE.format(**input_data)

        # extract text ids
        input_ids = tokenizer.encode(input_texts, add_special_tokens=False)
        input_ids = torch.as_tensor(input_ids).unsqueeze(0).to(device)
        attention_mask = torch.ones_like(input_ids).to(device)

        # Forward pass with attention extraction
        # outputs = model(
        #     input_ids=input_ids,
        #     laytout_input=input_polys,
        #     attention_mask=attention_mask,
        #     output_attentions=True
        #     )
        # attentions = outputs.attentions 

        model_output = model.generate(
                input_ids=input_ids,
                laytout_input=input_polys,
                attention_mask=attention_mask,
                **generate_params
            )

In [None]:
for idx, token_id in enumerate(model_output.sequences[0]):
    print(idx, token_id, tokenizer.decode(token_id))

In [None]:
for idx, token_id in enumerate(model_output.sequences[0][input_ids.size(1):]):
    print(idx, token_id, tokenizer.decode(token_id))

In [None]:
# pip install seaborn

In [None]:
model_output.attentions[2][0][0].shape

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Let's assume you have attention maps for T tokens
max_output_length = model_output.sequences.size(1)  # Number of tokens generated
# max_source_length = input_ids.size(1)+1  # Maximum source sequence length

full_attention_matrices = []

for attention_score in model_output.attentions[len(input_ids):]:
    # Load or extract the attention map for token t
    # Average over heads and layers
    # attention_map = torch.stack(attention_score).cpu().squeeze(1).squeeze(2).mean(dim=(0,1))  # mean of attention of all heads fo the all layers
    attention_map = attention_score[0].cpu().squeeze(0).squeeze(1).max(dim=0)[0] # max of attention of all heads in the first layer
    # attention_map = attention_score[-1].cpu().squeeze(0).squeeze(1).mean(dim=0) # mean of attention of all heads in the first layer


    # Pad to max_source_length
    padded_attention = np.pad(
        attention_map,
        (0, max_output_length - attention_map.shape[0]),
        'constant',
        constant_values=0
    )


    full_attention_matrices.append(padded_attention)

# Stack to form the attention matrix
full_attention_matrix = np.stack(full_attention_matrices, axis=0).T
print(full_attention_matrix.shape)

# Plot the attention matrix
plt.figure(figsize=(100, 200))
sns.heatmap(full_attention_matrix, cmap='viridis', cbar_kws={"shrink": 0.2})
cbar = plt.gcf().axes[-1]  # Get the color bar axis
cbar.tick_params(labelsize=30)  
# Set tick gaps for both x and y axes
# Set tick positions and labels
source_tokens = [tokenizer.decode(input_id) for input_id in model_output.sequences[0]]
target_tokens = [tokenizer.decode(input_id) for input_id in model_output.sequences[0][len(input_ids[0])+1:]]

plt.yticks(ticks=np.arange(0, len(source_tokens)), labels=source_tokens, rotation=0, fontsize=30)
plt.xticks(ticks=np.arange(0, len(target_tokens)), labels=target_tokens, rotation=90,  fontsize=30)
plt.xlabel('Target Token', fontsize=30)
plt.ylabel('Source Token', fontsize=30)
plt.title('LayTextLLM Attention Map', fontsize=30)
plt.savefig("LayTextLLM_Attention_Map.pdf", format="pdf", bbox_inches="tight")
# plt.show()


In [None]:
input_attention = model_output.attentions[0][0].squeeze(0).max(dim=0)[0]

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Let's assume you have attention maps for T tokens
max_output_length = model_output.sequences.size(1)  # Number of tokens generated
# max_source_length = input_ids.size(1)+1  # Maximum source sequence length

attention_matrices = []

for attention_score in model_output.attentions[-20:]:
    # Load or extract the attention map for token t
    # Average over heads and layers
    # attention_map = torch.stack(attention_score).cpu().squeeze(1).squeeze(2).mean(dim=(0,1))  # mean of attention of all heads fo the all layers
    attention_map = attention_score[-1].cpu().squeeze(0).squeeze(1).max(dim=0)[0][90:105] # max of attention of all heads in the first layer
    # attention_map = attention_score[-1].cpu().squeeze(0).squeeze(1).mean(dim=0) # mean of attention of all heads in the first layer


    # Pad to max_source_length
    # padded_attention = np.pad(
    #     attention_map,
    #     (0, max_output_length - attention_map.shape[0]),
    #     'constant',
    #     constant_values=0
    # )


    attention_matrices.append(attention_map)

# Stack to form the attention matrix
attention_matrix = np.stack(attention_matrices, axis=0)

question = "What is the quantity of - TICKET CP? ..."
question_tokens = [tokenizer.decode(token_id) for token_id in tokenizer.encode(question, add_special_tokens=False)]

attention_matrix[:len(question_tokens)] = input_attention[136:149, 90:105].cpu()
print(attention_matrix.shape)

# Plot the attention matrix
plt.figure(figsize=(20, 20))
sns.heatmap(attention_matrix, cmap='viridis', cbar_kws={"shrink": 0.2})
cbar = plt.gcf().axes[-1]  # Get the color bar axis
cbar.tick_params(labelsize=30)  
# Set tick gaps for both x and y axes
# Set tick positions and labels
source_tokens = [tokenizer.decode(input_id) for input_id in model_output.sequences[0]][90:105]

target_tokens = [tokenizer.decode(input_id) for input_id in model_output.sequences[0][-21:-1]]
target_tokens[:len(question_tokens)] = question_tokens

plt.yticks(ticks=np.arange(0, len(target_tokens)), labels=target_tokens, rotation=0, fontsize=30)
plt.xticks(ticks=np.arange(0, len(source_tokens)), labels=source_tokens, rotation=90,  fontsize=30)
plt.xlabel('Source Token', fontsize=30)
plt.ylabel('Target Token', fontsize=30)
plt.title('LayTextLLM Attention Map', fontsize=30)
plt.savefig("LayTextLLM_Attention_Map.pdf", format="pdf", bbox_inches="tight")
# plt.show()


In [None]:
input_attention[136:148].cpu().shape

In [None]:
full_attention_matrix[:][159]

In [None]:
vis_attention_score = []
coordinate_idx = 0
## get the attantion score of each coordinate
for token_idx,token_id in enumerate(input_ids[0]):
    if token_id == 0: # placeholder of layout embedding
        vis_attention_score.append(full_attention_matrix[token_idx][159])
        coordinate_idx += 1

In [None]:
## visualize score
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
from PIL import Image
import matplotlib as mpl


# Load the image
image_path = 'cord_example.jpg'  # Replace with your image path
image = Image.open(image_path)

# Example bounding box data [x, y, width, height, score]
scores = vis_attention_score

# Create figure and axes
fig, ax = plt.subplots(1)
ax.imshow(image)

# Define a colormap to map score to color
cmap = plt.get_cmap('Reds')
norm = mpl.colors.Normalize(vmin=min(scores), vmax=max(scores))
bounding_boxes = polys
# Plot each bounding box
for i, box in enumerate(bounding_boxes):
    x, y, w, h = box[0], box[1], box[4]-box[0], box[5]-box[1]
    score = scores[i]
    
    # Get color based on the score
    color = cmap(norm(score))  # Score between 0 and 1, maps to the colormap
    
    # Create a rectangle patch
    rect = patches.Rectangle((x, y), w, h, linewidth=2, edgecolor=color, facecolor='none')
    
    # Add the rectangle to the plot
    ax.add_patch(rect)
# Create a ScalarMappable for the color bar
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])  # You have to set a dummy array for the color bar

# Add color bar with legend
cbar = plt.colorbar(sm, ax=ax)
cbar.set_label('Score')
# Display the plot
print("Figure size:", fig.get_size_inches())
plt.savefig("cord_example.pdf", format="pdf", bbox_inches="tight")
plt.show()


In [None]:
output_ids = model_output.sequences[0]
output_str = tokenizer.decode(output_ids, skip_special_tokens=False)
print(output_str)

In [None]:
area = [142, 438, 635, 453]
x_min, y_min, x_max, y_max = area[0]/1000,  area[1]/1000,  area[2]/1000,  area[3]/1000
width, height = w, h
abs_x_min, abs_y_min, abs_x_max, abs_y_max = x_min*width, y_min*height, x_max*width, y_max*height

In [None]:
from PIL import Image, ImageDraw

# Open the black-and-white image (grayscale mode 'L')
image_path = '82092117.png'  # Replace with your image path
image = Image.open(image_path)

# Convert the grayscale image to RGB to allow for color drawing
image = image.convert("RGB")

# Create a drawing object
draw = ImageDraw.Draw(image)

# Define the bounding box coordinates (x1, y1, x2, y2)
bounding_box = (abs_x_min, abs_y_min, abs_x_max, abs_y_max)  # Replace with your bounding box coordinates

# Draw the bounding box (outline with color and width)
draw.rectangle(bounding_box, outline='red', width=3)

# Show the image
image.show()

# Optionally save the image with the bounding box
output_path = 'funsd_with_bbx.pdf'
image.save(output_path, format="pdf", bbox_inches="tight")


In [None]:
abs_x_min, abs_y_min, abs_x_max, abs_y_max