In [1]:
import torch
from transformers import BertTokenizer, BertForTokenClassification
from transformers import BigBirdTokenizer, BigBirdForTokenClassification
from transformers import LongformerTokenizer, LongformerForTokenClassification 
from datasets import load_dataset
from nltk.tokenize import sent_tokenize
import nltk
import numpy as np
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
from bertviz import model_view, head_view

nltk.download('punkt')
nltk.download('punkt_tab')

[nltk_data] Downloading package punkt to /home/zhizheng/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     /home/zhizheng/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


True

In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

MODEL="Longformer" 
FINETUNE=False

# Load the tokenizer and model
if MODEL=="Bert":
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    model = BertForTokenClassification.from_pretrained('bert-base-uncased', num_labels=1, output_attentions=True)
    model_path = './ckpts/finetune_cnn_model_updated/bert_summarization_epoch_3.pt'
    MAX_LEN = 512
elif MODEL == "BigBird":
    tokenizer = BigBirdTokenizer.from_pretrained('google/bigbird-roberta-base')
    model = BigBirdForTokenClassification.from_pretrained('google/bigbird-roberta-base', num_labels=1, output_attentions=True)
    model_path = './ckpts/finetune_cnn_model_updated/bigbird_summarization_epoch_3.pt'
    MAX_LEN = 4096
elif MODEL == "Longformer":
    tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096')
    model = LongformerForTokenClassification.from_pretrained('allenai/longformer-base-4096', num_labels=1, output_attentions=True)
    model_path = './ckpts/finetune_cnn_model_updated/longformer_summarization_epoch_3.pt'
    MAX_LEN = 4096

if FINETUNE:
    model.load_state_dict(torch.load(model_path, weights_only=True))
model.to(device)
model.eval()

# Load CNN/DailyMail dataset
dataset = load_dataset("cnn_dailymail", "3.0.0")
test_dataset = {
    'article': dataset['test']['article'][:100],
    'highlights': dataset['test']['highlights'][:100],
}

# Use the first article from the dataset as an example
article = test_dataset['article'][0]
highlights_gt = test_dataset['highlights'][0]

cuda


Some weights of LongformerForTokenClassification were not initialized from the model checkpoint at allenai/longformer-base-4096 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
max_length=MAX_LEN
article = str(article)
sentences = sent_tokenize(article)

# Encode the article
encoding = tokenizer.encode_plus(
    article,
    add_special_tokens=True,
    max_length=max_length,
    padding='max_length',
    truncation=True,
    return_attention_mask=True,
    return_tensors='pt'
)

input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)

# Get model predictions
with torch.no_grad():
    outputs = model(input_ids, attention_mask=attention_mask)

logits = outputs.logits[0]
# print(torch.sigmoid(logits.squeeze(-1)))

# Convert logits to binary labels
slogits = torch.sigmoid(logits.squeeze(-1))
predictions = (slogits > 0.5).int().cpu().numpy()
# print(predictions)

# Extract important sentences based on predictions
summary = []
current_pos = 1  # Skip [CLS] token

for sent in sentences:
    sent_tokens = tokenizer.encode(sent, add_special_tokens=False)
    sent_length = len(sent_tokens)
    if current_pos + sent_length < max_length:
        sent_prediction = predictions[current_pos:current_pos + sent_length]
        if sent_prediction.mean() > 0.5:  # If the sentence is mostly predicted as important
            summary.append(sent)
        current_pos += sent_length

highlights='\n'.join(summary)

In [13]:
print("Original Article:\n", article)
print("\nOriginal highlight:\n", highlights_gt)
print("\nGenerated highlight:\n", highlights)

