In [1]:
%pip install datasets

In [2]:
import pandas as pd
import torch
import nltk
from nltk.tokenize import sent_tokenize, word_tokenize
from torch.utils.data import DataLoader
from datasets import load_dataset
from tqdm.notebook import tqdm

# The models the authors used:
try:
    from transformers import BertForMaskedLM, BertTokenizer, AdamW, get_linear_schedule_with_warmup
    from transformers import AlbertForMaskedLM, AlbertTokenizer
except ModuleNotFoundError:
    %pip install transformers
    from transformers import BertForMaskedLM, BertTokenizer, AdamW, get_linear_schedule_with_warmup
    from transformers import AlbertForMaskedLM, AlbertTokenizer

nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [3]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cuda')

In [6]:
DailyNews_ds = load_dataset('json', data_files='../datasets/DailyNews_300.json', split='train')
DailyNews_ds

Generating train split: 0 examples [00:00, ? examples/s]

Dataset({
    features: ['annotators_ids', 'scores', 'summary', 'text'],
    num_rows: 300
})

In [21]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case = True)
model = BertForMaskedLM.from_pretrained('bert-base-uncased').to(DEVICE)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.weight', 'bert.pooler.dense.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [50]:
class DataCollator:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, batch):

        summaries, texts = zip(*batch)

        max_summary_length = min(max(len(summary) for summary in summaries), 512)

        summaries_ids = torch.tensor([self.tokenizer(
            summary,
            add_special_tokens=False,
            truncation=True,
            max_length=max_summary_length,
            padding='max_length')['input_ids'] for summary in summaries])

        max_num_sentences = max(len(sent_tokenize(text.strip())) for text in texts)

        texts_ids = []
        for text in texts:
            tokenized_sentences = [
                self.tokenizer(sentence,
                               add_special_tokens=False,
                               truncation=True,
                               max_length=max_summary_length, # for consistency, since I can't get the maximum sentence length easily
                               padding='max_length')['input_ids']
                for sentence in sent_tokenize(text.strip())
            ]
            padded_sentences = tokenized_sentences + [[self.tokenizer.pad_token_id] * max_summary_length] * (max_num_sentences - len(tokenized_sentences))  # side-effect: there will be many sentences filled with only pad token IDs
            texts_ids.append(padded_sentences)

        texts_ids = torch.tensor(texts_ids)

        return {'summaries_ids': summaries_ids, 'texts_ids': texts_ids}

In [56]:
def BLANC_help(model, dataloader, M=6, L_min=4, device='cpu'):
    """
    Calculate BLANC similarity between summaries and texts using a BERT-type model.

    Parameters:
    - summaries (Tensor): Tensor of tokenized summaries.
    - texts (Tensor): Tensor of tokenized sentences.
    - model: BERT-type model.
    - M (int): Parameter M for the algorithm (default is 6).
    - L_min (int): Minimum length requirement for masked words (default is 4).

    Returns:
    - Tensor: BLANC similarity scores.
    """

    batch_scores = []

    for batch in dataloader:
        summaries = batch['summaries_ids'].to(device) # Shape: [batch_size, max_summary_length]
        texts = batch['texts_ids'].to(device) # Shape: [batch_size, num_sentences, max_summary_length]
        texts = texts.squeeze(1)

        # print(f'summaries shape: {summaries.shape}')
        # print(f'texts shape: {texts.shape}')

        filler = torch.zeros_like(summaries).fill_(tokenizer.convert_tokens_to_ids('.'))  # Shape: [batch_size, max_summary_length]
        # print(f'filler shape: {filler.shape}')
        S = torch.zeros((2, 2), dtype=torch.float)

        for i in range(M):
            # masked_texts = torch.where((torch.arange(texts.size(1)) - i) % M == 0 & (texts >= L_min), tokenizer.mask_token_id, texts)
            masked_texts = torch.where((torch.arange(texts.size(1)).to(device) - i) % M == 0, tokenizer.mask_token_id, texts).to(device) # Shape: [batch_size, num_sentences, max_summary_length] -- need to find a way to get the word lengths
            # print(f'masked_texts shape: {masked_texts.shape}')
            input_base = torch.cat((filler, masked_texts), dim=1).to(device)  # Shape: [batch_size, max_summary_length * 2]
            input_help = torch.cat((summaries, masked_texts), dim=1).to(device) # Shape: [batch_size, max_summary_length * 2]

            # print(f'input_base shape: {input_base.shape}')
            # print(f'input_help shape: {input_help.shape}')

            tokenized_input_base = model(**{'input_ids': input_base}).logits
            tokenized_input_help = model(**{'input_ids': input_help}).logits

            masked_tokens = torch.nonzero(masked_texts == tokenizer.mask_token_id, as_tuple=False)  # Shape: sometimes [batch_size=32, 2], sometimes [batch_size=64, 2]
            # print(f'masked_tokens shape: {masked_tokens.shape}')

            for j in range(masked_tokens.size(0)):
                idx_batch, idx_token = masked_tokens[j]
                predicted_idx_base = torch.argmax(tokenized_input_base[idx_batch, idx_token]).item()
                predicted_idx_help = torch.argmax(tokenized_input_help[idx_batch, idx_token]).item()

                predicted_word_base = tokenizer.convert_ids_to_tokens(predicted_idx_base)
                predicted_word_help = tokenizer.convert_ids_to_tokens(predicted_idx_help)

                k = int(predicted_word_base == texts[idx_batch, idx_token])
                m = int(predicted_word_help == texts[idx_batch, idx_token])
                S[k, m] += 1

        print(f'S: {S}')
        try:
            B = (S[0, 1] - S[1, 0]) / (S[0, 0] + S[1, 1] + S[0, 1] + S[1, 0])
        except ZeroDivisionError:
            B = torch.zeros(1, dtype=torch.float)
        batch_scores.append(B)

    avg_B = sum(batch_scores) / len(batch_scores)

    return avg_B


In [57]:
dataset = DailyNews_ds.select_columns(['summary', 'text'])
data_collator = DataCollator(tokenizer)

batch_size = 32

dataloader = DataLoader(
    dataset, batch_size=batch_size, collate_fn=data_collator, shuffle=True
    )

BLANC_help(model, dataloader, M=6, L_min=4, device=DEVICE)

S: tensor([[224.,   0.],
        [  0.,   0.]])
S: tensor([[224.,   0.],
        [  0.,   0.]])
S: tensor([[224.,   0.],
        [  0.,   0.]])
S: tensor([[224.,   0.],
        [  0.,   0.]])
S: tensor([[224.,   0.],
        [  0.,   0.]])
S: tensor([[224.,   0.],
        [  0.,   0.]])
S: tensor([[224.,   0.],
        [  0.,   0.]])
S: tensor([[224.,   0.],
        [  0.,   0.]])
S: tensor([[224.,   0.],
        [  0.,   0.]])
S: tensor([[84.,  0.],
        [ 0.,  0.]])


tensor(0.)

Ideas for improvement:
1. try other models
2. try other datasets
3. test on other problems