In [3]:
# Import necessary libraries
import torch
from transformers import BertTokenizer, BertModel
from sentence_transformers import SentenceTransformer, util, models
from nltk.tokenize import word_tokenize
from spacy.matcher import Matcher
import spacy
import spacy_transformers
import random
import pandas as pd
from tqdm import tqdm
import math

In [5]:
# Specify the BERT model variant you want to use
model_name = "../ilm/new_final_model"

# Load BERT tokenizer and model
#tokenizer_w = BertTokenizer.from_pretrained(model_name)
model_w = BertModel.from_pretrained(model_name, output_attentions=True)
word_embedding_model = models.Transformer(model_name)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), 'cls')
model_s = SentenceTransformer(modules=[word_embedding_model, pooling_model])

Some weights of the model checkpoint at ../ilm/new_final_model were not used when initializing BertModel: ['cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at ../ilm/new_final_model and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias

ValueError: Non-consecutive added token '[PAD]' found. Should have index 30522 but has index 0 in saved vocabulary.

In [None]:
def get_attention_matrix(sentence, target_words, tokenizer, model):
    # Initialize tokenizer and model

    # Tokenize input and obtain outputs
    inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True, max_length=512)
    outputs = model(**inputs)

    # Get attention weights from outputs
    attentions = outputs.attentions  # List of attention tensors for each layer

    # Define weights for the last four layers
    weights = torch.tensor([0.1, 0.2, 0.3, 0.4], device=attentions[0].device)

    # Extract the last four layers and apply weights
    last_four_layers = torch.stack(attentions[-4:])
    weighted_layers = last_four_layers * weights[:, None, None, None, None]

    # Sum across the weighted layers and then average over the heads
    weighted_sum = torch.sum(weighted_layers, dim=0)
    avg_attention = torch.mean(weighted_sum, dim=1)[0]

    # Aggregate subword attentions for whole words
    tokens = tokenizer.tokenize(sentence)
    word_attention_list = []
    word_list = []
    i = 0
    while i < len(tokens):
        if not tokens[i].startswith("##"):
            word = tokens[i]
            if i == len(tokens) - 1 or not tokens[i+1].startswith("##"):
                word_attention_list.append(avg_attention[i].tolist())
            else:
                subword_count = 1
                subword_attention = avg_attention[i].clone()
                word += tokens[i+1][2:]
                while i + subword_count < len(tokens) and tokens[i + subword_count].startswith("##"):
                    subword_attention += avg_attention[i + subword_count]
                    word += tokens[i + subword_count + 1][2:] if i + subword_count + 1 < len(tokens) and tokens[i + subword_count + 1].startswith("##") else ""
                    subword_count += 1
                word_attention_list.append((subword_attention / subword_count).tolist())
                i += subword_count - 1
            word_list.append(word)
        i += 1

   # Convert attention to dictionary form for whole words
    attention_dict = {}
    for i, word in enumerate(word_list):
        attention_dict[word] = {word_list[j]: word_attention_list[i][j] for j in range(len(word_list))}

    # Compute importance scores for all words
    all_importance_scores = {}
    for word, weights in attention_dict.items():
        all_importance_scores[word] = sum(weights.values())

    # Extract importance scores for target words
    importance_scores = {word: all_importance_scores[word] for word in target_words if word in all_importance_scores}

    # Normalize the importance scores so they sum to 1 for the target words
    total_score = sum(importance_scores.values())
    for word in importance_scores:
        importance_scores[word] /= total_score

    return attention_dict, importance_scores


In [None]:
def get_semantic_weights(sentence, matcher):

    doc = nlp(sentence)
    matches = matcher(doc)

    chunks = []

    for match in matches:
        match_id, start, end = match
        string_id = nlp.vocab.strings[match_id]
        span = doc[start:end]
        # print(span.text)
        chunks.append((start, end))

    chunk_phrases = [str(doc[start:end]) for start, end in chunks]
    attention_dict, importance_scores = get_attention_matrix(sentence, chunk_phrases, tokenizer_w, model_w)
    s_embedding = model_s.encode(sentence, convert_to_tensor=True)

    weights = []
    for chunk in chunks:

        start, end = chunk
        chunk_phrase = doc[start:end]

        new_sent = ' '.join([doc[:start].text, doc[end:].text])
        # sentence.replace(chunk, '')
        new_embedding = model_s.encode(new_sent, convert_to_tensor=True)
        # print(chunk_phrase, ": ", new_sent)
        cosine_score = util.cos_sim(s_embedding, new_embedding)
        weights.append(((start, end), 1-cosine_score.cpu().squeeze().numpy()))

    total = sum([score for (chunk, score) in weights])
    weights = [(chunk, score/total) for (chunk, score) in weights]

    weights.sort(key = lambda x : x[1], reverse = True)
    sem_weight_scores = {}
    for (start, end), weight in weights:
      sem_weight_scores[str(doc[start:end])] = weight
    return sem_weight_scores, importance_scores