Original Article:
 (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and the United States, neither of which is an ICC member, opposed the Palestinians' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wedne

In [14]:
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk import pos_tag
import nltk


text = article
tokens = word_tokenize(text)
pos_tags = pos_tag(tokens)
stop_words = set(stopwords.words("english"))
content_tokens = [word for word, pos in pos_tags if (pos.startswith("NN") or pos.startswith("VB") or pos.startswith("JJ")) and word.lower() not in stop_words and word not in ['<', '>', '/', 'br']]
print(content_tokens)

['CNN', 'Palestinian', 'Authority', 'became', 'member', 'International', 'Criminal', 'Court', 'Wednesday', 'step', 'gives', 'court', 'jurisdiction', 'alleged', 'crimes', 'Palestinian', 'territories', 'formal', 'accession', 'marked', 'ceremony', 'Hague', 'Netherlands', 'court', 'based', 'Palestinians', 'signed', 'ICC', 'founding', 'Rome', 'Statute', 'January', 'accepted', 'jurisdiction', 'alleged', 'crimes', 'committed', 'occupied', 'Palestinian', 'territory', 'including', 'East', 'Jerusalem', 'June', 'month', 'ICC', 'opened', 'preliminary', 'examination', 'situation', 'Palestinian', 'territories', 'paving', 'way', 'possible', 'war', 'crimes', 'investigations', 'Israelis', 'members', 'court', 'Palestinians', 'subject', 'counter-charges', 'Israel', 'United', 'States', 'ICC', 'member', 'opposed', 'Palestinians', 'efforts', 'join', 'body', 'Palestinian', 'Foreign', 'Minister', 'Riad', 'al-Malki', 'speaking', 'Wednesday', 'ceremony', 'said', 'move', 'greater', 'justice', 'Palestine', 'becom

In [15]:
content_word_mask = torch.tensor([[1 if tokenizer.decode(id).strip() in content_tokens else 0 for id in input_ids[0]]], dtype=torch.bool)

# Combine all masks
valid_token_mask = content_word_mask

In [16]:
# Find the indices of True values
true_indices = torch.where(valid_token_mask)[1]

if len(true_indices) >= 12:
    valid_token_mask[0,true_indices[12:]] = False

In [17]:
if MODEL=="Longformer":
    ATTN = []
    for layer_attention in outputs.attentions:
        global_attn = torch.zeros((1, layer_attention.shape[1],MAX_LEN, MAX_LEN))  # Use the same device
        local_attn = torch.zeros((1, layer_attention.shape[1],MAX_LEN, MAX_LEN))  # Use the same device
        half_window = int(model.longformer.config.attention_window[0] / 2)
        # no global attention for cnn
            
        # Draw local attention 
        for i in range(MAX_LEN):
            if i > half_window and MAX_LEN - i > half_window + 1:
                local_attn[0, :, i, i - half_window:i + half_window + 1] = layer_attention[0, :, i, :]
            elif i <= half_window:
                local_attn[0, :, i, :i + half_window + 1] = layer_attention[0, :, i, half_window - i:]
            else:
                local_attn[0, :, i, i - half_window:] = layer_attention[0, :, i, :half_window + (MAX_LEN - i)]

        combined_attn = global_attn + local_attn
        ATTN.append(combined_attn.to(device))
        
else:
    ATTN = outputs.attentions

In [18]:
masked_attn = []
print(outputs.attentions[0].shape)
for layer_attention in ATTN:
    # CLS token's attention in this layer
    cls_attention_to_content_words = layer_attention[:, :, valid_token_mask[0]]
    cls_attention_to_content_words = cls_attention_to_content_words[:, :, :,valid_token_mask[0]]
    cls_attention_to_content_words = cls_attention_to_content_words / torch.sum(cls_attention_to_content_words, dim=-1, keepdims=True)
    masked_attn.append(cls_attention_to_content_words)  # Average over heads

# Concatenate all layer attentions vertically
print(masked_attn[0].shape)
# Extract valid tokens for x-axis labels
tokens = tokenizer.convert_ids_to_tokens(input_ids[0][valid_token_mask[0]])


torch.Size([1, 12, 4096, 513])
torch.Size([1, 12, 12, 12])


In [19]:
head_view(masked_attn, tokens) 

<IPython.core.display.Javascript object>