In [None]:
nlp = spacy.load("en_core_web_trf")

pattern = [{"POS": {"IN": ["NOUN", "PROPN","VERB","ADJ","ADV"]}}]

matcher = Matcher(nlp.vocab)
matcher.add("pattern",[pattern])

In [None]:
sample = "why is the fish fishing?"

In [None]:
weights, scores = get_semantic_weights(sample, matcher)
print(weights, scores)
# for (start, end), weight in weights:
#     print((start,end), doc[start:end], ":", weight)

In [None]:
def mask_function(document, words_to_mask):
    """
    Returns a list of 3-tuples indicating positions of masked words.

    Parameters:
    - document (str): The input document.
    - words_to_mask (list): List of words that need to be masked.

    Returns:
    - list: List of 3-tuples (infilling type, span offset, span length).
    """

    masked_positions = []

    for word in words_to_mask:
        offset = document.find(word)
        while offset != -1:
            masked_positions.append(("mask", offset, len(word)))
            offset = document.find(word, offset + len(word))

    # Sort by offsets to ensure the list is in order
    return sorted(masked_positions, key=lambda x: x[1])


In [None]:
import math
import re

def get_masked_template(sentence, n=0.5, semantic_score_weight=0.5):
    # 1. Split the joke into setup and punchline.
    setup, punchline = sentence.split("?")
    
    # 2. Calculate semantic and attention scores for the entire sentence.
    sem_weight_scores, attention_scores = get_semantic_weights(sentence, matcher)
    
    # Create a final_score dictionary just as before.
    final_score = {}
    for k in sem_weight_scores:
        final_score[k] = semantic_score_weight*sem_weight_scores[k] + (1-semantic_score_weight)*attention_scores[k]
    
    # 3. Determine words to be masked for setup and punchline.
    # Using regular expression to split the text by spaces and punctuations.
    setup_words = re.findall(r'\b\w+\b', setup)
    punchline_words = re.findall(r'\b\w+\b', punchline)
    
    setup_scores = {word: final_score[word] for word in setup_words if word in final_score}
    punchline_scores = {word: final_score[word] for word in punchline_words if word in final_score}

    setup_masked_words = [k for k, v in sorted(setup_scores.items(), key=lambda item: item[1], reverse=True)]
    punchline_masked_words = [k for k, v in sorted(punchline_scores.items(), key=lambda item: item[1], reverse=True)]
    
    final_setup_masked_words = setup_masked_words[:math.ceil(len(setup_masked_words)*n)]
    final_punchline_masked_words = punchline_masked_words[:math.ceil(len(punchline_masked_words)*n)]
    
    # 4. Mask the setup and punchline separately.
    masked_setup = mask_function(setup, final_setup_masked_words)
    masked_punchline = mask_function(punchline, final_punchline_masked_words)
    
    # 5. Combine the masked setup and punchline.
    masked_joke = masked_setup + "?" + masked_punchline

    return final_score, masked_joke


In [None]:
# def get_masked_template(sentence, n=0.5, semantic_score_weight=0.5):
#   sem_weight_scores, attention_scores = get_semantic_weights(sentence, matcher)
#   final_score = {}
#   for k in sem_weight_scores:
#     final_score[k] = semantic_score_weight*sem_weight_scores[k] + (1-semantic_score_weight)*attention_scores[k]
#   masked_words = [k for k, v in sorted(final_score.items(), key=lambda item: item[1], reverse=True)]
#   final_masked_words = masked_words[:math.ceil(len(masked_words)*n)]
#   return final_score, mask_function(sentence, final_masked_words)

In [None]:
get_masked_template("i am fishing for fishies because i love food")

In [None]:
def apply_mask(sentence, mask_spans):
    """
    Masks the specified spans in the sentence.

    Parameters:
    - sentence (str): The input sentence.
    - mask_spans (list): List of 3-tuples specifying spans to mask.

    Returns:
    - str: Sentence with specified spans replaced by [MASK].
    """

    # Reverse the list so that we can mask from the end of the sentence.
    # This ensures that the earlier offsets don't change.
    mask_spans.reverse()

    for _, offset, length in mask_spans:
        sentence = sentence[:offset] + '[MASK]' + sentence[offset + length:]

    return sentence
